In [None]:
#default_exp metrics
# all_slow

# Metrics

> A collection of Metrics used in the segmentation models

In [1]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# export
import torch
import numpy as np
from fastai.torch_core import TensorBase, flatten_check
from fastai.metrics import Metric

In [3]:
from fastai.vision.all import *
import numpy as np
from torch.nn.modules.loss import _Loss
import segmentation_models_pytorch as smp
from steel_segmentation.utils import get_train_df
from steel_segmentation.transforms import SteelDataBlock, SteelDataLoaders

In [4]:
path = Path("../data")
train_pivot = get_train_df(path=path, pivot=True)
block = SteelDataBlock(path)
dls = SteelDataLoaders(block, train_pivot, bs=8)
xb, yb = dls.one_batch()
print(xb.shape, xb.device)
print(yb.shape, yb.device)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ..\aten\src\ATen\native\BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


torch.Size([8, 3, 224, 1568]) cuda:0
torch.Size([8, 4, 224, 1568]) cpu


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [17]:
model = smp.Unet("resnet18", classes=4).to(device)

logits = model(xb)
probs = torch.sigmoid(logits)
preds = ( probs > 0.5).float()

## Kaggle Dice metric
The competition [evaluation metric](https://www.kaggle.com/c/severstal-steel-defect-detection/overview/evaluation) is defined as:

> This competition is evaluated on the mean Dice coefficient. The Dice coefficient can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. The formula is given by:

$$
J(A,B) = \frac{2 * |A \cap B|}{|A| \cup |B|}
$$

> where X is the predicted set of pixels and Y is the ground truth. The Dice coefficient is defined to be 1 when both X and Y are empty. The leaderboard score is the mean of the Dice coefficients for each <ImageId, ClassId> pair in the test set.

In this section there are all the metric that can be used to evaluate the performances of the segmentation models trained.

Simulated training with `compute_val` and a test Learner with `TstLearner`.

In [7]:
#For testing: a fake learner and a metric that isn't an average
@delegates()
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): 
        self.pred,self.xb,self.yb = None,None,None
        self.loss_func=BCEWithLogitsLossFlat()
        
#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, pred, y):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3):
        learn.pred = pred[vals[i]:vals[i+1]]
        learn.yb = ( y[vals[i]:vals[i+1]], )
        met.accumulate(learn)
    return met.value

## Multiclass Dice

The `fastai` library comes with a dice metric for multiple channel masks. As a segmentation metric in this frameworks, it expects a flatten mask for targets.

In [8]:
multidice_obj = DiceMulti()

In [11]:
compute_val(multidice_obj, pred=preds.detach().cpu(), y=yb.argmax(1))

0.1798790120410166

Here we slightly change the `DiceMulti` for a 4-channel mask as targets.

In [12]:
# export
class ModDiceMulti(Metric):
    "Averaged Dice metric (Macro F1) for multiclass target in segmentation"

    def __init__(self, axis=1, with_logits=False):
        self.axis = axis
        self.with_logits = with_logits

    def reset(self): self.inter, self.union =  {}, {}

    def accumulate(self, learn):
        if self.with_logits:
            logit = learn.pred
            prob = torch.sigmoid(logit)
            pred = (prob > 0.5).float().argmax(dim=self.axis)
        else:
            pred = learn.pred.argmax(dim=self.axis)

        y = learn.yb[0]
        # Added to deal with 4-channels masks
        if pred.shape != y.shape:
            y = y.argmax(dim=self.axis)

        pred, targ = flatten_check(pred, y)
        for c in range(learn.pred.shape[self.axis]):
            p = torch.where(pred == c, 1, 0)
            t = torch.where(targ == c, 1, 0)
            p, t = TensorBase(p), TensorBase(t) # may be redundant (old fastai bug)
            c_inter = (p*t).float().sum().item()
            c_union = (p+t).float().sum().item()
            if c in self.inter:
                self.inter[c] += c_inter
                self.union[c] += c_union
            else:
                self.inter[c] = c_inter
                self.union[c] = c_union

    @property
    def value(self):
        binary_dice_scores = np.array([])
        for c in self.inter:
            binary_dice_scores = np.append(
                binary_dice_scores,
                2.*self.inter[c]/self.union[c] if self.union[c] > 0 else np.nan)
        self.binary_dice_scores = binary_dice_scores
        return np.nanmean(binary_dice_scores)

In [18]:
dice_obj = ModDiceMulti(with_logits=True)
compute_val(dice_obj, pred=logits.detach().cpu(), y=yb)

0.2130325182791189

In [20]:
dice_obj = ModDiceMulti()
compute_val(dice_obj, pred=preds.detach().cpu(), y=yb)

0.2130325182791189