In [3]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
#export
import torch
from torch import tensor

In [13]:
#export
from exp.nb_utils import snakify_class_name

In [14]:
#export
class Metric:
    def reset(self): pass
    def accumulate(self, learner): pass

    @property
    def value(self): raise NotImplementedError

    @property
    def name(self): return snakify_class_name(self, "Metric")

In [15]:
#export
class AvgMetric(Metric):
    def __init__(self, func):
        self.func = func
        self.reset()

    def reset(self):
        self.total = tensor(0.)
        self.count = 0

    def accumulate(self, learner):
        bs = learner.xb.shape[0]
        self.total += self.func(learner.pred, learner.yb).cpu() * bs
        self.count += bs

    @property
    def value(self):
        return self.total / self.count if self.count else None

    @property
    def name(self):
        return self.func.__name__

In [22]:
#export
class AvgLoss(Metric):
    def reset(self):
        self.total = tensor(0.)
        self.count = 0

    def accumulate(self, learner):
        bs = learner.xb.shape[0]
        self.total = learner.loss.cpu().mean() * bs
        self.count += bs

    @property
    def value(self):
        return self.total / self.count if self.count else None

    @property
    def name(self): return "loss"

In [23]:
#export
class AvgSmoothLoss(Metric):
    def __init__(self, beta=0.98):
        self.beta = beta

    def reset(self):
        self.val = tensor(0.)
        self.count = 0

    def accumulate(self, learner):
        self.count += 1
        self.val = torch.lerp(learner.loss.cpu().mean(),
                              self.val,
                              self.beta)

    @property
    def value(self):
        return self.val/(1-self.beta**self.count)

    @property
    def name(self): return "smooth_loss"

In [4]:
!python notebook2script.py metrics.ipynb

Converted metrics.ipynb to exp/nb_metrics.py
