In [None]:
# default_exp metrics

# Metrics

> Definition of the metrics that can be used in training models. See [fastai2 docs](https://dev.fast.ai/metrics) for more information.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#export 
from fastai2.metrics import Dice
from fastai2.torch_core import *
from fastai2.imports import *
from fastai2.learner import *
from fastcore.foundation import patch

## Dice coefficient / F1 score wrapper class
see [here](https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou) for F1/Dice score discussion.

In [None]:
#export
class Dice_f1(Dice):
    "Dice coefficient metric for binary target in segmentation"
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y[0])
        self.inter += (pred*targ).float().sum().item()
        self.union += (pred+targ).float().sum().item()

## Intersection-Over-Union wrapper class

In [None]:
#export
class Iou(Dice_f1):
    "Implemetation of the IoU (jaccard coefficient) that is lighter in RAM"
    @property
    def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None

### Patch to show metrics

In [None]:
#export
#from https://forums.fast.ai/t/plotting-metrics-after-learning/69937
@patch
@delegates(subplots)
def plot_metrics(self: Recorder, nrows=None, ncols=None, figsize=None, **kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.show()

Test

In [None]:
#tbd

## Export 

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted index.ipynb.
