# Pytorch Training Loop Implementation

In [None]:
# From "Deep Learning Training Loop Implementation" 
# https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynb

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, OrderedDict
import math
from pathlib import Path
import re
import sys

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms as T
from torchvision.datasets import MNIST, CIFAR10
from tqdm import tqdm_notebook as tqdm

default_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model = create_model(params)
phases = create_train_valid_data()
opt = optim.SGD(model.params, lr=1e-3)

model.to(device)
    
for epoch in range(1, epochs + 1):

    for phase in phases:
        n = len(phase.loader)
        is_training = phase.grad
        model.train(is_training)

        for batch in phase.loader:
            x, y = place_and_unwrap(batch, device)

            with torch.set_grad_enabled(is_training):
                out = model(x)
                loss = loss_fn(out, y)

            if is_training:
                opt.zero_grad()
                loss.backward()
                opt.step()

            phase.batch_loss = loss.item()

In [None]:
# We don’t try to encapsulate all possible features into a single class or function, but delegate calls to subordinate modules. 
# Each module is responsible for reacting onto received notification properly

In [None]:
def train(model, opt, phases, callbacks=None, epochs=1, device=default_device, loss_fn=F.nll_loss):
    model.to(device)
    
    cb = callbacks
    
    cb.training_started(phases=phases, optimizer=opt)
    
    for epoch in range(1, epochs + 1):
        cb.epoch_started(epoch=epoch)

        for phase in phases:
            n = len(phase.loader)
            cb.phase_started(phase=phase, total_batches=n)
            is_training = phase.grad
            model.train(is_training)

            for batch in phase.loader:

                phase.batch_index += 1
                cb.batch_started(phase=phase, total_batches=n)
                x, y = place_and_unwrap(batch, device)

                with torch.set_grad_enabled(is_training):
                    cb.before_forward_pass()
                    out = model(x)
                    cb.after_forward_pass()
                    loss = loss_fn(out, y)

                if is_training:
                    opt.zero_grad()
                    cb.before_backward_pass()
                    loss.backward()
                    cb.after_backward_pass()
                    opt.step()

                phase.batch_loss = loss.item()
                cb.batch_ended(phase=phase, output=out, target=y)

            cb.phase_ended(phase=phase)

        cb.epoch_ended(phases=phases, epoch=epoch)

    cb.training_ended(phases=phases)

In [None]:
# Callbacks Examples

In [None]:
# 1. Loss
# At the end of every batch, we’re computing a running loss

In [None]:
class RollingLoss(Callback):

    def __init__(self, smooth=0.98):
        self.smooth = smooth

    def batch_ended(self, phase, **kwargs):
        prev = phase.rolling_loss
        a = self.smooth
        avg_loss = a * prev + (1 - a) * phase.batch_loss
        debias_loss = avg_loss / (1 - a ** phase.batch_index)
        phase.rolling_loss = avg_loss
        phase.update(debias_loss)

    def epoch_ended(self, phases, **kwargs):
        for phase in phases:
            phase.update_metric('loss', phase.last_loss)

In [None]:
# 2. Accuracy
# Note that the callback receives notifications at the end of each batch, and the end of training epoch

In [None]:
def accuracy(out, y_true):
    y_hat = out.argmax(dim=-1).view(y_true.size(0), -1)
    y_true = y_true.view(y_true.size(0), -1)
    match = y_hat == y_true
    return match.float().mean()
  

class Accuracy(Callback):

    def epoch_started(self, **kwargs):
        self.values = defaultdict(int)
        self.counts = defaultdict(int)

    def batch_ended(self, phase, output, target, **kwargs):
        acc = accuracy(output, target).detach().item()
        self.counts[phase.name] += target.size(0)
        self.values[phase.name] += target.size(0) * acc

    def epoch_ended(self, phases, **kwargs):
        for phase in phases:
            metric = self.values[phase.name] / self.counts[phase.name]
            phase.update_metric('accuracy', metric)

In [None]:
# 3. Parameter Scheduler
# The idea is to use cyclic schedulers that adjust model’s optimizer parameters magnitudes during single or several training epochs. 
# Moreover, these schedulers not only decrease learning rates as a number of processed batches grows but also increase them for some number of steps or periodically.

In [None]:
#  we are effectively getting a stochastic gradient with warm restarts that allows us to escape from local minima. 
# The following snippet shows how one can implement a cosine annealing learning rate

In [None]:
class CosineAnnealingSchedule:
    """
    The schedule class that returns a learning rate multiplier from range [0.0, 1.0]
    """
    def __init__(self, eta_min=0.0, eta_max=1.0, t_max=100, t_mult=2):
        self.eta_min = eta_min
        self.eta_max = eta_max
        self.t_max = t_max
        self.t_mult = t_mult
        self.iter = 0

    def update(self, **kwargs):
        self.iter += 1

        eta_min, eta_max, t_max = self.eta_min, self.eta_max, self.t_max

        t = self.iter % t_max
        eta = eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(math.pi * t / t_max))
        if t == 0:
            self.iter = 0
            self.t_max *= self.t_mult

        return eta

In [None]:
# Stream Logger
# The last thing we would like to add is some logging to see how well our model performs during the training process. 
# The most simplistic approach is to print stats into the standard output stream. 
# However, you could save it into CSV file or even send as a notification to your mobile phone instead.

In [None]:
def merge_dicts(ds):
    merged = OrderedDict()
    for d in ds:
        for k, v in d.items():
            merged[k] = v
    return merged

In [None]:
class StreamLogger(Callback):

    def __init__(self, streams=None, log_every=1):
        self.streams = streams or [sys.stdout]
        self.log_every = log_every

    def epoch_ended(self, phases, epoch, **kwargs):
        if epoch % self.log_every != 0:
            return
        metrics = merge_dicts([phase.last_metrics for phase in phases])
        values = [f'{k}={v:.4f}' for k, v in metrics.items()]
        values_string = ', '.join(values)
        string = f'Epoch: {epoch:4d} | {values_string}\n'
        for stream in self.streams:
            stream.write(string)
            stream.flush()

In [None]:
class ProgressBar(Callback):

    def training_started(self, phases, **kwargs):
        bars = OrderedDict()
        for phase in phases:
            bars[phase.name] = tqdm(total=len(phase.loader), desc=phase.name)
        self.bars = bars

    def batch_ended(self, phase, **kwargs):
        bar = self.bars[phase.name]
        bar.set_postfix_str(f'loss: {phase.last_loss:.4f}')
        bar.update(1)
        bar.refresh()

    def epoch_ended(self, **kwargs):
        for bar in self.bars.values():
            bar.n = 0
            bar.refresh()

    def training_ended(self, **kwargs):
        for bar in self.bars.values():
            bar.n = bar.total
            bar.refresh()
            bar.close()

In [None]:
# we’re ready to start using our training loop!

In [None]:
## You definitely should user transfer learnin when working on your daily tasks. 
## It makes your network to converge much faster compared to the training from scratch. 

In [None]:
def as_sequential(model: nn.Module):
    """Converts model with nested submodules into Sequential model."""

    return nn.Sequential(*list(model.children()))

In [None]:
class AdaptiveConcatPool2d(nn.Module):
    """Applies average and maximal adaptive pooling to the tensor and
    concatenates results into a single tensor.

    The idea is taken from fastai library.
    """
    def __init__(self, size=1):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(size)
        self.max = nn.AdaptiveMaxPool2d(size)

    def forward(self, x):
        return torch.cat([self.max(x), self.avg(x)], 1)

In [None]:
class Flatten(nn.Module):
    """Converts N-dimensional tensor into 'flat' one."""

    def forward(self, x):
        return x.view(x.size(0), -1)

In [None]:
class ResNetClassifier(nn.Module):

    def __init__(self, n_classes, arch=models.resnet18):
        super().__init__()

        model = arch(True)
        seq_model = as_sequential(model)
        backbone, classifier = seq_model[:-2], seq_model[-2:]

        self.backbone = backbone
        self.top = nn.Sequential(
            AdaptiveConcatPool2d(),
            Flatten(),
            
            nn.Linear(1024, 512),
            nn.Dropout(0.25),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 256),
            nn.Dropout(0.5),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, n_classes)
        )
    
    def forward(self, x):
        return self.top(self.backbone(x))

In [None]:
def classifier_weights(m: nn.Module, bn=(1, 1e-3)):
    """Initializes layers weights for a classification model."""
    
    name = classname(m)

    with torch.no_grad():
        if name.find('Conv') != -1:
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)

        elif name.find('BatchNorm') != -1:
            weight, bias = bn
            nn.init.constant_(m.weight, weight)
            nn.init.constant_(m.bias, bias)

        elif name == 'Linear':
            nn.init.kaiming_normal_(m.weight)
            nn.init.zeros_(m.bias)

In [None]:
def freeze(m, freeze=True, bn=False):
    for child in m.children():
        name = classname(child)
        if not bn and name.find('BatchNorm') != -1:
            continue
        for p in child.parameters():
            p.requires_grad = not freeze

In [None]:
data_path = Path.home()/'data'/'cifar10'

imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

epochs = 10

train_ds = CIFAR10(
    data_path, 
    train=True, 
    download=True,
    transform=T.Compose([
        T.Resize(224),
        T.RandomAffine(5, translate=(0.05, 0.05), scale=(0.9, 1.1)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(*imagenet_stats)
    ])
)

valid_ds = CIFAR10(
    data_path, 
    train=False, 
    transform=T.Compose([
        T.Resize(224),
        T.ToTensor(),
        T.Normalize(*imagenet_stats)
    ])
)

phases = make_phases(train_ds, valid_ds, bs=200, n_jobs=8)

model = ResNetClassifier(10)
model.top.apply(classifier_weights)
freeze(model.backbone)

# Try AdamW later! for a better (correct) weight decay regularization
opt = optim.SGD(model.parameters(), lr=1e-2, momentum=0.95, nesterov=True, weight_decay=1e-2)

cb = CallbacksGroup([
    RollingLoss(),
    Accuracy(),
    Scheduler(
        OneCycleSchedule(t=len(phases[0].loader) * epochs),
        params_conf=[
            {'name': 'lr'},
            {'name': 'weight_decay', 'inverse': True}
        ],
        mode='batch'
    ),
    StreamLogger(),
    ProgressBar()
])

In [None]:
train(model, opt, phases, cb, epochs=epochs, loss_fn=F.cross_entropy)

In [None]:
lr_history = pd.DataFrame(cb['scheduler'].parameter_history('lr'))
ax = lr_history.plot(figsize=(8, 6))
ax.set_xlabel('Training Batch Index')
ax.set_ylabel('Learning Rate');