In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

import os
os.environ['CUDA_VISIBLE_DEVICES']=''

import numpy as np
import pandas as pd
import glob
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.externals import joblib
from skimage.transform import resize
from tqdm import tqdm_notebook as tqdm
from torch.autograd import Variable
import torch
import ipywidgets as ipy

from common_blocks.augmentation import resize_pad_seq
from common_blocks.utils import plot_list, read_images
from common_blocks.models import weighted_focal_loss
from common_blocks.metrics import compute_eval_metric

METADATA_FILEPATH = 'YOUR/metadata.csv'
OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = 'YOUR/validation_results.pkl'

METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'
OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments/sal_1036_cv_829_lb_837/out_of_fold_train_predictions.pkl'

In [None]:
def load_img(path):
    img = np.array(Image.open(path))
    return img

def filter_size(sizes, size_range):
    size_min, size_max = size_range
    filtered_idx = []
    for idx, tup in enumerate(sizes):
        if size_min<=tup<=size_max:
            filtered_idx.append(idx)
    return filtered_idx

image_prep = resize_pad_seq(102, 'edge', 13)

In [None]:
metadata = pd.read_csv(METADATA_FILEPATH)

oof_train = joblib.load(OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH)
ids = oof_train['ids']
predictions = oof_train['images']

In [None]:
THRESHOLD = 0.5

predicted_maps, masks, images, iouts, sizes = [],[],[],[],[]
for idx, pred in tqdm(zip(ids, predictions)):
    row = metadata[metadata['id']==idx]
    predicted_map = np.zeros((2,101,101))
    predicted_map[0,:,:] = resize(pred[0,:,:],(101,101),mode='constant')
    predicted_map[1,:,:] = resize(pred[1,:,:],(101,101),mode='constant')
    predicted_mask = (predicted_map[1,:,:] > THRESHOLD).astype(int)
    mask = (load_img(row.file_path_mask.values[0]) > 0).astype(int)
    image = load_img(row.file_path_image.values[0])
    iout = compute_eval_metric(mask, predicted_mask)
    size = np.sum(mask)
    images.append(image)
    masks.append(mask)
    predicted_maps.append(predicted_map)
    iouts.append(iout)
    sizes.append(size)

In [None]:
size_idxs = filter_size(sizes, size_range=(1, 300))

In [None]:
@ipy.interact(idx = ipy.IntSlider(min=0,max=len(size_idxs)-1,value=0,step=1),
              alpha = ipy.FloatSlider(min=0,max=1,value=1.0,step=0.05),
              gamma = ipy.FloatSlider(min=0,max=10,value=0.0,step=0.1),
              max_weight = ipy.FloatSlider(min=1,max=1000.0,value=100.0,step=1.0),
              focus_threshold = ipy.FloatSlider(min=0,max=1,value=0.0,step=0.1),
              use_size_weight = ipy.Checkbox(value=True),
              use_border_weight = ipy.Checkbox(value=True),
              border_size = ipy.IntSlider(min=0,max=30,value=10,step=1),
              border_weight = ipy.FloatSlider(min=0,max=10.,value=10.0,step=0.25))
def present(idx, alpha, gamma,focus_threshold,
            max_weight,use_size_weight, use_border_weight,border_size, border_weight):
    data_idx = size_idxs[idx]
    predicted_map = predicted_maps[data_idx]
    logit = np.log(predicted_map/(1.0-predicted_map))
    output = np.expand_dims(logit,axis=0)
    
    mask = masks[data_idx]

    target = np.zeros_like(output)
    target[:,1,:,:] = mask
    target[:,0,:,:] = (mask == 0).astype(np.uint8)

    iout = iouts[data_idx]
    output = Variable(torch.Tensor(output))
    target = Variable(torch.Tensor(target))
    image = images[data_idx]

    focal_loss = weighted_focal_loss(output, target,
                               alpha=alpha, gamma=gamma,
                               max_weight=max_weight,
                               use_size_weight=use_size_weight,
                               use_border_weight=use_border_weight,
                               focus_threshold=focus_threshold,
                               border_size=border_size, border_weight=border_weight)
    focal_loss = focal_loss.data.cpu().numpy()[0]
        
    bce_loss = torch.nn.BCEWithLogitsLoss()(output, target)
    bce_loss = bce_loss.data.cpu().numpy()[0]
    
    print('BCE {:.4f}, Focal Loss {:.4f}, IOUT {:.2f}'.format(bce_loss, focal_loss, iout))
    plot_list(images=[image, predicted_map[1,:,:]],labels=[mask])