In [42]:
from ipywidgets import HBox,VBox,widgets,Button,Checkbox,Dropdown,Layout,Box,Output,Label,FileUpload
import cv2 as cv
import numpy as np
from PIL import Image
import io

In [43]:
def softmax(x):
    z = x - x.max()
    return np.exp(z)/np.exp(z).sum()

In [44]:
thresh = 0.5
imagenet_stats  = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
classes = ['Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
onnx_name = 'resnet34_bce_ex_29_15.onnx'
net = cv.dnn_ClassificationModel(onnx_name)
net.setInputSize(224, 224);

In [45]:
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()

In [46]:
def predict(img):
    img_normalized = (((np.array(img)[:,:,:3]/255) - imagenet_stats[0]) / imagenet_stats[1]).astype('float32')
    act_opencv = net.predict(img_normalized)[0]
    preds = softmax(act_opencv)[0]
    pred_mask = preds > thresh
    n_preds = pred_mask.sum()
    if(n_preds == 1):
        i_pred = preds.argmax()
        pred = preds[i_pred]
        if(pred < 0.75):
            conf_text = 'LOW'
        elif(pred < 0.95):
            conf_text = 'MEDIUM'
        else:
            conf_text = 'HIGH'
        lbl = f'Class: {classes[i_pred]}, Confidence {conf_text}, Probability: {pred:.02f}' 
    else:
        lbl = f'At the given threshold: {thresh}, the supplied image does not match one of our pets!'
    return lbl

In [47]:
def on_click(change):
    img = Image.open(io.BytesIO(btn_upload.data[-1]))
    out_pl.clear_output()
    lbl = predict(img)
    img.thumbnail((256,256))
    with out_pl: display(img)
    lbl_pred.value = lbl
    return btn_upload.data[-1]

In [48]:
btn_upload.observe(on_click, names=['data'])
display(VBox([widgets.Label('Select your cat/dog!'),btn_upload,out_pl,lbl_pred]))

VBox(children=(Label(value='Select your cat/dog!'), FileUpload(value={}, description='Upload'), Output(), Labe…