In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
#export
from collections import defaultdict, OrderedDict
from enum import IntFlag
from operator import itemgetter
import sys
from typing import Union, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from loop.config import defaults
from loop.utils import merge_dicts, to_list, to_snake_case, classname, autoformat

In [3]:
class NamedList:
    def __init__(self, od: OrderedDict):
        self.od = od
        self.names = list(od)
        
    def __iter__(self):
        self.index = -1
        return self
    
    def __next__(self):
        self.index += 1
        if self.index >= len(self):
            raise StopIteration()
        return self[self.index]
    
    def __len__(self): 
        return len(self.od)
    
    def __getitem__(self, item):
        if isinstance(item, str):
            return self.od[item]
        elif isinstance(item, int):
            return self.od[self.names[item]]
        return TypeError(f'invalid index type: {type(item)}')

In [6]:
#export
class Phase:
    """
    Model training loop phase.

    Each model's training loop iteration could be separated into (at least) two
    phases: training and validation. The instances of this class track
    metrics and counters, related to the specific phase, and keep the reference
    to subset of data, used during phase.
    """
    def __init__(self, name: str, loader: 'DataLoader', grad: bool=True):
        self.name = name
        self.loader = loader
        self.grad = grad
        self.batch_loss = None
        self.batch_index = 0
        self.rolling_loss = 0
        self.losses = []
        self.metrics = OrderedDict()

    @property
    def last_loss(self):
        return self.losses[-1] if self.losses else None

    @property
    def last_metrics(self):
        metrics = OrderedDict()
        metrics[f'{self.name}_loss'] = self.last_loss
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values[-1]
        return metrics

    @property
    def metrics_history(self):
        metrics = OrderedDict()
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values
        return metrics

    def update(self, loss):
        self.losses.append(loss)

    def update_metric(self, name, value):
        if name not in self.metrics:
            self.metrics[name] = []
        self.metrics[name].append(value)
        
    @staticmethod
    def make_train_valid(trn_ds, val_ds, 
                         bs: int=defaults.bs,
                         num_workers: Union[Tuple, int]=0):
        """Creates two loop's phases, train and valid.
        
        The phases are thin wrappers on top of data loaders intended to track
        additional information gathered during model's fitting process, like 
        loss, performance metrics, etc.
        """
        trn, val = broadcast(num_workers, 2)
        phs = OrderedDict()
        phs['train'] = Phase('train', DataLoader(trn_ds, bs, shuffle=True, num_workers=trn))
        phs['valid'] = Phase('valid', DataLoader(val_ds, bs, num_workers=val), grad=False)
        return NamedList(phs)

In [7]:
od = OrderedDict()
od['first'] = 1
od['last'] = 2
nl = NamedList(od)
assert len(nl) == len(od)
assert [nl['first'], nl['last']] == list(nl)

In [8]:
def mock_loader(): return DataLoader(TensorDataset(torch.randn((1000, 10))))
train = Phase('train', mock_loader())
assert train.name == 'train'
assert train.loader is not None

In [9]:
from torchvision.datasets import MNIST

In [10]:
root = defaults.datasets/'mnist'
trn_ds = MNIST(root, train=True)
val_ds = MNIST(root, train=False)

In [11]:
phases = Phase.make_train_valid(trn_ds, val_ds)
assert len(phases) == 2
assert phases['train'].loader is not None
assert phases['valid'].loader is not None
assert phases['train'].grad 
assert not phases['valid'].grad

In [12]:
#export
class Callback:
    """The base class inherited by callbacks.
    
    Provides a lot of hooks invoked on various stages of the training loop
    execution. The signature of functions is as broad as possible to allow
    flexibility and customization in descendant classes.
    """
    def training_started(self, **kwargs): pass
    def training_ended(self, **kwargs): pass
    def epoch_started(self, **kwargs): pass
    def epoch_ended(self, **kwargs): pass
    def phase_started(self, **kwargs): pass
    def phase_ended(self, **kwargs): pass
    def batch_started(self, **kwargs): pass
    def batch_ended(self, **kwargs): pass
    def before_forward(self, **kwargs): pass
    def after_forward(self, **kwargs): pass
    def before_backward(self, **kwargs): pass
    def after_backward(self, **kwargs): pass
    def interrupted(self, **kwargs): pass

In [13]:
#export
class Order(IntFlag):
    Unknown = -1
    Internal = 0
    Loss = 10
    Metrics = 100
    Schedule = 200
    History = 300
    Logging = 1000
    
    def __call__(self, index=0):
        return Order(self.value + index)
    
    @staticmethod
    def sort(items):
        ordered = [(getattr(item, 'order', Order.Unknown), item) for item in items]
        ordered.sort(key=itemgetter(0))
        return [item for _, item in ordered]

In [14]:
#export
class RollingLoss(Callback):
    """A callback that tracks model's loss.
    
    The loss is interpolated between current and next value to get 
    a smoother loss curve.
    """
    order = Order.Loss()
    
    def __init__(self, smooth=0.98):
        self.smooth = smooth
    
    def batch_ended(self, phase, **kwargs):
        prev = phase.rolling_loss
        a = self.smooth
        avg_loss = a*prev + (1 - a)*phase.batch_loss
        debias_loss = avg_loss / (1 - a**phase.batch_index)
        phase.rolling_loss = avg_loss
        phase.update(debias_loss)
        
    def epoch_ended(self, phases, **kwargs):
        for phase in phases:
            phase.update_metric('loss', phase.last_loss)

In [15]:
#export
class History(Callback):
    """A callback that collects model's metrics during its training."""
    
    order = Order.History()
    
    def training_started(self, **kwargs):
        self.recorded = None
        self.epochs = []
    
    def epoch_ended(self, epoch, **kwargs):
        self.epochs.append(epoch)
        
    def training_ended(self, phases, **kwargs):
        epochs = {'epoch': np.array(self.epochs).astype(int)}
        metrics = [epochs] + [p.metrics_history for p in phases]
        data = pd.DataFrame(merge_dicts(metrics))
        data.reset_index(inplace=True, drop=True)
        self.recorded = data
        
    def plot(self, x='epoch', ys='loss', ax=None):
        return self.recorded.plot(x='epoch', y=to_list(ys), ax=ax)

In [16]:
def autoformat(v):
    """Tryies to convert value into a string using the best possible representation."""
    
    return (f'{v:d}' if isinstance(v, (int, np.int16, np.int32, np.int64)) else
            f'{v:.4f}' if isinstance(v, (float, np.float16, np.float32, np.float64)) else
            f'{str(v)}')

In [17]:
class StreamLogger(Callback):
    """
    Writes performance metrics collected during the training process into list
    of streams.
    
    Parameters:
        streams: A list of file-like objects with `write()` method.
    
    """
    order = Order.Logging()
    
    def __init__(self, streams: list=None, log_every: int=1):
        self.streams = streams or [sys.stdout]
        self.log_every = log_every
        
    def epoch_ended(self, phases, epoch, **kwargs):
        metrics = merge_dicts([p.last_metrics for p in phases])
        values = [f'{k}={autoformat(v)}' for k, v in metrics.items()]
        values_string = ', '.join(values)
        string = f'Epoch: {epoch:4d} | {values_string}\n'
        for stream in self.streams:
            stream.write(string)
            stream.flush()

In [18]:
def accuracy(out, y_true):
    y_hat = out.argmax(dim=-1).view(y_true.size(0), -1)
    y_true = y_true.view(y_true.size(0), -1)
    match = y_hat == y_true
    return match.float().mean()

In [19]:
def from_torch(tensor):
    obj = tensor.detach().cpu()
    if not obj.shape:
        return obj.item()
    return obj

In [20]:
class Average(Callback):
    
    def __init__(self, metric_fn: 'callable', name: str=None):
        self.metric_fn = metric_fn
        self.name = name or self.metric_fn.__name__
    
    def epoch_started(self, **kwargs):
        self.values = defaultdict(int)
        self.counts = defaultdict(int)
        
    def batch_ended(self, phase, output, target, **kwargs):
        metric = from_torch(self.metric_fn(output, target))
        self.counts[phase.name] += target.size(0)
        self.values[phase.name] += target.size(0) * metric
        
    def epoch_ended(self, phases, **kwargs):
        for phase in phases:
            metric = self.values[phase.name] / self.counts[phase.name]
            phase.update_metric(self.name, metric)

In [21]:
#export
class Group(Callback):
    def __init__(self, cbs):
        self._init(cbs)
        self._model = None
        
    def _init(self, cbs):
        if not cbs:
            self.callbacks = []
            self.named_callbacks = {}
        else:
            cbs = Order.sort(cbs)
            for cb in cbs:
                cb.group = self
            self.callbacks = cbs
            self.named_callbacks = {to_snake_case(classname(cb)): cb for cb in cbs}
        
    def add(self, cb, *cbs):
        cbs = [cb] + list(cbs)
        self._init(cbs)
        
    def set_model(self, model):
        self._model = model

    def training_started(self, **kwargs): self('training_started', **kwargs)    
    def training_ended(self, **kwargs): self('training_ended', **kwargs)
    def epoch_started(self, **kwargs): self('epoch_started', **kwargs)
    def epoch_ended(self, **kwargs): self('epoch_ended', **kwargs)
    def phase_started(self, **kwargs): self('phase_started', **kwargs)
    def phase_ended(self, **kwargs): self('phase_ended', **kwargs)
    def batch_started(self, **kwargs): self('batch_started', **kwargs)
    def batch_ended(self, **kwargs): self('batch_ended', **kwargs)
    def before_forward(self, **kwargs): self('before_forward', **kwargs)
    def after_forward(self, **kwargs): self('after_forward', **kwargs)
    def before_backward(self, **kwargs): self('before_forward', **kwargs)
    def after_backward(self, **kwargs): self('after_backward', **kwargs)
    def interrupted(self, **kwargs): self('interrupted', **kwargs)
        
    def __getattr__(self, item):
        if item in vars(self):
            return self.__dict__[item]
        if self._model is not None:
            return getattr(self._model, item)
        raise AttributeError(item)
    
    def __getitem__(self, item):
        item = to_snake_case(item)
        if item in self.named_callbacks:
            return self.named_callbacks[item]
        raise KeyError(
            f'callback name is not found: {item}; '
            f'available callbacks are: {list(sorted(self.named_callbacks))}')
    
    def __call__(self, name, **kwargs):
        for cb in self.callbacks:
            method = getattr(cb, name, None)
            if method is None:
                continue
            method(**kwargs)

In [22]:
#export
def default_optimizer(model, **params):
    if 'lr' not in params:
        params['lr'] = 0.001
    return optim.Adam(model.parameters(), **params)

In [23]:
#export
def create_callbacks(cbs, default: bool=True):
    defaults = [RollingLoss(), History(), StreamLogger()] if default else []
    cbs = list(cbs or [])
    cbs += defaults
    return Group(cbs)

In [24]:
#export
def unwrap_if_single(obj):
    """Converts obj collection into scalar if it contains single element only."""
    return obj[0] if len(obj) == 1 else obj

In [25]:
#export
def place_and_unwrap(batch, dev):
    """Places tensors from batch onto proper device and converts targets
    into proper shape depending on number of tensors.
    """
    batch = [t.to(dev) for t in batch]
    x, *ys = batch
    return x, unwrap_if_single(ys)

In [26]:
#export
class TrainingInterrupted(Exception):
    def __init__(self, context=None):
        self.context = context

In [27]:
#export
class Loop:
    def __init__(self, model: nn.Module, cbs: list=None,
                 default_cb: bool=True, opt_fn: 'callable'=default_optimizer,
                 opt_params: dict=None, device: 'device'=defaults.device,
                 loss_fn: 'callable'=defaults.loss_fn):
        
        model.to(device)
        opt = opt_fn(model, **(opt_params or {}))
        cb = create_callbacks(cbs, default_cb)
        cb.set_model(model)
        
        self.model = model
        self.opt = opt
        self.cb = cb
        self.loss_fn = loss_fn 
        self.device = device
        
    def fit_datasets(self, trn_ds, val_ds, epochs: int=1, batch_size: int=defaults.bs):
        phases = Phase.make_train_valid(
            trn_ds, val_ds, bs=batch_size, num_workers=defaults.n_jobs)
        self.train(phases, epochs)
        
    def train(self, phases: list, epochs: int=1):
        try:
            self.cb.training_started(phases=phases)
            for epoch in range(1, epochs + 1):
                self.train_one_epoch(phases, epoch)
            self.cb.training_ended(phases=phases)
        except TrainingInterrupted as e:
            self.cb.interrupted(reason=e)
    
    def train_one_epoch(self, phases: list, curr_epoch: int=1):
        cb, model, opt = self.cb, self.model, self.opt
        
        cb.epoch_started(epoch=curr_epoch)

        for phase in phases:
            n = len(phase.loader)
            cb.phase_started(phase=phase, total_batches=n)
            is_training = phase.grad
            model.train(is_training)
            
            for batch in phase.loader:
                phase.batch_index += 1
                cb.batch_started(phase=phase, total_batches=n)
                x, y = place_and_unwrap(batch, self.device)
                
                with torch.set_grad_enabled(is_training):
                    cb.before_forward()
                    out = model(x)
                    cb.after_forward()
                    loss = self.loss_fn(out, y)
                
                if is_training:
                    opt.zero_grad()
                    cb.before_backward()
                    loss.backward()
                    opt.step()
                
                phase.batch_loss = loss.item()
                cb.batch_ended(phase=phase, output=out, target=y)
                
            cb.phase_ended(phase=phase)
                    
        cb.epoch_ended(phases=phases, epoch=curr_epoch)

In [28]:
from torch import nn
from torch import optim
from torch.nn import functional as F

class TinyNet(nn.Module):
    def __init__(self, n_out=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(64, 64)
        self.fc2 = nn.Linear(64, n_out)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [29]:
from torchvision.datasets import MNIST
from torchvision import transforms as T

In [30]:
def get_mnist():
    root = defaults.datasets/'mnist'

    mnist_stats = ([0.15]*1, [0.15]*1)

    trn_ds = MNIST(root, train=True, transform=T.Compose([
        T.Resize(32),
        T.RandomAffine(5, translate=(0.05, 0.05), scale=(0.9, 1.1)),
        T.ToTensor(),
        T.Normalize(*mnist_stats)
    ]))
    val_ds = MNIST(root, train=False, transform=T.Compose([
        T.Resize(32),
        T.ToTensor(),
        T.Normalize(*mnist_stats)
    ]))
    
    return trn_ds, val_ds

In [31]:
loop = Loop(TinyNet(), cbs=[Average(accuracy)], loss_fn=F.cross_entropy)

In [32]:
loop.fit_datasets(*get_mnist(), epochs=10, batch_size=1048)

Epoch:    1 | train_loss=2.2221, train_accuracy=0.1807, valid_loss=2.0829, valid_accuracy=0.2290
Epoch:    2 | train_loss=2.0750, train_accuracy=0.2664, valid_loss=1.9681, valid_accuracy=0.3166
Epoch:    3 | train_loss=1.9072, train_accuracy=0.3431, valid_loss=1.8346, valid_accuracy=0.3995
Epoch:    4 | train_loss=1.7552, train_accuracy=0.3929, valid_loss=1.7355, valid_accuracy=0.4380
Epoch:    5 | train_loss=1.6456, train_accuracy=0.4248, valid_loss=1.6583, valid_accuracy=0.4589
Epoch:    6 | train_loss=1.5627, train_accuracy=0.4476, valid_loss=1.5953, valid_accuracy=0.5147
Epoch:    7 | train_loss=1.4994, train_accuracy=0.4670, valid_loss=1.5420, valid_accuracy=0.5079
Epoch:    8 | train_loss=1.4458, train_accuracy=0.4869, valid_loss=1.4946, valid_accuracy=0.5427
Epoch:    9 | train_loss=1.3994, train_accuracy=0.5025, valid_loss=1.4515, valid_accuracy=0.5340
Epoch:   10 | train_loss=1.3556, train_accuracy=0.5220, valid_loss=1.4106, valid_accuracy=0.5796
