In [41]:
%reload_ext autoreload
%autoreload 2

In [42]:
#export
from collections import OrderedDict
from typing import Union, Tuple
from enum import IntEnum
from operator import itemgetter

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from loop.config import defaults
from loop.utils import merge_dicts, to_list

In [4]:
#export
class Phase:
    """
    Model training loop phase.

    Each model's training loop iteration could be separated into (at least) two
    phases: training and validation. The instances of this class track
    metrics and counters, related to the specific phase, and keep the reference
    to subset of data, used during phase.
    """
    def __init__(self, name: str, loader: 'DataLoader', grad: bool=True):
        self.name = name
        self.loader = loader
        self.grad = grad
        self.batch_loss = None
        self.batch_index = 0
        self.rolling_loss = 0
        self.losses = []
        self.metrics = OrderedDict()

    @property
    def last_loss(self):
        return self.losses[-1] if self.losses else None

    @property
    def last_metrics(self):
        metrics = OrderedDict()
        metrics[f'{self.name}_loss'] = self.last_loss
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values[-1]
        return metrics

    @property
    def metrics_history(self):
        metrics = OrderedDict()
        for name, values in self.metrics.items():
            metrics[f'{self.name}_{name}'] = values
        return metrics

    def update(self, loss):
        self.losses.append(loss)

    def update_metric(self, name, value):
        if name not in self.metrics:
            self.metrics[name] = []
        self.metrics[name].append(value)
        
    @staticmethod
    def make_train_valid(trn_ds, val_ds, 
                         bs: int=defaults.bs,
                         num_workers: Union[Tuple, int]=0):
        """Creates two loop's phases, train and valid.
        
        The phases are thin wrappers on top of data loaders intended to track
        additional information gathered during model's fitting process, like 
        loss, performance metrics, etc.
        """
        trn, val = unwrap(num_workers, 2)
        phs = OrderedDict()
        phs['train'] = Phase('train', DataLoader(trn_ds, bs, shuffle=True, num_workers=trn))
        phs['valid'] = Phase('valid', DataLoader(val_ds, bs, num_workers=val), grad=False)
        return phs

In [5]:
def mock_loader(): return DataLoader(TensorDataset(torch.randn((1000, 10))))
train = Phase('train', mock_loader())
assert train.name == 'train'
assert train.loader is not None

In [6]:
#export
def is_scalar(obj):
    return isinstance(obj, (int, float, str, complex))

In [7]:
#export
def unwrap(obj, pad=1):
    """Convenience function to unwrap collections and broadcast scalars."""
    if is_scalar(obj): 
        return [obj]*pad
    return obj

In [8]:
from torchvision.datasets import MNIST

In [9]:
root = defaults.datasets/'mnist'
trn_ds = MNIST(root, train=True)
val_ds = MNIST(root, train=False)

In [10]:
phases = Phase.make_train_valid(trn_ds, val_ds)
assert len(phases) == 2
assert phases['train'].loader is not None
assert phases['valid'].loader is not None
assert phases['train'].grad 
assert not phases['valid'].grad

In [11]:
#export
class Callback:
    """The base class inherited by callbacks.
    
    Provides a lot of hooks invoked on various stages of the training loop
    execution. The signature of functions is as broad as possible to allow
    flexibility and customization in descendant classes.
    """
    def training_started(self, **kwargs): pass
    def training_ended(self, **kwargs): pass
    def epoch_started(self, **kwargs): pass
    def epoch_ended(self, **kwargs): pass
    def phase_started(self, **kwargs): pass
    def phase_ended(self, **kwargs): pass
    def batch_started(self, **kwargs): pass
    def batch_ended(self, **kwargs): pass
    def before_forward(self, **kwargs): pass
    def after_forward(self, **kwargs): pass
    def before_backward(self, **kwargs): pass
    def after_backward(self, **kwargs): pass
    def interrupted(self, **kwargs): pass

In [38]:
from enum import IntFlag

class Order(IntFlag):
    Unknown = -1
    Internal = 0
    Loss = 10
    Metrics = 100
    Schedule = 200
    History = 300
    Logging = 1000
    
    def __call__(self, index=0):
        return Order(self.value + index)
    
    @staticmethod
    def sort(items):
        ordered = [(getattr(item, Order.Unknown), item) for item in items]
        ordered.sort(key=itemgetter(0))
        return [item for _, item in ordered]

In [35]:
#export
class RollingLoss(Callback):
    """A callback that tracks model's loss.
    
    The loss is interpolated between current and next value to get 
    a smoother loss curve.
    """
    order = Order.Loss()
    
    def __init__(self, smooth=0.98):
        self.smooth = smooth
    
    def batch_ended(self, phase, **kwargs):
        prev = phase.rolling_loss
        a = self.smooth
        avg_loss = a*prev + (1 - a)*phase.batch_loss
        debias_loss = avg_loss / (1 - a**phase.batch_index)
        phase.rolling_loss = avg_loss
        phase.update(debias_loss)
        
    def epoch_ended(self, phases, **kwargs):
        for phase in phases:
            phase.update_metric('loss', phase.last_loss)

In [None]:
#export
class History(Callback):
    """A callback that collects model's metrics during its training."""
    
    order = Order.History()
    
    def training_started(self, **kwargs):
        self.recorded = None
        self.epochs = []
    
    def epoch_ended(self, epoch, **kwargs):
        self.epochs.append(epoch)
        
    def training_ended(self, phases, **kwargs):
        epochs = {'epoch': np.array(self.epochs).astype(int)}
        metrics = [epochs] + [p.metrics_history for p in phases]
        data = pd.DataFrame(merge_dicts(metrics))
        data.reset_index(inplace=True, drop=True)
        self.recorded = data
        
    def plot(self, x='epoch', ys='loss', ax=None):
        return self.recorded.plot(x='epoch', y=to_list(ys), ax=ax)

In [None]:
#export
class Group(Callback):
    def __init__(self, cbs):
        self._init(self, cbs)
        self._model = None
        
    def _init(self, cbs):
        if not cbs:
            self.callbacks = []
            self.named_callbacks = {}
        else:
            cbs = Order.sort(cbs)
            for cb in cbs:
                cb.group = self
            self.callbacks = cbs
            self.named_callbacks = {to_snake_case(classname(cb)): cb for cb in cbs}
        
    def add(self, cb, *cbs):
        cbs = [cb] + list(cbs)
        self._init(cbs)
        
    def set_model(self, model):
        self._model = model

    def training_started(self, **kwargs): self('training_started', **kwargs)    
    def training_ended(self, **kwargs): self('training_ended', **kwargs)
    def epoch_started(self, **kwargs): self('epoch_started', **kwargs)
    def epoch_ended(self, **kwargs): self('epoch_ended', **kwargs)
    def phase_started(self, **kwargs): self('phase_started', **kwargs)
    def phase_ended(self, **kwargs): self('phase_ended', **kwargs)
    def batch_started(self, **kwargs): self('batch_started', **kwargs)
    def batch_ended(self, **kwargs): self('batch_ended', **kwargs)
    def before_forward(self, **kwargs): self('before_forward', **kwargs)
    def after_forward(self, **kwargs): self('after_forward', **kwargs)
    def before_backward(self, **kwargs): self('before_forward', **kwargs)
    def after_backward(self, **kwargs): self('after_backward', **kwargs)
    def interrupted(self, **kwargs): self('interrupted', **kwargs)
        
    def __getattr__(self, item):
        if item in vars(self):
            return self.__dict__[item]
        if self._model is not None:
            return getattr(self._model)
        raise AttributeError(item)
    
    def __getitem__(self, item):
        item = to_snake_case(item)
        if item in self.named_callbacks:
            return self.named_callbacks[item]
        raise KeyError(
            f'callback name is not found: {item}; '
            f'available callbacks are: {list(sorted(self.named_callbacks))}')
    
    def __call__(self, name, **kwargs):
        for cb in self.callbacks:
            method = getattr(cb, name, None)
            if method is None:
                continue
            method(**kwargs)

In [46]:
def default_optimizer(model, **params):
    if 'lr' not in params:
        params['lr'] = 0.001
    return optim.Adam(model.parameters(), **params)

In [None]:
def create_callbacks(cbs, default: bool=True):
    if cbs is None:
        cbs = [RollingLoss(), History()] if default else []
    else:
        cbs = list(cbs)
        if default:
            cbs += [RollingLoss(), History()]
    return Group(cbs)

In [None]:
def place_and_unwrap(batch, dev):
    """Places tensors from batch onto proper device and 
    returns (x, y) pair where all target tensors (if many)
    are put into y list.
    """
    batch = [t.to(dev) for t in batch]
    x, *ys = batch
    return x, to_list(ys)

In [55]:
class TrainingInterrupted(Exception):
    def __init__(self, context=None):
        self.context = context

In [56]:
class Loop:
    def __init__(self, model: nn.Module,
                 default_cb: bool=True, opt_fn: 'callable'=default_optimizer,
                 opt_params=None, device: 'device'=defaults.device,
                 loss_fn=defaults.loss_fn):
        
        breakpoint()
        model.to(device)
        opt = opt_fn(model, **(opt_params or {}))
        cb = create_callbacks(cbs, default_cb)
        cb.set_model(model)
        
        self.model = model
        self.opt = opt
        self.cb = cb
        self.loss_fn = loss_fn 
        self.device = device
        
    def fit_datasets(trn_ds, val_ds, epochs: int=1, batch_size: int=defaults.bs):
        phases = Phase.make_train_valid(
            trn_ds, val_ds, bs=batch_size, num_workers=defaults.n_jobs)
        self.train(phases, epochs)
        
    def train(phases: list, epochs: int=1):
        try:
            self.cb.training_started(phases=phases)
            for epoch in range(1, epochs + 1):
                self.train_one_epoch(phases, epoch)
            self.cb.training_ended(phases=phases)
        except TrainingInterrupted as e:
            self.cb.interrupted(reason=e)
    
    def train_one_epoch(phases: list, curr_epoch: int=1):
        self.cb.epoch_started(epoch=curr_epoch)

        for phase in phases:
            n = len(phase.loader)
            cb.phase_started(phase=phase, total_batches=n)
            is_training = phase.grad
            model.train(is_training)
            
            for batch in phase.loader:
                phase.batch_index += 1
                cb.batch_started(phase=phase, total_batches=n)
                x, y = place_and_unwrap(batch, self.device)
                
                with torch.set_grad_enabled(is_training):
                    cb.before_forward()
                    out = model(x)
                    cb.after_forward()
                    loss = self.loss_fn(out, y)
                
                if is_training:
                    opt.zero_grad()
                    cb.before_backward()
                    loss.backward()
                    opt.step()
                
                phase.batch_loss = loss.item()
                cb.batch_ended(phase=phase, output=out, target=y)
                
            cb.phase_ended(phase=phase)
                    
        self.cb.epoch_ended(phases=phases, epoch=curr_epoch)

In [57]:
from torch import nn
from torch import optim
from torch.nn import functional as F

class TinyNet(nn.Module):
    def __init__(self, n_out=10):
        super().__init__()
        self.conv1 = nn.Conv2d(28, 64, 3)
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, n_out)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
loop = Loop(TinyNet(), loss_fn=F.cross_entropy)

> <ipython-input-56-d2b2b96ab060>(8)__init__()
-> model.to(device)
(Pdb) device
