In [None]:
import os

import lightning as L
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    ModelSummary,
)
import matplotlib.pyplot as plt
import numpy as np
import torch
from rich import print
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from PIL import Image
from typing import List, Tuple, Dict
import tensorboard

%load_ext autoreload
%autoreload 2
%load_ext rich
%load_ext tensorboard

# Set random seed for reproducibility
seed = 42
L.seed_everything(seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Set data directory
DATA_DIR = os.path.join(os.getcwd(), "data")

## Load and preprocess the dataset


In [None]:
label_classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


In [None]:
train_dataset = CIFAR10(DATA_DIR, train=True, download=True)
train_dataset


In [None]:
# Display some images randomly
fig, ax = plt.subplots(1, 5, figsize=(15, 3))
for i, idx in enumerate(np.random.choice(len(train_dataset), 5, replace=False)):
    ax[i].imshow(train_dataset[idx][0])
    ax[i].set_title(f"Label: {train_dataset[idx][1]}")
    ax[i].axis("off")

plt.show()


In [None]:
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
print(f"Data mean: {DATA_MEANS}, Data std: {DATA_STD}")


In [None]:
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(DATA_MEANS, DATA_STD),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(DATA_MEANS, DATA_STD),
    ]
)


In [None]:
# Read the CIFAR-10 dataset and split it into train and validation sets

train_dataset = CIFAR10(DATA_DIR, train=True, download=True, transform=train_transform)
val_dataset = CIFAR10(DATA_DIR, train=True, download=True, transform=test_transform)
test_dataset = CIFAR10(DATA_DIR, train=False, download=True, transform=test_transform)

# Generate validation set without any data leakage while applying different transformations
L.seed_everything(seed)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
L.seed_everything(seed)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])


In [None]:
BATCH_SIZE = 128
# Generate data loaders
train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
)

test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)


In [None]:
NUM_IMAGES = 4
RNG_IDX = np.random.choice(len(train_dataset), NUM_IMAGES)
images = [train_dataset[idx][0] for idx in RNG_IDX]
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in RNG_IDX]
orig_images = [test_transform(img) for img in orig_images]

img_grid = make_grid(
    torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5
)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Augmentation examples on CIFAR10")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()


## Build the ResNet model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, first_stride: int = 1):
        super().__init__()

        self.left = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=first_stride,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=out_channels),
        )

        if first_stride > 1:
            self.right = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=first_stride,
                    bias=False,
                ),
                nn.BatchNorm2d(num_features=out_channels),
            )
        else:
            assert (
                in_channels == out_channels
            ), "in_channels must be equal to out_channels"
            self.right = nn.Identity()

        self.relu = nn.ReLU()

    def forward(self, x):
        left = self.left(x)
        right = self.right(x)
        return self.relu(left + right)


In [None]:
class BlockGroup(nn.Module):
    def __init__(
        self, n_blocks: int, in_channels: int, out_channels: int, first_stride: int = 1
    ):
        super().__init__()
        self.blocks = nn.Sequential(
            ResidualBlock(in_channels, out_channels, first_stride=first_stride),
            *[ResidualBlock(out_channels, out_channels) for _ in range(n_blocks - 1)],
        )

    def forward(self, x):
        """
        Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        """
        return self.blocks(x)

In [None]:
class ResNet(nn.Module):
    def __init__(
        self,
        n_blocks_per_group: List[int],
        out_features_per_group: List[int],
        first_strides_per_group: List[int],
        n_classes: int = 10,
        in_feats0: int = 16,
    ):
        super().__init__()

        self.in_feats0 = in_feats0
        self.n_classes = n_classes
        self.n_blocks_per_group = n_blocks_per_group
        self.out_features_per_group = out_features_per_group
        self.first_strides_per_group = first_strides_per_group

        self.in_layers = nn.Sequential(
            nn.Conv2d(3, in_feats0, kernel_size=1, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(num_features=in_feats0),
        )

        all_in_feats = [in_feats0] + out_features_per_group[:-1]

        self.residual_layers = nn.Sequential(
            *(
                BlockGroup(*args)
                for args in zip(
                    n_blocks_per_group,
                    all_in_feats,
                    out_features_per_group,
                    first_strides_per_group,
                )
            )
        )

        self.out_layers = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(out_features_per_group[-1], n_classes),
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.out_layers(self.residual_layers(self.in_layers(x)))
        return x


In [None]:
class CIFARModule(L.LightningModule):
    def __init__(
        self,
        model_hparams: Dict,
        optimizer_name: str,
        optimizer_hparams: Dict,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = ResNet(**model_hparams)

        self.loss_fn = nn.CrossEntropyLoss()

        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        if self.hparams.optimizer_name == "Adam":
            optimizer = torch.optim.AdamW(
                self.parameters(), **self.hparams.optimizer_hparams
            )

        elif self.hparams.optimizer_name == "SGD":
            optimizer = torch.optim.SGD(
                self.parameters(), **self.hparams.optimizer_hparams
            )

        else:
            assert False, f"Unknown optimizer: {self.hparams.optimizer_name}"

        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1
        )

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_fn(preds, labels)

        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_fn(preds, labels)

        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_fn(preds, labels)

        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(
            "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        return loss

## Train the CNN model

In [None]:
model = CIFARModule(
    model_hparams={
        "n_blocks_per_group": [3, 3, 3],
        "out_features_per_group": [16, 32, 64],
        "first_strides_per_group": [1, 2, 2],
        "n_classes": 10,
    },
    optimizer_name="SGD",
    optimizer_hparams={"lr": 0.1, "weight_decay": 1e-4, "momentum": 0.9},
)

# 6n + 2 layers
print(f"Total number of parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
trainer = L.Trainer(
    accelerator="auto",
    max_epochs=10,
    callbacks=[
        ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
        LearningRateMonitor(
            logging_interval="epoch", log_momentum=True, log_weight_decay=True
        ),
    ],
)

trainer.logger._log_graph = True
trainer.logger._default_hp_metric = None


In [None]:
L.seed_everything(42)
trainer.fit(model, train_loader, val_loader)


## Evaluating the trained model

In [None]:
val_result = trainer.test(model, dataloaders=val_loader, verbose=True)

In [None]:
test_result = trainer.test(model, dataloaders=test_loader, verbose=True)