In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
#export
from collections import OrderedDict
from typing import Union, Tuple

In [3]:
#export
from loop.config import defaults

In [23]:
#export
class Phase:
    """
    Model training loop phase.

    Each model's training loop iteration could be separated into (at least) two
    phases: training and validation. The instances of this class track
    metrics and counters, related to the specific phase, and keep the reference
    to subset of data, used during phase.
    """
    def __init__(self, name: str, loader: 'DataLoader', grad: bool=True):
        self.name = name
        self.loader = loader
        self.grad = grad
        self.batch_loss = None
        self.batch_index = 0
        self.rolling_loss = 0
        self.losses = []
        self.metrics = OrderedDict()

    @property
    def last_loss(self):
        return self.losses[-1] if self.losses else None

    @property
    def last_metrics(self):
        metrics = OrderedDict()
        metrics[f'{self.name}_loss'] = self.last_loss
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values[-1]
        return metrics

    @property
    def metrics_history(self):
        metrics = OrderedDict()
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values
        return metrics

    def update(self, loss):
        self.losses.append(loss)

    def update_metric(self, name, value):
        if name not in self.metrics:
            self.metrics[name] = []
        self.metrics[name].append(value)
        
    @staticmethod
    def make_train_valid(trn_ds, val_ds, 
                         bs: int=defaults.bs,
                         num_workers: Union[Tuple, int]=0):
        """Creates two loop's phases, train and valid.
        
        The phases are thin wrappers on top of data loaders intended to track
        additional information gathered during model's fitting process, like 
        loss, performance metrics, etc.
        """
        trn, val = unwrap(num_workers, 2)
        defs = [
            ('train', DataLoader(trn_ds, bs, shuffle=True, num_workers=trn)),
            ('valid', DataLoader(val_ds, bs, shuffle=False, num_workers=val)]
        phs = OrderedDict()
        for name, loader in defs:
            phs[name] = Phase(name, loader)
        return phs

SyntaxError: invalid syntax (<ipython-input-23-3ae1ed1dfdf9>, line 61)

In [24]:
#export
import torch
from torch.utils.data import DataLoader, TensorDataset

In [6]:
def mock_loader(): return DataLoader(TensorDataset(torch.randn((1000, 10))))
train = Phase('train', mock_loader())
assert train.name == 'train'
assert train.loader is not None

In [7]:
#export
def is_scalar(obj):
    return isinstance(obj, (int, float, str, complex))

In [8]:
#export
def unwrap(obj, pad=1):
    """Convenience function to unwrap collections and broadcast scalars."""
    if is_scalar(obj): 
        return [obj]*pad
    return obj

In [10]:
from torchvision.datasets import MNIST

In [16]:
root = defaults.datasets/'mnist'
trn_ds = MNIST(root, train=True)
val_ds = MNIST(root, train=False)

In [22]:
phases = Phase.make_train_valid(trn_ds, val_ds)
assert len(phases) == 2
assert phases['train'].loader is not None
assert phases['valid'].loader is not None
assert phases['train'].grad 
assert not phases['valid'].grad

AssertionError: 

OrderedDict([('train', <__main__.Phase at 0x7fdf48212e10>),
             ('valid', <__main__.Phase at 0x7fdf48212c18>)])