In [None]:
#hide
#all_slow

# Steel Defect Detection

> POC from multiple Severstal Kaggle Competition solutions.

In [1]:
#hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.vision.all import *
from steel_segmentation.all import * 

import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive, VBox
from IPython.display import Image

In [3]:
def opt_func(params, **kwargs): return OptimWrapper(params, torch.optim.Adam, **kwargs)
def splitter(m): return convert_params([[m.encoder], [m.decoder], [m.segmentation_head]])

In [4]:
learner = load_learner("efficientnet-b2.pkl")

In [40]:
#hide_output
btn_upload = widgets.FileUpload(multiple=False)
clear_upload = widgets.Button(description='Clear')
lbl_pred = widgets.Label()
out_pl = widgets.Output()
btn_run = widgets.Button(description='Classify')

In [42]:
def get_defects(preds) -> str:
    preds = preds.float()
    
    zero_preds = torch.zeros(tuple(preds.shape)[1:])
    detected_defects = torch.cat([zero_preds.unsqueeze(axis=0), preds])
    idx_defects = detected_defects.argmax(0).unique()
    argmax_defects = list(np.sort(idx_defects.numpy()))
    types_defects = [str(o) for o in argmax_defects][1:]
    n_defects = len(types_defects)
    
    defects_word = "defects" if n_defects!=1 else "defect"
    types_word = "types" if n_defects!=1 else "type"
    if n_defects > 0:
        return f"Predicted: n°{n_defects} {defects_word} of {types_word}: {' '.join(types_defects)}"
    else:
        return f"Predicted: n°0 {defects_word}"

def segment_img(img):
    rles, preds, probs = learner.predict(img)
    title = get_defects(preds)
    img_np = np.array(img)
    w,h,_ = img_np.shape
    
    out_pl.clear_output()
    with out_pl: 
        plot_mask_image("Original", img_np, np.zeros((w,h,4)))
        plot_mask_image("Predicted", np.array(img), preds.permute(1,2,0).float().numpy())
        
    lbl_pred.value = title
    
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1]) # new release .content.tobytes()
    segment_img(img)
    
def on_click_clear(change):
    btn_upload.value.clear()
    out_pl.clear_output()
    lbl_pred.value = ""
    btn_upload._counter = 0
        
btn_run.on_click(on_click_classify)
clear_upload.on_click(on_click_clear)

In [44]:
VBox([widgets.Label('Detect steel defects with image segmentation'), 
      btn_upload, clear_upload, btn_run, out_pl, lbl_pred])

VBox(children=(Label(value='Detect steel defects with image segmentation'), FileUpload(value={'0a1cade03.jpg':…