In [None]:
import os
import glob
import torch
from torch.nn.functional import softmax
from torchvision import transforms
from IPython.display import clear_output

from ipywidgets.widgets import FileUpload, Label, Output, VBox
from PIL import Image
from io import BytesIO

from resnet50 import ResNet50
from utils import make_conv_dict, build_res_layers, get_device

In [None]:
#  Create ResNet50 dictionary
classes = [c.split('/')[-1] for c in sorted(glob.glob(os.path.join('poke_imgs', '*')))]
num_classes = len(classes)
layer1 = build_res_layers(1)
layer2 = build_res_layers(2, in_chls=256)
layer3 = build_res_layers(3, in_chls=512)
layer4 = build_res_layers(4, in_chls=1024)

res_50 = ResNet50(num_classes, layer1, layer2, layer3, layer4, pretrained=False, verbose=False)

In [None]:
#  Get current device
device = get_device()

In [None]:
#  Load ResNet50 model weights
ckpt_path = os.path.join('res50_ckpts', 'res50_ft_data_eps100_lr0001_ckpt.pt')
ckpt = torch.load(ckpt_path, map_location=device) 
res_50.load_state_dict(ckpt['state_dict'])
res_50.to(device)
clear_output()

In [None]:
#  Create method to get prediction
def predict(img, model, classes):
    """Returns predicted Pokemon classification as well as confidence."""
    #  Transform images
    trans_func = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), 
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img_input = trans_func(img)
    
    #  Compute prediction
    model.eval()
    with torch.no_grad():
        logits = model(img_input.unsqueeze(0))
    
        prob, pred = torch.max(softmax(logits), dim=1)
    
    return prob, pred

In [None]:
#  Create framework for web app
upload_btn = FileUpload()
pred_lbl = Label()
out_disp = Output()

In [None]:
#  Write commands when new image is uploaded
def on_upload(change):
    """Run steps to compute and return prediction."""
    pred_lbl.value = 'Unknown'

    #  Read image
    img = Image.open(BytesIO(upload_btn.data[-1]))

    #  Compute prediction
    prob, pred_idx = predict(img, res_50, classes)

    #  Display image and prediction
    out_disp.clear_output()
    with out_disp:
        img.thumbnail((128, 128))
        display(img)
    
    pred_lbl.value = f'Prediction: {classes[pred_idx.item()]}, Probability: {prob.item():.4f}'

upload_btn.observe(on_upload, names = ['data'])

In [None]:
#  Display the UI
display(VBox([Label('Upload your Pokemon image!'), upload_btn, out_disp, pred_lbl]))