In [None]:
# default_exp metric

# Metric

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastcore.basics import GetAttr, store_attr, ifnone
from types import FunctionType
from fastai.metrics import Metric, AvgMetric, AccumMetric

In [None]:
#export

class _LearnerProxy(GetAttr):
    _default = 'learn'
    def __init__(self, learn, pred_idx, target_idx):
        store_attr()
        self.pred = self.learn.pred[pred_idx]
        self.y = self.learn.y[target_idx]


class RoutedAccumMetric(AccumMetric, GetAttr):
    "AccumMetric with predictions and targets for a specific model head."
    _default = 'metric'
    def __init__(self, metric, pred_idx, target_idx):
        self.metric = metric
        self._name = metric.name
        self.pred_idx = pred_idx
        self.target_idx = target_idx
    
    def reset(self):
        "Clear all targs and preds"
        return self.metric.reset()
    
    def accumulate(self, learn):
        "Store targs and preds from `learn`, using activation function and argmax as appropriate"
        return self.metric.accumulate(_LearnerProxy(learn, self.pred_idx, self.target_idx))
    
    def __call__(self, preds, targs):
        "Calculate metric on one batch of data"
        return self.metric(preds[self.pred_idx], targs[self.target_idx])
    
    @property
    def name(self):
        return self._name
    
    @name.setter
    def name(self, value):
        self._name = value
    
def route_to_metric(metric, pred_idx, target_idx):
    """Routes model output at `pred_idx` and target at index `target_idx` to metric"""
    if isinstance(metric, type): 
        metric = metric()
    if isinstance(metric, FunctionType):
        func = lambda preds, *targs, **kwargs: metric(preds[pred_idx], targs[target_idx], **kwargs)
        func.__name__ = metric.__name__
        return AvgMetric(func)
    if isinstance(metric, Metric):
        return RoutedAccumMetric(metric, pred_idx, target_idx)
    raise ValueError("Unsupported metric type; must be either function or Metric")


In [None]:
#hide

import torch
from fastai.metrics import accuracy, R2Score, F1Score

bs = 8
target1, pred1 = torch.randint(0, 5, [bs], dtype=float), torch.randint(0, 5, [bs], dtype=float)
target2, pred2 = torch.randn(bs), torch.randn(bs)

preds = [pred1, pred2]
targets = [target1, target2]

inner_f1_macro = F1Score(average='macro')
f1_macro = route_to_metric(inner_f1_macro, 0, 0)
test_close(f1_macro(preds, targets), inner_f1_macro(preds[0], targets[0]))

inner_r2 = R2Score()
r2 = route_to_metric(inner_r2, 1, 1)
test_close(r2(preds, targets), inner_r2(preds[1], targets[1]))

In [None]:
#export

def mtl_metrics(*metrics_list):
    """Convenience function to route each prediction to list of metrics by their order."""
    return [route_to_metric(m, i, i) for i, metrics in enumerate(metrics_list) for m in metrics]

In [None]:
#hide
routed_metrics = mtl_metrics([F1Score(average='macro'), accuracy], [R2Score])
test_eq(3, len(routed_metrics))

Example usage for routing the first model prediction to `F1Score` and `accuracy`, while routing the second one to `R2Score`.
```py
learn = Learner(...
    metrics = mtl_metrics([F1Score(average='macro'), accuracy], [R2Score])
)
```