In [1]:
#hide
#all_slow

# Steel Defect Detection

> POC from multiple Severstal Kaggle Competition solutions.

In [2]:
#hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from fastai.vision.all import *
from steel_segmentation.all import * 
import cv2
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive, VBox, HBox
from IPython.display import Image

def opt_func(params, **kwargs): return OptimWrapper(params, torch.optim.Adam, **kwargs)
def splitter(m): return convert_params([[m.encoder], [m.decoder], [m.segmentation_head]])

In [4]:
exported_model_url = "https://www.dropbox.com/s/f52j2u4trox0i6o/efficientnet-b2.pkl?dl=1"
exported_model = download_data(exported_model_url, "export.pkl")

In [5]:
learner = load_learner(exported_model, cpu=True)

In [55]:
# prediction
  
def get_label_defects(preds: torch.Tensor, groundtruth:bool=False) -> str:   
    preds = preds.float()
    zero_preds = torch.zeros(tuple(preds.shape)[1:])
    detected_defects = torch.cat([zero_preds.unsqueeze(axis=0), preds])
    idx_defects = detected_defects.argmax(0).unique()
    argmax_defects = list(np.sort(idx_defects.numpy()))
    
    types_defects = [str(o) for o in argmax_defects][1:]
    n_defects = len(types_defects)
    
    defects_word = "defects" if n_defects!=1 else "defect"
    types_word = "types" if n_defects!=1 else "type"
    preds_word = "Ground truth" if groundtruth else "Predicted"
    if n_defects > 0:
        return f"{preds_word}: n°{n_defects} {defects_word} of {types_word}: {' '.join(types_defects)}"
    else:
        return f"{preds_word}: n°0 {defects_word}"

def segment_img(img, out_widget, lbl_widget, gt=None, gt_label=None):
    rles, preds, probs = learner.predict(img)
    lbl_widget.value = get_label_defects(preds)    
    
    img_np = np.array(img)
    w,h,_ = img_np.shape
    
    out_widget.clear_output()
    with out_widget: 
        plot_mask_image("Original", img_np, np.zeros((w,h,4)))
        
        if not ((gt is None) or (gt_label is None)):
            gt_label.value = get_label_defects(
                torch.Tensor(gt).permute(2,0,1), 
                groundtruth=True
            )
            plot_mask_image("Ground Truth", img_np, gt)
            
        plot_mask_image("Predicted", np.array(img), preds.permute(1,2,0).float().numpy())

In [56]:
def get_imgid_list(img_path):
    return [o.name for o in img_path]

In [57]:
# Training images with ground truth

# Description
description_train = widgets.Label('Select an image and compare the prediction with the "ground truth"')

# Dropdown element
style = {'description_width': 'initial'}
imgid_elements_train = get_imgid_list(train_pfiles)
imgid_dropdown_train = widgets.Dropdown(
    options=imgid_elements_train, index=0, description="Select an Image:", style=style)
# Detect button
btn_run_train = widgets.Button(description='Detect')
btn_run_train.style.button_color = 'lightgreen'
# Label for text output groundtruth
lbl_ground_train = widgets.Label()
# Label for text output
lbl_pred_train = widgets.Label()
# Plot output
out_pl_train = widgets.Output()

# final GUI
training_box = VBox([
    description_train,
    HBox([imgid_dropdown_train, btn_run_train]), 
    lbl_ground_train,
    lbl_pred_train,
    out_pl_train], 
    style=style
)

# Actions
def on_change_detect_train(change):
    imageid = imgid_dropdown_train.value    
    _, mask = make_mask(imageid)
    
    image_np = cv2.imread(str(train_path/imageid))
    img = PILImage.create(image_np)
    
    segment_img(img, out_pl_train, lbl_pred_train, gt=mask, gt_label=lbl_ground_train)
    
btn_run_train.on_click(on_change_detect_train)
imgid_dropdown_train.observe(on_change_detect_train, names='value')

In [58]:
# Test images

# Description
description_test = widgets.Label('Select an image from the test set (no ground truth)')

# Dropdown element
style = {'description_width': 'initial'}
imgid_elements_test = get_imgid_list(test_pfiles)
imgid_dropdown_test = widgets.Dropdown(
    options=imgid_elements_test, index=0, description="Select an Image:", style=style)
# Detect button
btn_run_test = widgets.Button(description='Detect')
btn_run_test.style.button_color = 'lightgreen'
# Label for text output predictions
lbl_pred_test = widgets.Label()
# Plot output
out_pl_test = widgets.Output()

# final GUI
testing_box = VBox([
    description_test,
    HBox([imgid_dropdown_test, btn_run_test]), 
    lbl_pred_test,
    out_pl_test], 
    style=style)

# Actions
def on_change_detect_test(change):
    imageid = imgid_dropdown_test.value    
    
    image_np = cv2.imread(str(test_path/imageid))
    img = PILImage.create(image_np)
    
    segment_img(img, out_pl_test, lbl_pred_test)
    
btn_run_test.on_click(on_change_detect_test)
imgid_dropdown_test.observe(on_change_detect_test, names='value')

In [59]:
# Upload

# Description
description_label_upload = widgets.Label('Upload an image and detect')
# Upload button
btn_upload = widgets.FileUpload(multiple=False)
# Detect button
btn_run_upload = widgets.Button(description='Detect')
btn_run_upload.style.button_color = 'lightgreen'
# Clear button
clear_upload = widgets.Button(description='Clear', button_style='danger')
# Label for text output
lbl_pred_upload = widgets.Label()
# Plot output
out_pl_upload = widgets.Output()

# final GUI
upload_box = VBox([description_label_upload, 
                   HBox([btn_upload, btn_run_upload, clear_upload]),
                   lbl_pred_upload, 
                   out_pl_upload])

# Actions
def on_click_detect_uploaded(change):
    img = PILImage.create(btn_upload.data[-1]) # new release .content.tobytes()
    segment_img(img, out_pl_upload, lbl_pred_upload)
    
def on_click_clear_uploaded(change):
    btn_upload._counter = 0
    btn_upload.value.clear()
    out_pl_upload.clear_output()
    lbl_pred_upload.value = ""
        
btn_run_upload.on_click(on_click_detect_uploaded)
clear_upload.on_click(on_click_clear_uploaded)

In [60]:
# tab = widgets.Tab()
# tab.children = [upload_box, multi_choice_box]
# tab.set_title(0, 'Upload validation')
# tab.set_title(1, 'Multiple choice validation')
# tab

In [61]:
accordion = widgets.Accordion(
    children=[training_box, testing_box, upload_box])

accordion.set_title(0, 'Training images with ground truth')
accordion.set_title(1, 'Test images validation')
accordion.set_title(2, 'Upload validation')
#accordion.set_title(2, 'Test validation')
accordion

Accordion(children=(VBox(children=(Label(value='Select an image and compare the prediction with the "ground tr…

![logo](logo.jpg)