In [29]:
import ipywidgets as widgets
from ipywidgets import *
from PIL import Image, ImageOps
from fastai.basic_train import load_learner
import torch
import torch.nn as nn
from fastai import *
from fastai.vision import open_image
import glob
import pandas as pd

In [30]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [31]:

        
def minfits(tl):
    left, top = (0,0)
    if((tl[0] -20) >= 0):
        left = tl[0] -20
    if((tl[1] -20) >= 0):
        top = tl[1] -20
    return left, top

def maxfits(br):
    right, bottom = (1024,1024)
    if((br[0] +20) <= 1024):
        right = br[0] +20
    if((br[1] +20) <= 1024):
        bottom = br[1] +20
    return right, bottom

def pad(img, expected_size):
    width, height = img.size
    if(width < expected_size):
        if(width > height):
            diff = expected_size-width
            img = img.resize((expected_size, (diff+height)))
            img = ImageOps.pad(img, (expected_size,expected_size), Image.BILINEAR, None)
        elif(height > width):
            diff = expected_size-height
            img = img.resize(((diff+width), expected_size))
            img = ImageOps.pad(img, (expected_size,expected_size), Image.BILINEAR, None)
    else:
        img = img.resize((expected_size, height), Image.BILINEAR)
    return img

def extract(img, tl, br):
    im = img
    
    left, top = minfits(tl)
    right, bottom = maxfits(br)
    

    cropped = im.crop((left, top, right, bottom))
    
    cropped = pad(cropped, 128)

    return cropped
        
def remove_pad(img, original_w, original_h):
    width, height = img.size
    if(width > original_w):
        if(original_w > original_h):
            diff = 128 - original_w
            newheight = original_h + diff
            pad = (128 - newheight)/2
            if((128-newheight)%2 == 1):
                img = img.crop((0, pad, 128, (128-pad) - 1 ))
            else:
                img = img.crop((0, pad, 128, 128-pad))
            img = img.resize((original_w, original_h), Image.LANCZOS)
        elif(original_h > original_w):
            diff = 128 - original_h
            newheight = original_w + diff
            pad = 128 - newheight
            img = img.crop((pad, 0, 128-pad, 128))
            img = img.resize((original_w, original_h), Image.LANCZOS)
        else:
            img = img.resize((original_w, original_h), Image.LANCZOS)
    else:
        pad = math.ceil((128-original_h)/2)
        img = img.crop((0, pad, 128, 128-pad))
        img = img.resize((original_w, original_h), Image.LANCZOS)    
   
    return img

def insert(img1, img2, tl, br ):
  
    crop = Image.open(img1)
  
    left, top = minfits(tl)
    right, bottom = maxfits(br)
    
    crop = remove_pad(crop, (right-left), (bottom-top))
    
    
    full_image = img2
    
    full_image.paste(crop, (left, top, right, bottom))

    return full_image



def run_insert(box, image, boxes):
    for index, row in boxes.iterrows():
                image = insert(box[index], image, (row.xmin, row.ymin), (row.xmax, row.ymax))
    return image
def run_extract(image, boxes):
    cnt = 0
    for index, row in boxes.iterrows():
        img = extract(image, (row.xmin, row.ymin), (row.xmax, row.ymax))
        img.save('demo/' + str(cnt) + '.png')
        cnt = cnt+1


In [32]:
from pathlib import Path
path = Path()
remover = load_learner(path)
boxes = pd.read_csv('boundingboxes.csv')
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
btn_run = widgets.Button(description="generate")
in_pl = widgets.Output()
imglist = []
predlist = []

In [33]:
def on_click_generate(change):
    for name, file_info in btn_upload.value.items():
        bts = io.BytesIO(file_info['content'])
        img = Image.open(bts)
        run_extract(img, boxes)
        imglist = sorted(glob.glob("demo/*.png"))
        for pic in imglist:
            image = open_image(pic)
            pred, pred_ioidx, probs = remover.predict(image)
            pred.save(pic.replace('demo', 'demo/predictions'))
        predlist = sorted(glob.glob("demo/predictions/*.png"))
        final = run_insert(predlist, img, boxes)
    out_pl.clear_output()
    with out_pl: display(final)
        
def on_upload(change):
    for name, file_info in btn_upload.value.items():
        bts = io.BytesIO(file_info['content'])
        img = open_image(bts)
        
    in_pl.clear_output()
    with in_pl: display(img.resize(torch.Size([img.shape[0],256,256])))
btn_upload.observe(on_upload)
btn_run.on_click(on_click_generate)

In [34]:
display(VBox([widgets.Label('Upload a medical image with text!'),
     btn_upload, in_pl, btn_run, out_pl]))

VBox(children=(Label(value='Upload a medical image with text!'), FileUpload(value={}, description='Upload'), O…

In [None]:
!pip install voila
!jupyter serverextension enable voila --sys-prefix