In [None]:
#| include: false
#all_slow

In [None]:
#| include: false
!pip install -Uqq fastai --upgrade
!pip install -Uqq fastcore --upgrade

In [None]:
#| include: false
!pip install voila
!jupyter serverextension enable --sys-prefix voila

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from steel_segmentation.all import *
from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive, VBox
from IPython.display import Image

In [None]:
def opt_func(params, **kwargs): return OptimWrapper(params, torch.optim.Adam, **kwargs)
def splitter(m): return convert_params([[m.encoder], [m.decoder], [m.segmentation_head]])

In [None]:
learner = load_learner(models_dir/"exports"/"efficientnet-b2.pkl")

In [None]:
#| output: false
btn_upload = widgets.FileUpload()
img_widget = widgets.Image()
lbl_pred = widgets.Label()
out_pl = widgets.Output()
btn_run = widgets.Button(description='Classify')

In [None]:
def get_defects(preds) -> str:
    argmax_defects = list(preds.float().argmax(0).unique().numpy())[1:]
    types_defects = [str(o+1) for o in argmax_defects]
    n_defects = len(types_defects)
    defects_word = "defects" if n_defects!=1 else "defect"
    types_word = "types" if n_defects!=1 else "type"
    if n_defects > 0:
        return f"Predicted: n°{n_defects} {defects_word} of {types_word}: {' '.join(types_defects)}"
    else:
        return f"Predicted: n°0 {defects_word}"

def segment_img(img):
    rles, preds, probs = learner.predict(img)
    title = get_defects(preds)
    img_np = np.array(img)
    w,h,_ = img_np.shape
    
    out_pl.clear_output()
    with out_pl: 
        plot_mask_image("Original", img_np, np.zeros((w,h,4)))
        plot_mask_image("Predicted", np.array(img), preds.permute(1,2,0).float().numpy())
        
    lbl_pred.value = title
    
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1]) # new release .content.tobytes()
    segment_img(img)
        
btn_run.on_click(on_click_classify)

In [None]:
VBox([widgets.Label('Detect steel defects with image segmentation'), 
      btn_upload, btn_run, out_pl, lbl_pred])

VBox(children=(Label(value='Detect steel defects with image segmentation'), FileUpload(value={}, description='…