# Training Loop Example

The notebook contains an example of simple training loop initialization and training.

In [72]:
import torch
from torch import nn
from torch import optim
from torchvision import models
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

In [178]:
from tqdm import tqdm_notebook as tqdm

In [179]:
import sys
from pathlib import Path
try:
    old_path
except NameError:
    old_path = sys.path.copy()
    sys.path = [Path.cwd().parent.as_posix()] + old_path

In [None]:
from loop.callbacks import History

In [None]:
def flat_model(model):
    """Converts model with nested modules into single list of modules"""
    
    def flatten(m):
        children = list(m.children())
        if not children:
            return [m]
        return sum([flatten(child) for child in children], [])
    
    return nn.Sequential(*flatten(model))

In [None]:
def as_sequential(model):
    return nn.Sequential(*list(model.children()))

In [None]:
def get_output_shape(model):
    """Pass a dummy input through the sequential model to get the output tensor shape."""
    first, *rest = flat_model(model)
    shape = first.in_channels, 128, 128
    dummy_input = torch.zeros(shape)
    out = model(dummy_input[None])
    return list(out.size())[1:]

In [None]:
class AdaptiveConcatPool2d(nn.Module):
    
    def __init__(self, size=1):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(size)
        self.max = nn.AdaptiveMaxPool2d(size)
        
    def forward(self, x):
        return torch.cat([self.max(x), self.avg(x)], 1)

In [None]:
class Flatten(nn.Module):
    
    def forward(self, x):
        return x.view(x.size(0), -1)

In [None]:
def init_default(module, init_fn=nn.init.kaiming_normal_):
    if init_fn is not None:
        if hasattr(module, 'weight'):
            init_fn(module.weight)
        if hasattr(module, 'bias') and hasattr(module.bias, 'data'):
            module.bias.data.fill_(0.)
    return module

In [None]:
def conv2d(ni, no, kernel=3, stride=1, padding=None, bias=False):
    padding = padding or kernel//2
    layer = nn.Conv2d(ni, no, kernel, stride, padding, bias=bias)
    return init_default(layer)

In [None]:
def batchnorm2d(nf):
    bn = nn.BatchNorm2d(nf)
    with torch.no_grad():
        bn.bias.fill_(1e-3)
        bn.weight.fill_(0)
    return bn

In [None]:
def init_weights(m):
    with torch.no_grad():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 1e-3)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.zeros_(m.bias)

In [None]:
def leaky_linear(ni, no, dropout=None, bn=True):
    layers = []
    if bn:
        layers.append(nn.BatchNorm1d(ni))
    if dropout is not None and droupout > 0:
        layers.append(nn.Dropout(dropout))
    layers.append(nn.Linear(ni, no))
    layers.append(nn.LeakyReLU(0.01, True))
    return nn.Sequential(*layers)

In [None]:
class Classifier(nn.Module):
    
    def __init__(self, n_classes, arch=models.resnet18, init_fn=init_weights):
        super().__init__()
        
        model = arch(True)
        seq_model = as_sequential(model)
        backbone, classifier = seq_model[:-2], seq_model[-2:]
        out_shape = get_output_shape(backbone)
        input_size = out_shape[0] * 2
        
        self.backbone = backbone
        self.cat = AdaptiveConcatPool2d()
        self.flatten = Flatten()
        self.block1 = leaky_linear(input_size, 512)
        self.block2 = leaky_linear(512, 256)
        self.out = nn.Linear(256, n_classes)
        self.init(init_fn)
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.cat(x)
        x = self.flatten(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.out(x)
        return x
    
    def init(self, fn=None):
        if fn is None:
            return
        self.apply(fn)                

In [None]:
model = Classifier(10)

In [16]:
img = torch.zeros((4, 3, 128, 128))
out = model(img)

In [17]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.optim import Optimizer

In [18]:
from multiprocessing import cpu_count

In [19]:
def default(x, fallback=None):
    return x if x is not None else fallback

In [77]:
class Callback:
    
    def training_started(self):
        pass
    
    def training_ended(self):
        pass
    
    def epoch_started(self, **kwargs):
        pass
    
    def phase_started(self, **kwargs):
        pass

    def phase_ended(self, **kwargs):
        pass
    
    def epoch_ended(self, **kwargs):
        pass
    
    def batch_started(self, **kwargs):
        pass
    
    def batch_ended(self, **kwargs):
        pass
    
    def before_forward_pass(self, **kwargs):
        pass
    
    def after_forward_pass(self, **kwargs):
        pass
    
    def before_backward_pass(self, **kwargs):
        pass
    
    def after_backward_pass(self, **kwargs):
        pass

In [78]:
class RollingLoss(Callback):
    
    def __init__(self, smooth=0.98):
        self.smooth = smooth
    
    def batch_ended(self, phase, **kwargs):
        prev = phase.rolling_loss
        a = self.smooth
        avg_loss = a*prev + (1 - a)*phase.batch_loss
        debias_loss = avg_loss/(1 - a**phase.batch_index)
        phase.rolling_loss = avg_loss
        phase.last_loss = debias_loss

In [104]:
class StreamLogger(Callback):
    """
    Writes performance metrics collected during the training process into list
    of streams.

    Parameters:
        streams: A list of file-like objects with `write()` method.

    """
    def __init__(self, streams=None, log_every=1):
        self.streams = streams or [sys.stdout]
        self.log_every = log_every
    
    def epoch_ended(self, phases, epoch, **kwargs):
        losses = [f'{phase.name}={phase.last_loss:.4f}' for phase in phases]
        loss_string = ', '.join(losses)
        string = f'Epoch: {epoch:4d} | {loss_string}\n'
        for stream in self.streams:
            stream.write(string)
            stream.flush()

In [182]:
class ProgressBar(Callback):
    
    def phase_started(self, phase, total_batches, **kwargs):
        if not hasattr(self, 'bar'):
            self.bar = tqdm(total=total_batches, desc=phase.name)
        else:
            self.bar.n = 0
            self.bar.total = total_batches
            self.bar.desc = phase.name
        self.bar.refresh()
    
    def batch_ended(self, phase, **kwargs):
        self.bar.set_postfix_str(f'loss: {phase.last_loss:.4f}')
        self.bar.update(1)
        
    def trianing_ended(self, **kwargs):
        self.bar.close()
        self.bar = None

In [80]:
class CallbacksGroup:
    
    def __init__(self, callbacks):
        self.callbacks = callbacks
        self.named_callbacks = {cb.__class__.__name__: cb for cb in callbacks}
    
    def set_loop(self, loop):
        self.loop = loop
        
    def training_started(self, **kwargs):
        for cb in self.callbacks:
            cb.training_started()
    
    def training_ended(self, **kwargs):
        for cb in self.callbacks:
            cb.training_ended()
            
    def phase_started(self, **kwargs):
        for cb in self.callbacks:
            cb.phase_started(**kwargs)

    def phase_ended(self, **kwargs):
        for cb in self.callbacks:
            cb.phase_ended(**kwargs)
    
    def epoch_started(self, **kwargs):
        for cb in self.callbacks:
            cb.epoch_started(**kwargs)
    
    def epoch_ended(self, **kwargs):
        for cb in self.callbacks:
            cb.epoch_ended(**kwargs)
    
    def batch_started(self, **kwargs):
        for cb in self.callbacks:
            cb.batch_started(**kwargs)
            
    def batch_ended(self, **kwargs):
        for cb in self.callbacks:
            cb.batch_ended(**kwargs)
    
    def before_forward_pass(self, **kwargs):
        for cb in self.callbacks:
            cb.before_forward_pass(**kwargs)
    
    def after_forward_pass(self, **kwargs):
        for cb in self.callbacks:
            cb.after_forward_pass(**kwargs)
    
    def before_backward_pass(self, **kwargs):
        for cb in self.callbacks:
            cb.before_forward_pass(**kwargs)
    
    def after_backward_pass(self, **kwargs):
        for cb in self.callbacks:
            cb.after_backward_pass(**kwargs)

In [81]:
class Phase:
    """
    Model training loop phase.

    Each model's training loop iteration could be separated into (at least) two
    phases: training and validation. The instances of this class track
    metrics and counters, related to the specific phase, and keep the reference
    to subset of data, used during phase.
    """
    def __init__(self, name: str, loader: DataLoader, grad: bool=True):
        self.name = name
        self.loader = loader
        self.grad = grad
        self.batch_loss = None
        self.batch_index = 0
        self.rolling_loss = 0
        self.last_loss = 0

In [82]:
class ExpandChannels:
    
    def __init__(self, num_of_channels=3):
        self.nc = num_of_channels
    
    def __call__(self, x):
        return x.expand((self.nc,) + x.shape[1:])

In [83]:
def get_transforms():
    return Compose([
        ToTensor(), 
        ExpandChannels(3),
        Normalize((0.1307,), (0.3081,))
    ])

In [84]:
def place_and_unwrap(batch, dev):
    x, *y = batch
    x = x.to(dev)
    y = [tensor.to(dev) for tensor in y]
    if len(y) == 1:
        [y] = y
    return x, y

In [183]:
n = 4
n_jobs = default(n, cpu_count())
epochs = 10
batch_size = 4096
device = torch.device('cuda:1')

transforms = get_transforms()
root = Path('~/data/mnist').expanduser()
train_ds = MNIST(root, train=True, download=True, transform=transforms)
valid_ds = MNIST(root, train=False, transform=transforms)

phases = [
    Phase('train', DataLoader(train_ds, batch_size, shuffle=True)),
    Phase('valid', DataLoader(valid_ds, batch_size), grad=False)
]

model = Classifier(10)
model.to(device)
opt = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
cb = CallbacksGroup([RollingLoss(), StreamLogger(), ProgressBar()])
cb.training_started()

for epoch in range(epochs):
    cb.epoch_started(epoch=epoch)

    for phase in phases:
        n = len(phase.loader)
        cb.phase_started(phase=phase, total_batches=n)
        is_training = phase.grad
        model.train(is_training)
        
        for batch in phase.loader:

            phase.batch_index += 1
            cb.batch_started(phase=phase, total_batches=n)
            x, y = place_and_unwrap(batch, device)
            
            with torch.set_grad_enabled(is_training):
                cb.before_forward_pass()
                out = model(x)
                cb.after_forward_pass()
                loss = loss_fn(out, y)
            
            if is_training:
                opt.zero_grad()
                cb.before_backward_pass()
                loss.backward()
                cb.after_backward_pass()
                opt.step()
            
            phase.batch_loss = loss.item()
            cb.batch_ended(phase=phase, output=out, target=y)
            
        cb.phase_ended(phase=phase)
    
    cb.epoch_ended(phases=phases, epoch=epoch)

HBox(children=(IntProgress(value=0, description='train', max=15), HTML(value='')))

Epoch:    0 | train=0.3807, valid=4.4185
Epoch:    1 | train=0.1899, valid=2.2400
Epoch:    2 | train=0.1181, valid=1.4651
Epoch:    3 | train=0.0798, valid=1.0762
Epoch:    4 | train=0.0563, valid=0.8417
Epoch:    5 | train=0.0407, valid=0.6880
Epoch:    6 | train=0.0302, valid=0.5790
Epoch:    7 | train=0.0227, valid=0.4964
Epoch:    8 | train=0.0178, valid=0.4346


KeyboardInterrupt: 