<img style="float: right" width="150" height="150" src="logo.jpg">

# Rilevamento difetti superficiali nell’acciaio

> Caso d'uso di IP4FVG - Nodo Data Analytics and Artificial Intelligence.

Questa applicazione permette di classificare e individuare difetti superficiali dell'acciaio attraverso l'utilizzo di algoritmi di Deep Learning per il *machine vision* (segmentazione immagini).

In [1]:
#hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.vision.all import *
from steel_segmentation.all import * 
import cv2
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive, VBox, HBox
from IPython.display import Image

def opt_func(params, **kwargs): return OptimWrapper(params, torch.optim.Adam, **kwargs)
def splitter(m): return convert_params([[m.encoder], [m.decoder], [m.segmentation_head]])

style = {'description_width': 'initial'}

In [3]:
# loading data and models
data_url = "https://www.dropbox.com/s/wzdtphbtaouaet5/data.zip?dl=1"
data_path = untar_data(data_url, fname="data.zip", dest=".")

effnet_model_url = "https://www.dropbox.com/s/f52j2u4trox0i6o/efficientnet-b2.pkl?dl=1"
effnet_model_path = download_data(effnet_model_url, "effnet_export.pkl")
resnet_model_url = "https://www.dropbox.com/s/qumn1dshh9b0154/resnet_export.pkl?dl=1"
resnet_model_path = download_data(resnet_model_url, "resnet_export.pkl")

In [38]:
# config selection
models = [effnet_model_path, resnet_model_path]

def select_learner(selection):
    return load_learner(models[selection], cpu=True)

model_radio = widgets.RadioButtons(
    options=[
        ("Model 1", 0), 
        ("Model 2", 1)
    ],
    value=1,
    description="Deep Learning model:",
    style=style
)

thresh_sel = widgets.FloatSlider(
    value=0.5,
    min=0.,
    max=1.,
    step=0.1,
    description='Threshold:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style
)

pixels_sel = widgets.IntSlider(
    value=0,
    min=0,
    max=5000,
    step=100,
    description='Minimum pixel for defects:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    style=style
)

learner = select_learner(model_radio.value)

def on_change_model(change):
    global learner
    learner = select_learner(model_radio.value)
model_radio.observe(on_change_model, names='value')

config_box = VBox([
    model_radio,
    thresh_sel,
    pixels_sel
])

In [33]:
# prediction
def post_process(probability, threshold, min_size):
    """
    Post processing of each predicted mask, components with lesser number of pixels
    than `min_size` are ignored.
    """
    def post_process_channel(p):
        mask = cv2.threshold(p, threshold, 1, cv2.THRESH_BINARY)[1]
        num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
        predictions = np.zeros((256, 1600), np.float32)

        for c in range(1, num_component):
            p = (component == c)
            if p.sum() > min_size:
                predictions[p] = 1

        return torch.Tensor(predictions)
    
    preds_list = []
    probs = probability.numpy()
    n = probs.shape[0]
    for ch in range(n):
        prob = post_process_channel(probs[ch])
        preds_list.append(prob)
    
    return torch.stack(preds_list, dim=0)

def get_label_defects(preds: torch.Tensor, groundtruth:bool=False) -> str:   
    preds = preds.float()
    zero_preds = torch.zeros(tuple(preds.shape)[1:])
    detected_defects = torch.cat([zero_preds.unsqueeze(axis=0), preds])
    idx_defects = detected_defects.argmax(0).unique()
    argmax_defects = list(np.sort(idx_defects.numpy()))
    
    types_defects = [str(o) for o in argmax_defects][1:]
    n_defects = len(types_defects)
    
    defects_word = "defects" if n_defects!=1 else "defect"
    types_word = "types" if n_defects!=1 else "type"
    preds_word = "Ground truth" if groundtruth else "Predicted"
    if n_defects > 0:
        return f"<b>{preds_word}</b>: n°<i>{n_defects}</i> {defects_word} of {types_word}: <i>{' '.join(types_defects)}</i>"
    else:
        return f"<b>{preds_word}</b>: n°0 {defects_word}"

def segment_img(
    img, 
    out_widget, lbl_widget, 
    thresh_sel, pixels_sel,
    gt_mask=None, gt_label=None
):
    threshold = thresh_sel.value
    min_size = pixels_sel.value
    
    rles, preds, probs = learner.predict(img)
    
    img_np = np.array(img)
    w,h,_ = img_np.shape
    
    out_widget.clear_output()
    with out_widget: 
        plot_mask_image("Original", img_np, np.zeros((w,h,4)))
        
        if not ((gt_mask is None) or (gt_label is None)):
            gt_label.value = get_label_defects(
                torch.Tensor(gt_mask).permute(2,0,1), 
                groundtruth=True) + " | "
            plot_mask_image("Ground Truth", img_np, gt_mask)
            
        post_processed_preds = post_process(probs, threshold, min_size)
        lbl_widget.value = get_label_defects(post_processed_preds)      
        plot_mask_image("Predicted", np.array(img), post_processed_preds.permute(1,2,0).float().numpy())
        
def get_imgid_list(img_path):
    return [o.name for o in img_path]

In [6]:
# Training images with ground truth

# Description
description_train = widgets.Label('Predict the defects and compare the results with the labels (ground truth).')

# Dropdown element
imgid_elements_train = get_imgid_list(train_pfiles)
imgid_dropdown_train = widgets.Dropdown(
    options=imgid_elements_train, index=0, description="Select an Image:", style=style)
# Detect button
btn_run_train = widgets.Button(description='Detect')
btn_run_train.style.button_color = 'lightgreen'
# Label for text output groundtruth
lbl_ground_train = widgets.HTML() #widgets.Label()
# Label for text output
lbl_pred_train = widgets.HTML() #widgets.Label()
# Plot output
out_pl_train = widgets.Output()

# final GUI
training_box = VBox([
    description_train,
    HBox([imgid_dropdown_train, btn_run_train]), 
    HBox([
        lbl_ground_train,
        lbl_pred_train
    ]),
    out_pl_train],
    style=style
)

# Actions
def on_change_detect_train(change):
    imageid = imgid_dropdown_train.value    
    _, mask = make_mask(imageid)
    
    image_np = cv2.imread(str(train_path/imageid))
    img = PILImage.create(image_np)
    
    segment_img(
        img, 
        out_pl_train, lbl_pred_train, 
        thresh_sel, pixels_sel, 
        gt_mask=mask, gt_label=lbl_ground_train
    )
    
btn_run_train.on_click(on_change_detect_train)
imgid_dropdown_train.observe(on_change_detect_train, names='value')

In [7]:
# Test images

# Description
description_test = widgets.Label('Predict the defects from test images (without labels).')

# Dropdown element
style = {'description_width': 'initial'}
imgid_elements_test = get_imgid_list(test_pfiles)
imgid_dropdown_test = widgets.Dropdown(
    options=imgid_elements_test, index=0, description="Select an Image:", style=style)
# Detect button
btn_run_test = widgets.Button(description='Detect')
btn_run_test.style.button_color = 'lightgreen'
# Label for text output predictions
lbl_pred_test = widgets.HTML() #widgets.Label()
# Plot output
out_pl_test = widgets.Output()

# final GUI
testing_box = VBox([
    description_test,
    HBox([imgid_dropdown_test, btn_run_test]), 
    lbl_pred_test,
    out_pl_test], 
    style=style)

# Actions
def on_change_detect_test(change):
    imageid = imgid_dropdown_test.value    
    
    image_np = cv2.imread(str(test_path/imageid))
    img = PILImage.create(image_np)
    
    segment_img(
        img, 
        out_pl_test, lbl_pred_test,
        thresh_sel, pixels_sel
    )
    
btn_run_test.on_click(on_change_detect_test)
imgid_dropdown_test.observe(on_change_detect_test, names='value')

In [8]:
# Upload

# Description
description_label_upload = widgets.Label('Upload an image and predict the defects.')
# Upload button
btn_upload = widgets.FileUpload(multiple=False)
# Detect button
btn_run_upload = widgets.Button(description='Detect')
btn_run_upload.style.button_color = 'lightgreen'
# Clear button
clear_upload = widgets.Button(description='Clear', button_style='danger')
# Label for text output
lbl_pred_upload = widgets.HTML() #widgets.Label()
# Plot output
out_pl_upload = widgets.Output()

# final GUI
upload_box = VBox([description_label_upload, 
                   HBox([btn_upload, btn_run_upload, clear_upload]),
                   lbl_pred_upload, 
                   out_pl_upload])

# Actions
def on_click_detect_uploaded(change):
    img = PILImage.create(btn_upload.data[-1]) # new release .content.tobytes()
    segment_img(
        img, 
        out_pl_upload, lbl_pred_upload,
        thresh_sel, pixels_sel
    )
    
def on_click_clear_uploaded(change):
    btn_upload._counter = 0
    btn_upload.value.clear()
    out_pl_upload.clear_output()
    lbl_pred_upload.value = ""
        
btn_run_upload.on_click(on_click_detect_uploaded)
clear_upload.on_click(on_click_clear_uploaded)

In [9]:
# tab = widgets.Tab()
# tab.children = [upload_box, multi_choice_box]
# tab.set_title(0, 'Upload validation')
# tab.set_title(1, 'Multiple choice validation')
# tab

In [37]:
box_lbls = ['Configuration', 'Detect and compare', 'Validate', 'Upload']

accordion = widgets.Accordion(
    children=[config_box, training_box, testing_box, upload_box],
    selected_index=1
)

for i, lbl in enumerate(box_lbls):
    accordion.set_title(i, lbl)

accordion

Accordion(children=(VBox(children=(RadioButtons(description='Select a model:', index=1, options=(('Model 1', 0…



**Riferimenti:**
- Documentazione tecnica: [steel_segmentation](https://marcomatteo.github.io/steel_segmentation/)
- Dataset utilizzato: [Severstal Competition, Kaggle](https://www.kaggle.com/c/severstal-steel-defect-detection/overview)