# Wide ResNet (ConvNet) model using PyTorch

In [1]:
%matplotlib inline
%load_ext lab_black
%load_ext autoreload
%autoreload 2

In [2]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
import os
from tqdm import tqdm

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.transforms as tt
from torch.utils.data import random_split, DataLoader

In [5]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


class DeviceDataLoader:
    def __init__(self, data_loader, device):
        self.data_loader = data_loader
        self.device = device

    def __iter__(self):
        for batch in self.data_loader:
            yield to_device(batch, self.device)

    def __len__(self):
        return len(self.data_loader)

In [6]:
VAL_SIZE = 5000
BATCH_SIZE = 8
N_EPOCHS = 100
LEARNING_RATE = 1e-3

In [7]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_transformers = tt.Compose(
    [
        tt.RandomCrop(32, padding=4, padding_mode="reflect"),
        tt.RandomHorizontalFlip(),
        tt.ToTensor(),
        tt.Normalize(*stats),
    ]
)
validation_transformers = tt.Compose([tt.ToTensor(), tt.Normalize(*stats)])

In [8]:
train_dataset = ImageFolder(
    os.path.join("../data", "cifar10", "train"), transform=train_transformers
)
validation_dataset = ImageFolder(
    os.path.join("../data", "cifar10", "test"), transform=validation_transformers
)

In [9]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True
)
validation_dataloader = DataLoader(
    validation_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)

In [10]:
train_loader = DeviceDataLoader(train_dataloader, get_default_device())
val_loader = DeviceDataLoader(validation_dataloader, get_default_device())

In [11]:
def accuracy(y_preds: torch.Tensor, y_true: torch.Tensor):
    return torch.sum(y_preds == y_true).item() / y_true.numel()

In [12]:
def conv_2d(
    input_channels: int, output_channels: int, stride: int = 1, kernel_size: int = 3
):
    return nn.Conv2d(
        in_channels=input_channels,
        out_channels=output_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
        bias=False,
    )


def batch_norm_conv_2d(input_channels: int, output_channels: int):
    return nn.Sequential(
        nn.BatchNorm2d(input_channels),
        nn.ReLU(inplace=True),
        conv_2d(input_channels, output_channels),
    )


class ResidualBlock(nn.Module):
    def __init__(self, input_channels: int, output_channels: int, stride: int = 1):
        super().__init__()
        self.batch_norm = nn.BatchNorm2d(input_channels)
        self.conv_1 = conv_2d(input_channels, output_channels, stride)
        self.conv_2 = batch_norm_conv_2d(output_channels, output_channels)
        self.shortcut = lambda x: x
        if input_channels != output_channels:
            self.shortcut = conv_2d(
                input_channels, output_channels, stride, kernel_size=1
            )

    def forward(self, x):
        x = F.relu(self.batch_norm(x), inplace=True)
        r = self.shortcut(x)
        x = self.conv_1(x)
        x = self.conv_2(x) * 0.2
        return x.add_(r)


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

In [13]:
class WideResNet(nn.Module):
    def __init__(
        self,
        n_groups: int,
        n_res_blocks_per_group: int,
        n_classes: int,
        channel_multiplier: int,
        initial_n_channels: int,
    ):
        super().__init__()
        self.n_groups = n_groups
        self.n_classes = n_classes
        self.n_res_blocks_per_group = n_res_blocks_per_group
        self.initial_n_channels = initial_n_channels
        self.channel_multiplier = channel_multiplier
        self.layers = []
        self.model = None
        self.build_model()

    def conv_2d(
        self, input_channels: int, output_channels: int, kernel_size: int, stride: int
    ):
        return nn.Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            bias=False,
        )

    def build_group(
        self,
        n_res_blocks_per_group: int,
        input_channels: int,
        output_channels: int,
        stride: int,
    ):
        group = []
        first_block = ResidualBlock(input_channels, output_channels, stride)
        group.append(first_block)
        for i in range(1, n_res_blocks_per_group):
            group.append(ResidualBlock(output_channels, output_channels))
        return group

    def calc_n_channels(self, index: int) -> int:
        if index == 0:
            return self.initial_n_channels
        return self.initial_n_channels * (2 ** index) * self.channel_multiplier

    def build_model(self):
        self.layers = [
            self.conv_2d(
                input_channels=3,
                output_channels=self.initial_n_channels,
                kernel_size=3,
                stride=1,
            )
        ]
        for i in range(self.n_groups):
            stride = 2 if i > 0 else 1
            self.layers.extend(
                self.build_group(
                    self.n_res_blocks_per_group,
                    self.calc_n_channels(i),
                    self.calc_n_channels(i + 1),
                    stride,
                )
            )
        self.layers.extend(
            [
                nn.BatchNorm2d(self.calc_n_channels(self.n_groups)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(1),
                Flatten(),
                nn.Linear(self.calc_n_channels(self.n_groups), self.n_classes),
            ]
        )
        self.model = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.model(x)

In [14]:
model = WideResNet(
    n_groups=3,
    n_res_blocks_per_group=3,
    n_classes=10,
    channel_multiplier=6,
    initial_n_channels=16,
)

In [15]:
model = to_device(model, get_default_device())

In [None]:
optimiser = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
history = {
    "loss": [],
    "acc": [],
    "val_loss": [],
    "val_acc": [],
}
for i in range(N_EPOCHS):
    _loss = []
    _acc = []
    _val_loss = []
    _val_acc = []
    _batch_sizes = []
    _val_batch_sizes = []


    # Training
    for Xb, yb in tqdm(train_loader):
        logits = model(Xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        # Metrics
        _loss.append(loss.detach().numpy())
        y_prob = F.softmax(logits, dim=1)
        y_pred_prob, y_preds = torch.max(y_prob, dim=1)
        acc = accuracy(y_preds, yb)
        _acc.append(acc)
        _batch_sizes.append(len(Xb))

    # Validation
    with torch.no_grad():
        for Xb, yb in val_loader:
            logits = model(Xb)
            val_loss = F.cross_entropy(logits, yb).detach().numpy()
            _val_loss.append(val_loss)
            y_prob = F.softmax(logits, dim=1)
            y_pred_prob, y_preds = torch.max(y_prob, dim=1)
            val_acc = accuracy(y_preds, yb)
            _val_acc.append(val_acc)
            _val_batch_sizes.append(len(Xb))


        # Weighted sum of losses to take into account non-equal batch sizes
        _loss = np.sum(np.multiply(_loss, _batch_sizes)) / np.sum(_batch_sizes)
        _val_loss = np.sum(np.multiply(_val_loss, _val_batch_sizes)) / np.sum(_val_batch_sizes)

        history["loss"].append(_loss)
        history["acc"].append(torch.Tensor(_acc).mean().item())
        history["val_acc"].append(torch.Tensor(_val_acc).mean().item())
        history["val_loss"].append(_val_loss)
        print(f"Epoch: {i + 1}/{N_EPOCHS}, acc: {history['acc'][-1]:.4f}, loss: {history['loss'][-1]:.4f}, val_acc: {history['val_acc'][-1]:.4f},  val_loss: {history['val_loss'][-1]:.4f}\r", end="")