In [None]:
#default_exp metrics

In [None]:
#hide
#missing
#!git clone https://github.com/marcomatteo/steel_segmentation.git
#!pip install -e steel_segmentation


In [None]:
#hide
!pip install -Uqq fastai --upgrade
!pip install -Uqq fastcore --upgrade

# Metrics

> A collection of Metrics used in the segmentation models

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/marcomatteo/steel_segmentation/blob/master/nbs/05_metrics.ipynb)

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from steel_segmentation.metadata import *
from steel_segmentation.masks import *
from steel_segmentation.datasets import *
from steel_segmentation.dataloaders import *

import fastai
from fastai.vision.all import *
from fastcore.foundation import *

import torch
import torch.nn.functional as F

from collections import defaultdict

import segmentation_models_pytorch as smp

In this section there are all the metric that can be used to evaluate the performances of the segmentation models trained.

In [None]:
# missing
dls = get_segmnt_dls(train_pivot, bs=20)
x, targs = dls.train.one_batch()
x.shape, targs.shape

(torch.Size([20, 3, 256, 1600]), torch.Size([20, 4, 256, 1600]))

In [None]:
# missing
x = x.cpu()
model = smp.Unet("resnet34", 
                 encoder_weights="imagenet", 
                 classes=4, 
                 activation=None)
loaded_params = torch.load(models_dir/"kaggle-UNET-ResNet34.pth")
model.load_state_dict(loaded_params["state_dict"], strict=True)

<All keys matched successfully>

In [None]:
# missing
logits = model(x)
probs = torch.sigmoid(logits) 
preds = (probs > 0.5).float()

Simulated training with `compute_val` and a test Learner with `TstLearner`.

In [None]:
#For testing: a fake learner and a metric that isn't an average
@delegates()
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): 
        self.pred,self.xb,self.yb = None,None,None
        self.loss_func=BCEWithLogitsLossFlat()
        
#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, pred, y):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3):
        learn.pred = pred[vals[i]:vals[i+1]]
        learn.yb = ( y[vals[i]:vals[i+1]], )
        met.accumulate(learn)
    return met.value

## Multiclass Dice

The `fastai` library comes with a dice metric for multiple channel masks. As a segmentation metric in this frameworks, it expects a flatten mask for targets.

In [None]:
multidice_obj = DiceMulti()

In [None]:
# missing
compute_val(multidice_obj, pred=preds, y=targs.argmax(1))

0.5713036010962022

Here we slightly change the `DiceMulti` for a 4-channel mask as targets.

In [None]:
# export
class ModDiceMulti(Metric):
    "Averaged Dice metric (Macro F1) for multiclass target in segmentation"

    def __init__(self, axis=1, with_logits=False): 
        self.axis = axis
        self.with_logits = with_logits
        
    def reset(self): self.inter, self.union =  {}, {}

    def accumulate(self, learn):
        if self.with_logits:
            logit = learn.pred
            prob = torch.sigmoid(logit)
            pred = (prob > 0.5).float().argmax(dim=self.axis)
        else:
            pred = learn.pred.argmax(dim=self.axis)
        
        y = learn.yb[0]
        # Added to deal with 4-channels masks
        if pred.shape != y.shape:
            y = y.argmax(dim=self.axis)
            
        pred, targ = flatten_check(pred, y)
        for c in range(learn.pred.shape[self.axis]):
            p = torch.where(pred == c, 1, 0)
            t = torch.where(targ == c, 1, 0)
            p, t = TensorBase(p), TensorBase(t) # may be redundant (old fastai bug)
            c_inter = (p*t).float().sum().item()
            c_union = (p+t).float().sum().item()
            if c in self.inter:
                self.inter[c] += c_inter
                self.union[c] += c_union
            else:
                self.inter[c] = c_inter
                self.union[c] = c_union

    @property
    def value(self):
        binary_dice_scores = np.array([])
        for c in self.inter:
            binary_dice_scores = np.append(
                binary_dice_scores, 
                2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)
        self.binary_dice_scores = binary_dice_scores
        return np.nanmean(binary_dice_scores)

In [None]:
dice_obj = ModDiceMulti(with_logits=True)

In [None]:
# missing
compute_val(dice_obj, pred=logits, y=targs)

0.5713036010962022

In [None]:
dice_obj = ModDiceMulti()

In [None]:
# missing
compute_val(dice_obj, pred=preds, y=targs)

0.5713036010962022

Different targets: 
- a flatten mask, used by fastai segmentation models
- a 4-channels mask, used by pytorch segmentation models

In [None]:
x1a = torch.ones(20,1,1,1)
x1b = torch.clone(x1a)*0.5
x1c = torch.clone(x1a)*0.3
x1d = torch.clone(x1a)*0.1
x1 = torch.cat((x1a,x1b,x1c,x1d),dim=1)   # Prediction: 20x4

x2 = torch.zeros(20,1,1)       # Target: 20xClass0
x2chs = torch.zeros(20,4,1,1)  # Target: 20xClass0

# Dice metric = 1
test_eq(compute_val(dice_obj, x1, x2), 1.)
test_eq(compute_val(dice_obj, x1, x2chs), 1.)

x2_ch0 = torch.zeros(20,1,1,1)
x2_ch1 = torch.ones(20,1,1,1)
x2_ch2 = torch.zeros(20,1,1,1)
x2_ch3 = torch.zeros(20,1,1,1)
x2_chs = (x2_ch0, x2_ch1, x2_ch2, x2_ch3)

x2 = torch.ones(20,1,1)          # Target: 20xClass1
x2chs = torch.cat(x2_chs, dim=1) # Target: 20xClass1

# Dice metric = 0
test_eq(compute_val(dice_obj, x1, x2), 0.)
test_eq(compute_val(dice_obj, x1, x2chs), 0.)

Different scenario with a multiclass batch:
- Class0 x 10
- Class1 x 4
- Class2 x 3
- Class4 x 3

In [None]:
# Target: 10xClass0, 4xClass1, 3xClass2, 3xClass4
x2a = torch.zeros(10,1,1)
x2b = torch.ones(4,1,1)
x2c = torch.ones(3,1,1) * 2
x2d = torch.ones(3,1,1) * 3
x2 = torch.cat((x2a,x2b,x2c,x2d),dim=0) # shape (20, 1, 1)
computed_dice = compute_val(dice_obj, x1, x2)

batch_sizes = [10, 4, 3, 3]
x2_chs = [torch.zeros(n, 4, 1, 1) for i, n in enumerate(batch_sizes)]
for i, x2_ch in enumerate(x2_chs):
    x2_ch[:, i] = 1
x2chs = torch.cat(x2_chs, dim=0) # shape (20, 4, 1, 1)
computed_dice_chs = compute_val(dice_obj, x1, x2chs)

# Dice: 2*TP/(2*TP+FP+FN)
dice1 = (2*10)/(2*10+4+3+3)              
dice2 = 0
dice3 = 0
dice4 = 0

# Dice metric = 0.1666
test_eq(computed_dice,     (dice1+dice2+dice3+dice4)/4)
test_eq(computed_dice_chs, (dice1+dice2+dice3+dice4)/4)
test_eq(computed_dice, computed_dice_chs)

computed_dice

0.16666666666666666

## Kaggle Dice metric
The competition [evaluation metric](https://www.kaggle.com/c/severstal-steel-defect-detection/overview/evaluation) is defined as:

> This competition is evaluated on the mean Dice coefficient. The Dice coefficient can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. The formula is given by:

$$
J(A,B) = \frac{2 * |A \cap B|}{|A| \cup |B|}
$$

> where X is the predicted set of pixels and Y is the ground truth. The Dice coefficient is defined to be 1 when both X and Y are empty. The leaderboard score is the mean of the Dice coefficients for each <ImageId, ClassId> pair in the test set.

In [None]:
#export
class KaggleDice(Metric):
    """
    Multi-class Dice used in Severstal comp,
    is 1 when prediction and mask are empty
    """
    def __init__(self, axis=1, with_logits=False, eps=1e-9): 
        self.axis = axis
        self.eps = eps
        self.with_logits = with_logits
        
    def reset(self): self.inter, self.union = defaultdict(list), defaultdict(list)

    def accumulate(self, learn):
        if self.with_logits:
            logit = learn.pred
            prob = torch.sigmoid(logit)
            pred = (prob > 0.5).float().argmax(dim=self.axis)
        else:
            pred = learn.pred.argmax(dim=self.axis)
        
        y = learn.yb[0]
        if pred.shape != y.shape:
            y = y.argmax(dim=self.axis)
        
        n, c = y.shape[0], pred.shape[self.axis]
            
        preds, targs = flatten_check(pred, y)
        for i in range(0, c):
            p = torch.where(preds == i, 1, 0)
            t = torch.where(targs == i, 1, 0)

            p, t = TensorBase(p), TensorBase(t)

            c_inter = (p*t).sum(-1).float()#.item()
            c_union = (p+t).sum(-1).float()#.item()

            self.inter[i].append(c_inter) 
            self.union[i].append(c_union)

    @property
    def value(self):
        binary_dice_scores = np.array([])
        for c in range(len(self.inter)):
            inter = torch.stack(self.inter[c])
            union = torch.stack(self.union[c])
            
            val = 2.*(inter+self.eps)/(union+self.eps)
            cond = union == 0
            val[cond] = 1
            
            binary_dice_scores = np.append(binary_dice_scores, val.cpu().numpy())
            
        self.binary_dice_scores = binary_dice_scores
        return np.nanmean(binary_dice_scores)        
        #return (binary_dice_scores).reshape(-1, 4).mean(0).mean()

In [None]:
dice_kobj = KaggleDice(with_logits=True)

In [None]:
# missing
compute_val(dice_kobj, pred=logits, y=targs)

0.9962328915329041

In [None]:
dice_kobj = KaggleDice()

In [None]:
# missing
compute_val(dice_kobj, pred=preds, y=targs)

0.9962328915329041

Trying to cast a metric founded in a Kaggle discussion. These metrics work but can be problematic with a valuation phase with more than 1000 examples.

In [None]:
#export
def single_dice_coef(y_true, y_pred, smooth=1):
    """Binary segmentation function."""
    y_true_f = np.ndarray.flatten(y_true)
    y_pred_f = np.ndarray.flatten(y_pred)
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

def single_dice_coef_channel(y_true, y_pred, smooth=1):
    """Multichannel segmentation function."""
    ch1 = single_dice_coef(y_true[:,0,:,:], y_pred[:,0,:,:],smooth)
    ch2 = single_dice_coef(y_true[:,1,:,:], y_pred[:,1,:,:],smooth)
    ch3 = single_dice_coef(y_true[:,2,:,:], y_pred[:,2,:,:],smooth)
    ch4 = single_dice_coef(y_true[:,3,:,:], y_pred[:,3,:,:],smooth)
    res = (ch1+ch2+ch3+ch4)/4
    return res

In [None]:
#exports
KaggleDiceCoefMulti = AccumMetric(single_dice_coef_channel, to_np=True, flatten=False, thresh=0.5)

In [None]:
#missing
compute_val(KaggleDiceCoefMulti, logits, targs)

0.5636921149524751

In [None]:
#exports
FastKaggleCoefDiceMulti = AccumMetric(single_dice_coef_channel, to_np=True, flatten=False)

In [None]:
x0, y0 = torch.zeros(20, 4, 1, 1), torch.zeros(20, 4, 1, 1)

test_eq(compute_val(FastKaggleCoefDiceMulti, x0, y0), 1.)
test_close(compute_val(FastKaggleCoefDiceMulti, x1, x2chs), 0.38935)

In [None]:
#missing
compute_val(FastKaggleCoefDiceMulti, preds, targs)

0.5962340563656034

In [None]:
#exports
KaggleDiceCoef = AccumMetric(single_dice_coef, to_np=True, flatten=False, thresh=0.5)

In [None]:
#missing
for ch in range(4):
    print(compute_val(KaggleDiceCoef, logits[:,ch], targs[:,ch]))

0.556828003457217
1.0
0.6978286745507603
0.000111781801922647


In [None]:
#exports
FastKaggleDiceCoef = AccumMetric(single_dice_coef, to_np=True, flatten=False)

In [None]:
#missing
for ch in range(4):
    print(compute_val(FastKaggleDiceCoef, preds[:,ch], targs[:,ch]))

0.6492999628298848
1.0
0.7355494646115235
8.679802100512108e-05


In [None]:
x0, y0 = torch.zeros(20, 1, 1), torch.zeros(20, 1, 1)
test_eq(compute_val(FastKaggleDiceCoef, x0, y0), 1.)
test_eq(compute_val(FastKaggleDiceCoef, x1[:,0], x2), 0.975)

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_metadata.ipynb.
Converted 02_masks.ipynb.
Converted 03_datasets.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_metrics.ipynb.
Converted 06_loss.ipynb.
Converted 07_trainer.ipynb.
Converted 08_predict.ipynb.
Converted 09_visualize.ipynb.
Converted index.ipynb.
