In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.layers import *
from local.data.pipeline import *
from local.data.source import *
from local.data.core import *
from local.data.external import *
from local.notebook.showdoc import show_doc
from local.optimizer import *
from local.learner import *
from local.callback.progress import *

In [None]:
#default_exp metrics
# default_cls_lvl 3

# Metrics

> Definition of the metrics that can be used in training models

## Core metric

This is where the function that converts scikit-learn metrics to fastai metrics is defined. You should skip this section unless you want to know all about the internals of fastai.

In [None]:
import sklearn.metrics as skm

In [None]:
#export core
def flatten_check(inp, targ, detach=True):
    "Check that `out` and `targ` have the same number of elements and flatten them."
    inp,targ = to_detach(inp.contiguous().view(-1)),to_detach(targ.contiguous().view(-1))
    test_eq(len(inp), len(targ))
    return inp,targ

In [None]:
x1,x2 = torch.randn(5,4),torch.randn(20)
x1,x2 = flatten_check(x1,x2)
test_eq(x1.shape, [20])
test_eq(x2.shape, [20])
x1,x2 = torch.randn(5,4),torch.randn(21)
test_fail(lambda: flatten_check(x1,x2))

In [None]:
#export
class AccMetric(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, dim_argmax=None, sigmoid=False, thresh=None, to_np=False, invert_arg=False, **kwargs): 
        self.func,self.dim_argmax,self.sigmoid,self.thresh = func,dim_argmax,sigmoid,thresh
        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs
    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred.argmax(dim=self.dim_argmax) if self.dim_argmax else learn.pred
        if self.sigmoid: pred = torch.sigmoid(pred)
        if self.thresh:  pred = (pred >= self.thresh)
        pred,targ = flatten_check(pred, learn.yb)
        self.preds.append(pred)
        self.targs.append(targ)
    
    @property
    def value(self): 
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)

`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/trianing phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).

For classification problems with single label, predictions need to be transformed with a sofmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, we can just apply the argmax. Pass along `dim_argmax` to have this done by `AccMetric` (usually -1 will work pretty well).

For classification problems with multiple labels, or if your targets are onehot-encoded, predictions may need to pass through a sigmoid (if it wasn't included in your model) then be compared to a given threshold (to decide between 0 and 1), this is done by `AccMetric` if you pass `sigmoid=True` and/or a value for `thresh`.

If you want to use a metric function sklearn.metrics, you will need to convert predictions and labels to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from us, so you will need to pass `invert_arg=True` to make `AccMetric` do the inversion for you.

In [None]:
#For testing: a fake learner and a metric that isn't an average
class TstLearner():
    def __init__(self): self.pred,self.yb = None,None

def _l2_mean(x,y): return torch.sqrt((x-y).float().pow(2).mean())

#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, x1, x2):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3): 
        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]
        met.accumulate(learn)
    return met.value

In [None]:
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccMetric(_l2_mean)
test_close(compute_val(tst, x1, x2), _l2_mean(x1, x2))
test_eq(torch.cat(tst.preds), x1.view(-1))
test_eq(torch.cat(tst.targs), x2.view(-1))

#test argmax
x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
tst = AccMetric(_l2_mean, dim_argmax=-1)
test_close(compute_val(tst, x1, x2), _l2_mean(x1.argmax(dim=-1), x2))

#test thresh
x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).byte()
tst = AccMetric(_l2_mean, thresh=0.5)
test_close(compute_val(tst, x1, x2), _l2_mean((x1 >= 0.5), x2))

#test sigmoid
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccMetric(_l2_mean, sigmoid=True)
test_close(compute_val(tst, x1, x2), _l2_mean(torch.sigmoid(x1), x2))

#test to_np
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccMetric(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)
assert compute_val(tst, x1, x2)

#test invert_arg
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccMetric(lambda x,y: torch.sqrt(x.pow(2).mean()))
test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))
tst = AccMetric(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)
test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))

In [None]:
#export
def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    dim_argmax = axis if is_class and thresh is None else None
    sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)
    return AccMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, thresh=thresh, 
                     to_np=True, invert_arg=True, **kwargs)

This is the quickest way to use a sckit-learn metric in a fastai training loop. `is_class` indicates if you are in a classification problem or not. In this case:
- leaving `thresh` to `None` indicates it's a single-label classification problem and predictions will pass through an argmax over `axis` before being compared to the targets
- setting a value for `thresh` indicates it's a multi-label classification problem and predictions will pass through a sigmoid (can be deactivated with `sigmoid=False`) and be compared to `thresh` before being compared to the targets

If `is_class=False`, it indicates you are in a regression problem, and predictions are compared to the targets without being modified. In all cases, `kwargs` are extra keyword arguments passed to `func`.

In [None]:
tst_single = skm_to_fastai(skm.precision_score)
x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))
test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))

In [None]:
tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))

tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2, sigmoid=False)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))

In [None]:
tst_reg = skm_to_fastai(skm.r2_score, is_class=False)
x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))

In [None]:
def metric_func(met, **kwargs):
    return partial(met, **kwargs)()

In [None]:
def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):
    return skm_to_fastai(skm.precision_score, axis=axis,
                        labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)

In [None]:
precision = metric_func(Precision)

In [None]:
tst_single = partial(precision, 
x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))
test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))

In [None]:
pr

In [None]:
skm.precision_score() 'labels=None', 'pos_label=1', "average='binary'", 'sample_weight=None'],

## Single-label classification

> Warning: All functions defined in this section are intended for single-label classification and targets that aren't one-hot encoded. For multi-label problems or one-hot encoded targets, use the `_multi` version of them.

In [None]:
#export
def accuracy(inp, targ):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = flatten_check(inp.argmax(dim=-1), targ)
    return (pred == targ).float().mean()

In [None]:
#For testing
def change_targ(targ, n, c):
    idx = torch.randperm(len(targ))[:n]
    res = targ.clone()
    for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c
    return res

In [None]:
x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(accuracy(x,y), 1)
y1 = change_targ(y, 2, 5)
test_eq(accuracy(x,y1), 0.5)
test_eq(accuracy(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)

In [None]:
#export
def error_rate(inp, targ):
    "1 - `accuracy`"
    return 1 - accuracy(inp, targ)

In [None]:
x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(error_rate(x,y), 0)
y1 = change_targ(y, 2, 5)
test_eq(error_rate(x,y1), 0.5)
test_eq(error_rate(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)

In [None]:
#export
def top_k_accuracy(inp, targ, k=5):
    "Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)"
    inp = inp.topk(k=k, dim=-1)[1]
    targ = targ.unsqueeze(dim=-1).expand_as(inp)
    return (inp == targ).sum(dim=-1).float().mean()

In [None]:
x = torch.randn(6,5)
y = torch.arange(0,6)
test_eq(top_k_accuracy(x[:5],y[:5]), 1)
test_eq(top_k_accuracy(x, y), 5/6)

In [None]:
#export
def ap_score(axis=-1, average='macro', pos_label=1, sample_weight=None):
    "Average Precision for single-label classification problems"
    return skm_to_fastai(skm.average_precision_score, axis=axis, 
                         average=average, pos_label=pos_label, sample_weight=sample_weight)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details.

In [None]:
#export
def balanced_accuracy(axis=-1, sample_weight=None, adjusted=False):
    "Balanced Accuracy for single-label binary classification problems"
    return skm_to_fastai(skm.balanced_accuracy_score, axis=axis, 
                         sample_weight=sample_weight, adjusted=adjusted)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html#sklearn.metrics.balanced_accuracy_score) for more details.

In [None]:
#export
def brier_score(axis=-1, sample_weight=None, pos_label=None):
    "Brier score for single-label classification problems"
    return skm_to_fastai(skm.balanced_accuracy_score, axis=axis, 
                         sample_weight=sample_weight, pos_label=pos_label)

See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details.

In [None]:
#export
def cohen_kappa(axis=-1, labels=None, weights=None, sample_weight=None):
    "Cohen kappa for single-label classification problems"
    return skm_to_fastai(skm.balanced_accuracy_score, axis=axis, 
                         sample_weight=sample_weight, pos_label=pos_label)

In [None]:
cohen_kappa_score(y1, y2, labels=None, weights=None, sample_weight=None)

In [None]:
def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):
    return skm_to_fastai(skm.precision_score, axis=axis,
                        labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)

## Multi-label classification

In [None]:
def change_1h_targ(targ, n):
    idx = torch.randperm(targ.numel())[:n]
    res = targ.clone().view(-1)
    for i in idx: res[i] = 1-res[i]
    return res.view(targ.shape)

In [None]:
#export
def accuracy_thresh(inp, targ, thresh=0.5, sigmoid=True):
    "Compute accuracy when `inp` and `targ` are the same size."
    inp,targ = flatten_check(inp,targ)
    if sigmoid: inp = inp.sigmoid()
    return ((inp>thresh)==targ.byte()).float().mean()

In [None]:
x = torch.randn(4,5)
y = torch.sigmoid(x) >= 0.5
test_eq(accuracy_thresh(x,y), 1)
test_eq(accuracy_thresh(x,1-y), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_thresh(x,y1), 0.75)

#Different thresh
y = torch.sigmoid(x) >= 0.2
test_eq(accuracy_thresh(x,y, thresh=0.2), 1)
test_eq(accuracy_thresh(x,1-y, thresh=0.2), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_thresh(x,y1, thresh=0.2), 0.75)

#No sigmoid
y = x >= 0.5
test_eq(accuracy_thresh(x,y, sigmoid=False), 1)
test_eq(accuracy_thresh(x,1-y, sigmoid=False), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_thresh(x,y1, sigmoid=False), 0.75)

## Regression

In [None]:
@metric()
def rmse(inp, targ):
    "Root mean squared error betzeen `inp` and `targ`"
    return torch.sqrt(F.mse_loss(inp, targ))

In [None]:
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = rmse
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.pred,learn.yb = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]
    tst.accumulate(learn)
test_eq(tst.value, torch.sqrt(F.mse_loss(x1,x2)))

In [None]:
@metric()
def exp_rmspe(inp, targ):
    "Root mean square percentage error of the exponential of `inp` and `targ`"
    inp,targ = torch.exp(inp),torch.exp(targ)
    return torch.sqrt(((targ - inp)/targ).pow(2).mean())

In [None]:
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = exp_rmspe
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.pred,learn.yb = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]
    tst.accumulate(learn)
test_eq(tst.value, torch.sqrt((((torch.exp(x2) - torch.exp(x1))/torch.exp(x2))**2).mean()))

In [None]:
def mae(inp,targ):
    "Mean absolute error between `inp` and `targ`."
    inp,targ = flatten_check(inp,targ)
    return torch.abs(inp - targ).mean()

In [None]:
x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_eq(mae(x1,x2), torch.abs(x1-x2).mean())

In [None]:
def mse(inp,targ):
    "Mean squared error between `inp` and `targ`."
    return F.mse_loss(*flatten_check(inp,targ))

In [None]:
x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_close(mse(x1,x2), (x1-x2).pow(2).mean())

In [None]:
def msle(inp, targ):
    "Mean squared logarithmic error between `inp` and `targ`."
    inp,targ = flatten_check(inp,targ)
    return F.mse_loss(torch.log(1 + inp), torch.log(1 + targ))

In [None]:
x1,x2 = torch.randn(4,5),torch.randn(4,5)
x1,x2 = torch.relu(x1),torch.relu(x2)
test_close(msle(x1,x2), (torch.log(x1+1)-torch.log(x2+1)).pow(2).mean())

In [None]:
@metric(to_np=True)
def explained_variance(inp, targ):
    "Explained variance between `inp` and `targ`"
    return skmets.explained_variance_score(targ,inp)

In [None]:
from sklearn.metrics import explained_variance_score
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = explained_variance
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.pred,learn.yb = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]
    tst.accumulate(learn)
test_close(tst.value, skmets.explained_variance_score(x2.view(-1).numpy(),x1.view(-1).numpy()))

In [None]:
@metric(to_np=True)
def r2_score(inp, targ):
    "R2 score (coefficient of determination)"
    return skmets.r2_score(targ,inp)

In [None]:
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = r2_score
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.pred,learn.yb = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]
    tst.accumulate(learn)
test_close(tst.value, skmets.r2_score(x2.view(-1),x1.view(-1)))

## Segmentation 

In [None]:
def foreground_acc(inp, targ, bkg_idx=0):
    "Computes non-background accuracy for multiclass segmentation"
    targ = targ.squeeze(1)
    mask = targ != bkg_idx
    return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()

In [None]:
x = torch.randn(4,5,3,3)
y = x.argmax(dim=1)[:,None]
test_eq(foreground_acc(x,y), 1)
y[0] = 0 #the 0s are ignored so we get the same value
test_eq(foreground_acc(x,y), 1)

In [None]:
def fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor:
    "Computes the f_beta between `preds` and `targets`"
    beta2 = beta ** 2
    if sigmoid: y_pred = y_pred.sigmoid()
    y_pred = (y_pred>thresh).float()
    y_true = y_true.float()
    TP = (y_pred*y_true).sum(dim=1)
    prec = TP/(y_pred.sum(dim=1)+eps)
    rec = TP/(y_true.sum(dim=1)+eps)
    res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2)
    return res.mean()


def dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
    "Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    targs = targs.view(n,-1)
    intersect = (input * targs).sum().float()
    union = (input+targs).sum().float()
    if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
    else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze())

class ConfusionMatrix(Callback):
    "Computes the confusion matrix."

    def on_train_begin(self, **kwargs):
        self.n_classes = 0

    def on_epoch_begin(self, **kwargs):
        self.cm = None

    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        preds = last_output.argmax(-1).view(-1).cpu()
        targs = last_target.cpu()
        if self.n_classes == 0:
            self.n_classes = last_output.shape[-1]
            self.x = torch.arange(0, self.n_classes)
        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)
        if self.cm is None: self.cm =  cm
        else:               self.cm += cm

    def on_epoch_end(self, **kwargs):
        self.metric = self.cm

@dataclass
class CMScores(ConfusionMatrix):
    "Base class for metrics which rely on the calculation of the precision and/or recall score."
    average:Optional[str]="binary"      # `binary`, `micro`, `macro`, `weigthed` or None
    pos_label:int=1                     # 0 or 1
    eps:float=1e-9

    def _recall(self):
        rec = torch.diag(self.cm) / self.cm.sum(dim=1)
        if self.average is None: return rec
        else:
            if self.average == "micro": weights = self._weights(avg="weighted")
            else: weights = self._weights(avg=self.average)
            return (rec * weights).sum()

    def _precision(self):
        prec = torch.diag(self.cm) / self.cm.sum(dim=0)
        if self.average is None: return prec
        else:
            weights = self._weights(avg=self.average)
            return (prec * weights).sum()

    def _weights(self, avg:str):
        if self.n_classes != 2 and avg == "binary":
            avg = self.average = "macro"
            warn("average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.")
        if avg == "binary":
            if self.pos_label not in (0, 1):
                self.pos_label = 1
                warn("Invalid value for pos_label. It has now been set to 1.")
            if self.pos_label == 1: return Tensor([0,1])
            else: return Tensor([1,0])
        elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
        elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
        elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()


class Recall(CMScores):
    "Compute the Recall."
    def on_epoch_end(self, last_metrics, **kwargs): 
        return add_metrics(last_metrics, self._recall())

class Precision(CMScores):
    "Compute the Precision."
    def on_epoch_end(self, last_metrics, **kwargs): 
        return add_metrics(last_metrics, self._precision())

@dataclass
class FBeta(CMScores):
    "Compute the F`beta` score."
    beta:float=2

    def on_train_begin(self, **kwargs):
        self.n_classes = 0
        self.beta2 = self.beta ** 2
        self.avg = self.average
        if self.average != "micro": self.average = None

    def on_epoch_end(self, last_metrics, **kwargs):
        prec = self._precision()
        rec = self._recall()
        metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps)
        metric[metric != metric] = 0  # removing potential "nan"s
        if self.avg: metric = (self._weights(avg=self.avg) * metric).sum()
        return add_metrics(last_metrics, metric)

    def on_train_end(self, **kwargs): self.average = self.avg

@dataclass
class KappaScore(ConfusionMatrix):
    "Compute the rate of agreement (Cohens Kappa)."
    weights:Optional[str]=None      # None, `linear`, or `quadratic`

    def on_epoch_end(self, last_metrics, **kwargs):
        sum0 = self.cm.sum(dim=0)
        sum1 = self.cm.sum(dim=1)
        expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
        if self.weights is None:
            w = torch.ones((self.n_classes, self.n_classes))
            w[self.x, self.x] = 0
        elif self.weights == "linear" or self.weights == "quadratic":
            w = torch.zeros((self.n_classes, self.n_classes))
            w += torch.arange(self.n_classes, dtype=torch.float)
            w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
        else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
        k = torch.sum(w * self.cm) / torch.sum(w * expected)
        return add_metrics(last_metrics, 1-k)

@dataclass
class MatthewsCorreff(ConfusionMatrix):
    "Compute the Matthews correlation coefficient."
    def on_epoch_end(self, last_metrics, **kwargs):
        t_sum = self.cm.sum(dim=1)
        p_sum = self.cm.sum(dim=0)
        n_correct = torch.trace(self.cm)
        n_samples = p_sum.sum()
        cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum)
        cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum)
        cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum)
        return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp))

class Perplexity(Callback):
    "Perplexity metric for language models."
    def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0

    def on_batch_end(self, last_output, last_target, **kwargs):
        self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target)
        self.len += last_target.size(1)

    def on_epoch_end(self, last_metrics, **kwargs): 
        return add_metrics(last_metrics, torch.exp(self.loss / self.len))

def auc_roc_score(input:Tensor, targ:Tensor):
    "Using trapezoid method to calculate the area under roc curve"
    fpr, tpr = roc_curve(input, targ)
    d = fpr[1:] - fpr[:-1]
    sl1, sl2 = [slice(None)], [slice(None)]
    sl1[-1], sl2[-1] = slice(1, None), slice(None, -1)
    return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1)

def roc_curve(input:Tensor, targ:Tensor):
    "Returns the false positive and true positive rates"
    targ = (targ == 1)
    desc_score_indices = torch.flip(input.argsort(-1), [-1])
    input = input[desc_score_indices]
    targ = targ[desc_score_indices]
    d = input[1:] - input[:-1]
    distinct_value_indices = torch.nonzero(d).transpose(0,1)[0]
    threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device)))
    tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs]
    fps = (1 + threshold_idxs - tps)
    if tps[0] != 0 or fps[0] != 0:
        fps = torch.cat((LongTensor([0]), fps))
        tps = torch.cat((LongTensor([0]), tps))
    fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1]
    return fpr, tpr

@dataclass
class AUROC(Callback):
    "Calculate the auc score based on the roc curve. Restricted to the binary classification task."
    def on_epoch_begin(self, **kwargs):
        self.targs, self.preds = LongTensor([]), Tensor([])
        
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = F.softmax(last_output, dim=1)[:,-1]
        self.preds = torch.cat((self.preds, last_output.cpu()))
        self.targs = torch.cat((self.targs, last_target.cpu().long()))
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))

class MultiLabelFbeta(LearnerCallback):
    "Computes the fbeta score for multilabel classification"
    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
    _order = -20 
    def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average="micro"):
        super().__init__(learn)
        self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \
            eps, thresh, sigmoid, average, beta**2

    def on_train_begin(self, **kwargs):
        self.c = self.learn.data.c
        if self.average != "none": self.learn.recorder.add_metric_names([f'{self.average}_fbeta'])
        else: self.learn.recorder.add_metric_names([f"fbeta_{c}" for c in self.learn.data.classes])

    def on_epoch_begin(self, **kwargs):
        dvc = self.learn.data.device
        self.tp = torch.zeros(self.c).to(dvc)
        self.total_pred = torch.zeros(self.c).to(dvc)
        self.total_targ = torch.zeros(self.c).to(dvc)
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte()
        m = pred*targ
        self.tp += m.sum(0).float()
        self.total_pred += pred.sum(0).float()
        self.total_targ += targ.sum(0).float()
    
    def fbeta_score(self, precision, recall):
        return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps)

    def on_epoch_end(self, last_metrics, **kwargs):
        self.total_pred += self.eps
        self.total_targ += self.eps
        if self.average == "micro":
            precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum()
            res = self.fbeta_score(precision, recall)
        elif self.average == "macro":
            res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean()
        elif self.average == "weighted":
            scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))
            res = (scores*self.total_targ).sum() / self.total_targ.sum()
        elif self.average == "none":
            res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)))
        else:
            raise Exception("Choose one of the average types: [micro, macro, weighted, none]")
        
        return add_metrics(last_metrics, res)