In [2]:
import ipywidgets as widgets
import torch
import PIL
import io
import torchvision.transforms as transforms
import random
import pickle
import numpy as np

###### Loading classes dictionaries ########
with open('imagenet_classes.pkl', 'rb') as f:
     imagenet=pickle.load(f)
colors=['blue','red','yellow','orange','green','purple','pink','magenta','limegreen','olive','teal','violet','lawngreen']
coco_labels=[
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]



file = open("image_select.png", "rb")#placeholder img
image = file.read()

###### Initilization of widgets ########

uploader=widgets.FileUpload(
    accept='image/*',  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=False,  # True to accept multiple files upload else False
    description='Select an image'
)

viewer=widgets.Image(value=image)#contains a placeholder image

def slider_creation(name): #function to generate sliders
    return widgets.FloatSlider(
    value=1,
    min=0,
    max=2,
    step=0.1,
    description=name,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

sliders=[slider_creation(x) for x in ['Brightness','Color','Contrast']]

obj_det=widgets.Button(description="Object detection")

img_class=widgets.Button(description="Image classification")

semantic_segm=widgets.Button(description="Segmentation")

reset=widgets.Button(description="Discard changes")

###### Helper functions ########


    
    
transform = transforms.Compose([#required transformations for various models
    transforms.Resize(256),                    
    transforms.CenterCrop(224),                
    transforms.ToTensor(),                    
    transforms.Normalize(                      
    mean=[0.485, 0.456, 0.406],                
    std=[0.229, 0.224, 0.225]                  
     )
])

def buf2viewer(img): #function thats save the PIL format of image to a virtual buffer and then 
    #assigns current viewer widget value  to the image in buffer
    buf = io.BytesIO()
    img.save(buf , format='PNG')
    viewer.value=buf.getvalue()

def viewer2torch():#converts current viewer image to tensor with the format required to run model()
    image = PIL.Image.open(io.BytesIO(viewer.value)).convert('RGB')
    image = transforms.functional.to_tensor(image)
    return image.unsqueeze(0)






###### GUI interaction functions ########
def image_open(change):
    viewer.value = change.new[list(change.new.keys())[0]]['content']

def function_generator(enhance_function): #function to generate function that handles slider changes
    #this can't be done with .observe method as it forbids passing arguments to it
    def slider_chng(change):
            image = PIL.Image.open(io.BytesIO(viewer.value))
            enhancer = enhance_function(image)
            image=enhancer.enhance(change.new)
            buf2viewer(image)
    return slider_chng

def on_obj_clicked(b):
    model =  torch.load('fasterrcnn_mobilenet_v3_large_fpn.pt')
    model.eval()
    image = PIL.Image.open(io.BytesIO(viewer.value)).convert('RGB')
    def add_margin(pil_img, top, right, bottom, left, color):#function that adds padding to the top 
        #image in case the image is too small to fit the labels source:https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/
        width, height = pil_img.size
        new_width = width + right + left
        new_height = height + top + bottom
        result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
        result.paste(pil_img, (left, top))
        return result

    if image.size[1]<300:image=add_margin(image, 300-image.size[1], 0, 0, 0, 'white')
        
    buf2viewer(image)
    image_t = viewer2torch()
    pred=model(image_t)
    scores=pred[0]['scores'].tolist()
    keep=[]
    for i,score in enumerate(scores):
        if score>0.2:keep.append(i)
    scores=pred[0]['scores'][keep,].tolist()
    labels=[coco_labels[i] for i in pred[0]['labels'][keep,].tolist()]
    boxes=pred[0]['boxes'][keep,].tolist()
    for index,label in enumerate(labels):
        labels[index]=label+': '+str(round(scores[index]*100,1))+'%'
    source_img=PIL.Image.open(io.BytesIO(viewer.value)).convert('RGB')
    draw = PIL.ImageDraw.Draw(source_img)
    for index,box in enumerate(boxes):
        color=colors[random.randint(0,len(colors)-1)]
        draw.rectangle(box,width=6,outline=color)
        draw.text((box[0], box[1]-30), labels[index],font=PIL.ImageFont.truetype(font='calibri.ttf', size=25),fill=color)
    buf2viewer(source_img)
    
def on_reset_clicked(b):
    try:
        viewer.value=uploader.value[list(uploader.value.keys())[0]]['content']
    except IndexError:
        viewer.value=image
        
def on_class_clicked(b):
    model=torch.load('mobilenet_v3_large.pt')
    model.eval()
    image = viewer2torch()
    pred=model(image)
    _, indices = torch.sort(pred, descending=True)
    percentage = torch.nn.functional.softmax(pred, dim=1)[0] * 100
    blank=PIL.Image.open('blank.png')
    features=[imagenet[idx]+": "+ str(round(percentage[idx].item(),1))+"%" for idx in indices[0][:5]]
    draw = PIL.ImageDraw.Draw(blank)
    for idx in range(5):
        draw.text((30, 30+(idx*30)), features[idx],font=PIL.ImageFont.truetype(font='calibri.ttf', size=30),fill='black')
    blank=blank.crop((0,0,350,200))
    image = PIL.Image.open(io.BytesIO(viewer.value)).convert('RGB')
    if (image.size[0]<350 or image.size[1]<200):image=image.resize((350,200))# if the image is too small make sure
        #that all all information from features list is present on the screen
    PIL.Image.Image.paste(image, blank)
    buf2viewer(image)

def on_segm_clicked(b):
    model=torch.load('segm_deeplabv3_mobilenet_v3.pt')
    model.eval()
    image=viewer2torch()
    pred=model(image)
    def decode_segmap(image, nc=21):#function copied from https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
        label_colors = np.array([(0, 0, 0),  # 0=background
                   # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
                   (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
                   # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
                   (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
                   # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
                   (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
                   # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
                   (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
        r = np.zeros_like(image).astype(np.uint8)
        g = np.zeros_like(image).astype(np.uint8)
        b = np.zeros_like(image).astype(np.uint8)
        for l in range(0, nc):
            idx = image == l
            r[idx] = label_colors[l, 0]
            g[idx] = label_colors[l, 1]
            b[idx] = label_colors[l, 2]
        rgb = np.stack([r, g, b], axis=2)
        return rgb
    out=torch.argmax(pred['out'].squeeze(), dim=0).detach().cpu().numpy()
    out=torch.from_numpy(decode_segmap(out))
    out=out.permute(2,0,1)
    to_image=transforms.ToPILImage()
    out=to_image(out)
    buf2viewer(out)




###### Intiation of GUI ########   
  
uploader.observe(image_open, names='value')

enhancers=[PIL.ImageEnhance.Brightness,PIL.ImageEnhance.Color,PIL.ImageEnhance.Contrast]
for idx,slider in enumerate(sliders):
    slider.observe(function_generator(enhancers[idx]),names='value')
obj_det.on_click(on_obj_clicked)
img_class.on_click(on_class_clicked)
reset.on_click(on_reset_clicked)
semantic_segm.on_click(on_segm_clicked)

grid = widgets.GridspecLayout(2, 5)
grid[0,0]=uploader
grid[0,1]=reset
grid[0,2]=obj_det
grid[0,3]=img_class
grid[0,4]=semantic_segm
grid[1,:2]=sliders[0]
grid[1,2:4]=sliders[1]
grid[1,4]=sliders[2]
display(viewer, grid)





Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x03\x8f\x00\x00\x02\x04\x08\x02\x00\x00\x006\x0b\x8c…

GridspecLayout(children=(FileUpload(value={}, accept='image/*', description='Select an image', layout=Layout(g…