In [None]:
#default_exp metrics
from nbdev.showdoc import show_doc

# Metrics

> Definition of the metrics that can be used in training models. See [fastai2 docs](https://dev.fast.ai/metrics) for more information.

In [None]:
#hide
import torch

In [None]:
#export
from fastai.metrics import Dice as fastai_Dice
from fastai.torch_core import *
from fastai.imports import *
from fastai.learner import *
from fastcore.foundation import patch

## Dice coefficient
see [here](https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou) for F1/Dice score discussion.

In [None]:
#export
class Dice(fastai_Dice):
    "Dice coefficient metric for binary target in segmentation"
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.yb[0])
        pred, targ  = map(TensorBase, (pred, targ))
        self.inter += (pred*targ).float().sum().item()
        self.union += (pred+targ).float().sum().item()

In [None]:
#For testing: a fake learner and a metric that isn't an average
@delegates()
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None    

In [None]:
#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, x1, x2):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3): 
        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)
        met.accumulate(learn)
    return met.value

In [None]:
x1 = torch.randn(20,2,3,3)
x2 = torch.randint(0, 2, (20, 3, 3))
pred = x1.argmax(1)
inter = (pred*x2).float().sum().item()
union = (pred+x2).float().sum().item()
x1 = TensorImage(x1)
x2 = TensorMask(x2)
test_eq(compute_val(Dice(), x1, x2), 2*inter/union)

## Intersection-Over-Union

In [None]:
#export
class Iou(Dice):
    "Implemetation of the IoU (jaccard coefficient) that is lighter in RAM"
    @property
    def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None

In [None]:
test_eq(compute_val(Iou(), x1, x2), inter/(union-inter))

### Patch to show metrics

In [None]:
#export
#from https://forums.fast.ai/t/plotting-metrics-after-learning/69937
@patch
@delegates(subplots)
def plot_metrics(self: Recorder, nrows=None, ncols=None, figsize=None, **kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.show()

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted 09_gt.ipynb.
Converted add_information.ipynb.
Converted deepflash2.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
Converted tutorial.ipynb.
