# 1. Importing packages

In [1]:
from typing import Optional, Tuple

import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision import datasets, transforms

# config inline plotting of matplotlib, use svg format for better quality and smaller file size
%config InlineBackend.figure_format = 'svg'

# 2. Define the Model Architecture 

## 2.1 AlexNet

In [2]:
class AlexNetSmall(nn.Module):
    def __init__(self, num_classes = 10):
        super(AlexNetSmall, self).__init__() # Call parent class constructor

        # Define the layers of the network
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1), # 3 input channels, 32 output channels, 3x3 kernel
            nn.ReLU(inplace = True), # ReLU activation
            nn.MaxPool2d(kernel_size = 2, stride = 2), # Max pooling with 2x2 kernel and stride 2

            nn.Conv2d(32, 64, kernel_size=3, padding = 1),   # 32 input channels, 64 output channels, 3x3 kernel
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size = 2, stride = 2),

            nn.Conv2d(64, 96, kernel_size = 3, padding = 1), 
            nn.ReLU(inplace = True),

            nn.Conv2d(96, 64, kernel_size = 3, padding = 1),  
            nn.ReLU(inplace = True),

            nn.Conv2d(64, 32, kernel_size = 3, padding = 1), 
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size = 2, stride = 2),
        )

        # Define the classifier part of the network, which is a fully connected network
        self.classifier = nn.Sequential(
            nn.Dropout(p = 0.5),  # Dropout with probability 0.5 to prevent overfitting
            nn.Linear(32 * 4 * 4, 512),  # Fully connected layer with 512 units, 32*4*4 is the size of the output of the convolutional part
            nn.ReLU(inplace = True),
            nn.Dropout(p = 0.5),
            nn.Linear(512, 256),  # Fully connected layer with 256 units, 512 is the size of the output of the previous layer
            nn.ReLU(inplace = True),
            nn.Linear(256, num_classes), # Output layer with num_classes units (10 in this case)
        )

        self._initialize_weights()

    # forward method defines the computation performed at every call
    def forward(self, x):
        x = self.features(x)  # Convolutional layers
        x = torch.flatten(x, 1)  # Flatten the output of the convolutional layers
        x = self.classifier(x)  # Fully connected layers
        return x

    # Added weight initialization method
    def _initialize_weights(self):
        for m in self.modules():  # Iterate over all modules in the network
            if isinstance(m, nn.Conv2d): # Check if the module is a Conv2d layer
                # Initialize Conv2d weights with kaiming_normal initialization
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None: # Check if the layer has bias
                    nn.init.constant_(m.bias, 0) # Initialize bias to 0
            elif isinstance(m, nn.Linear): # Check if the module is a Linear layer
                nn.init.normal_(m.weight, 0, 0.01) # Initialize Linear weights with normal distribution (mean 0, std 0.01)
                nn.init.constant_(m.bias, 0)  # Initialize bias to 0

## 2.2 Basic CNN

In [None]:
class BasicCNN(nn.Module):
    def __init__(self, num_classes = 10):
        super(BasicCNN, self).__init__()

        # Define the layers of the network using Sequential API
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(3, 32, kernel_size = 3, padding = 1), # 3 input channels, 32 output channels, 3x3 kernel
            nn.BatchNorm2d(32), # normalize the output of the previous layer，helps to stabilize and speed up the training
            nn.ReLU(inplace = True), # ReLU activation function to introduce non-linearity
            nn.MaxPool2d(kernel_size = 2, stride = 2), # Max pooling with 2x2 kernel and stride 2

            # Second block
            nn.Conv2d(32, 64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(64), # normalize the output of the previous layer
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size = 2, stride = 2),

            # Third block
            nn.Conv2d(64, 128, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size = 2, stride = 2),

            # Fourth block
            nn.Conv2d(128, 256, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size = 2, stride = 2),
        )

        # Define the classifier part of the network, which is a fully connected network 
        self.classifier = nn.Sequential(
            nn.Dropout(0.5), 
            nn.Linear(256 * 2 * 2, 512), # Fully connected layer with 512 units, 256*2*2 is the size of the output of the convolutional part
            nn.ReLU(inplace = True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes) # Output layer with num_classes units (10 in this case)
        )

        self._initialize_weights() 

    # forward method defines the computation performed at every call
    def forward(self, x):
        x = self.features(x) # Convolutional layers
        x = torch.flatten(x, 1) # Flatten the output of the convolutional layers
        x = self.classifier(x) # Fully connected layers
        return x

    # Added weight initialization method
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

## 2.3 ResNet_18

In [4]:
# Residual block
class ResidualBlock(nn.Module):
    expansion = 1  # expansion factor for the number of output channels, used in the shortcut connection, default is 1

    # constructor of the ResidualBlock class
    def __init__(self, in_channels, out_channels, stride = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)  # first convolution
        self.bn1 = nn.BatchNorm2d(out_channels)  # batch normalization
        self.relu = nn.ReLU(inplace = True)  # ReLU activation function
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)  # second convolution
        self.bn2 = nn.BatchNorm2d(out_channels)  # batch normalization

        # shortcut connection to add to the output of the second batch normalization
        self.shortcut = nn.Sequential() 
        # if the number of input channels is different from the number of output channels, we need to apply a convolution to the shortcut
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1,stride=stride, bias=False),  # convolution for the shortcut
                # batch normalization for the shortcut
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    # forward pass
    def forward(self, x):
        identity = self.shortcut(x) # shortcut connection to add to the output of the second batch normalization
        out = self.conv1(x)  # first convolution
        out = self.bn1(out)  # batch normalization
        out = self.relu(out)  # ReLU activation function
        out = self.conv2(out) 
        out = self.bn2(out)  
        out += identity  # add the shortcut connection, this is the skip connection, the output of the second batch normalization is added to the shortcut
        out = self.relu(out) 
        return out


# ResNet18 model
class ResNet18(nn.Module):
    def __init__(self, num_classes = 10):
        super(ResNet18, self).__init__()
        self.in_channels = 64  # number of input channels
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False)  # first convolution
        self.bn1 = nn.BatchNorm2d(64) 
        self.relu = nn.ReLU(inplace = True)
        self.layer1 = self.make_layer(64, 2, stride=1)  # first residual block
        self.layer2 = self.make_layer(128, 2, stride=2)  # second residual block
        self.layer3 = self.make_layer(256, 2, stride=2)  # third residual block
        self.layer4 = self.make_layer(512, 2, stride=2)  # fourth residual block
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # average pooling
        self.fc = nn.Linear(512 * ResidualBlock.expansion, num_classes)  # fully connected layer 

    # function to create a residual block
    def make_layer(self, out_channels, num_blocks, stride):
        # strides for the residual block
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []  # list to store the layers of the residual block
        for stride in strides:
            # add a residual block to the list
            layers.append(ResidualBlock(self.in_channels, out_channels, stride))
            # update the number of input channels for the next residual block
            self.in_channels = out_channels * ResidualBlock.expansion
        return nn.Sequential(*layers)

    # forward pass
    def forward(self, x):
        out = self.conv1(x) # first convolution
        out = self.bn1(out) # batch normalization
        out = self.relu(out) # ReLU activation function
        out = self.layer1(out) # first residual block
        out = self.layer2(out) # second residual block
        out = self.layer3(out) # third residual block
        out = self.layer4(out) # fourth residual block
        out = self.avg_pool(out) # average pooling
        out = out.view(out.size(0), -1) # flatten the output of the average pooling
        out = self.fc(out) # fully connected layer
        return out


# 3. Define PyTorch Lightning Module

## 3.1 Load and Prepare Data

In [5]:
class CIFAR10DataModule(pl.LightningDataModule):
    # Define the data module for CIFAR-10 dataset, which is a subclass of LightningDataModule class
    def __init__(self, data_dir: str = './data', batch_size: int = 128, num_workers: int = 4, augmentation_strength: str = 'basic'):
        super().__init__()
        self.data_dir = data_dir # directory to store the data
        self.batch_size = batch_size # batch size
        self.num_workers = num_workers # number of workers for data loaders
        self.augmentation_strength = augmentation_strength # augmentation strength, can be 'basic' or 'strong'

        # Define transforms based on augmentation strength
        if augmentation_strength == 'basic':
            self.transform = transforms.Compose([
                transforms.RandomCrop(32, padding = 4), # random crop with padding 4
                transforms.RandomHorizontalFlip(), # random horizontal flip
                transforms.ToTensor(), # convert the image to PyTorch tensor
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # normalize the image with mean and standard deviation
            ])
        elif augmentation_strength == 'strong':
            self.transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15), # random rotation with maximum angle 15 degrees
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # random affine transformation with maximum translation 0.1 in both directions
                transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2), # random color jitter
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        else:
            raise ValueError(f"Unknown augmentation strength: {augmentation_strength}") # raise an error if the augmentation strength is unknown

        # Define test transform without augmentation (only normalization)
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    # Prepare the data by downloading the CIFAR-10 dataset
    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train = True, download = True) # download training data 
        datasets.CIFAR10(self.data_dir, train = False, download = True) # download test data

    # Setup the data module by defining the train, validation, and test datasets
    def setup(self, stage: Optional[str] = None): # stage can be 'fit' (training), 'validate' (validation), 'test', or None
        if stage == 'fit' or stage is None:
            self.train_dataset = datasets.CIFAR10(self.data_dir, train = True, transform = self.transform) # training dataset
            train_size = int(0.9 * len(self.train_dataset)) # 90% training, 10% validation
            val_size = len(self.train_dataset) - train_size 
            # split the training dataset into training and validation datasets based on the sizes 
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.train_dataset, [train_size, val_size])

        if stage == 'test' or stage is None:
            self.test_dataset = datasets.CIFAR10(self.data_dir, train = False, transform = self.test_transform)

    # Define data loaders for training，shuffle the training data
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers)

    # Define data loaders for validation
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size = self.batch_size, num_workers = self.num_workers)

    # Define data loaders for testing
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size = self.batch_size, num_workers = self.num_workers)


## 3.2 Define Base Model, Log Metrics and Extra Parameters Information

In [6]:
# Define the base model and log the hyperparameters, training, validation, and test metrics, and extra parameters information

# Base model class
class BaseModel(pl.LightningModule):
    def __init__(self,
                 model_name: str, 
                 learning_rate: float = 1e-3, 
                 optimizer: str = 'Adam', 
                 augmentation_strength: str = 'basic' 
                 ):
        super().__init__()
        self.learning_rate = learning_rate  # learning rate for the optimizer
        self.model_name = model_name  # model name, can be 'alex_net', 'basic_cnn', or 'res_net_18'
        self.optimizer = optimizer # optimizer, can be 'Adam' or 'SGD'
        self.augmentation_strength = augmentation_strength # augmentation strength, can be 'basic' or 'strong'

        # Initialize the appropriate model
        if model_name == 'alex_net':
            self.model = AlexNetSmall(num_classes = 10)
        elif model_name == 'basic_cnn':
            self.model = BasicCNN(num_classes = 10)
        elif model_name == 'res_net_18':
            self.model = ResNet18(num_classes = 10)
        else:
            raise ValueError(f"Unknown model name: {model_name}") # raise an error if the model name is unknown

        # Save hyperparameters to the log file, accessible by all the callbacks
        self.save_hyperparameters()

        # Metrics for training, validation, and testing
        self.train_acc = Accuracy(task = 'multiclass', num_classes = 10) # task is multiclass classification, num_classes is 10
        self.val_acc = Accuracy(task = 'multiclass', num_classes = 10)
        self.test_acc = Accuracy(task = 'multiclass', num_classes = 10)

    # Forward pass of the model to compute the output
    def forward(self, x):
        return self.model(x)

    #define the optimizer and the learning rate scheduler
    def configure_optimizers(self):
        if self.hparams.optimizer == 'Adam': # Adam optimizer
            optimizer = torch.optim.AdamW(self.parameters(), lr = self.learning_rate)
        elif self.hparams.optimizer == 'SGD': # SGD optimizer
            optimizer = torch.optim.SGD(self.parameters(), lr = self.learning_rate)
        else:
            raise ValueError(f"Unknown optimizer: {self.hparams.optimizer}")
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 200) # CosineAnnealingLR learning rate scheduler
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss" # monitor the validation loss for the learning rate scheduler
            }
        }

    # Training step to compute the loss and update the weights，log the training loss and accuracy
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch # input and target labels
        logits = self(x) # forward pass
        loss = F.cross_entropy(logits, y) # cross-entropy loss

        # Log metrics
        acc = self.train_acc(logits.softmax(dim = -1), y) # compute accuracy using the softmax of the logits and the target labels
        self.log('train_loss', loss, prog_bar = True) # log the training loss
        self.log('train_acc', acc, prog_bar = True) # log the training accuracy

        # Add extra parameters information
        self.logger.experiment.add_scalar('parameters/learning_rate', self.learning_rate) # learning rate to the log file 
        self.logger.experiment.add_text('parameters/optimizer', self.optimizer) # optimizer to the log file
        self.logger.experiment.add_text('parameters/augmentation_strength', self.hparams.augmentation_strength) 

        return loss

    # Validation step to compute the loss and accuracy, log the validation loss and accuracy
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch # input and target labels
        logits = self(x) 
        loss = F.cross_entropy(logits, y)

        # Log metrics 
        acc = self.val_acc(logits.softmax(dim = -1), y)
        self.log('val_loss', loss, prog_bar = True) 
        self.log('val_acc', acc, prog_bar = True)

        # Add extra parameters information
        self.logger.experiment.add_scalar('parameters/learning_rate', self.learning_rate)
        self.logger.experiment.add_text('parameters/optimizer', self.optimizer)
        self.logger.experiment.add_text('parameters/augmentation_strength', self.hparams.augmentation_strength)

        return loss

    # Test step to compute the loss and accuracy, log the test loss and accuracy
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x, y = batch # input and target labels
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        # Log metrics
        acc = self.test_acc(logits.softmax(dim = -1), y)
        self.log('test_loss', loss)
        self.log('test_acc', acc)

        # Add extra parameters information
        self.logger.experiment.add_scalar('parameters/learning_rate', self.learning_rate)
        self.logger.experiment.add_text('parameters/optimizer', self.optimizer)
        self.logger.experiment.add_text('parameters/augmentation_strength', self.hparams.augmentation_strength)

        return loss

# 4. Train Model

In [7]:
def train_model(
        model_name: str,
        learning_rate: float = 1e-3, 
        optimizer: str = 'Adam', 
        augmentation_strength: str = 'basic',
        max_epochs: int = 20 # maximum number of epochs
):
    # Create data module
    data_module = CIFAR10DataModule(augmentation_strength = augmentation_strength)

    # Create model
    model = BaseModel(model_name = model_name, learning_rate = learning_rate, optimizer = optimizer)

    version = f"lr_{learning_rate}_opt_{optimizer}_aug_{augmentation_strength}" # version based on hyperparameters

    # Create logger for TensorBoard visualization and logging the metrics 
    logger = TensorBoardLogger(
        'lightning_logs', # directory to store the logs
        name = model_name, # model name
        version = version # version based on hyperparameters
    )

    # Create checkpoint callback to save the best model based on validation loss
    checkpoint_callback = ModelCheckpoint(
        monitor = 'val_loss', # monitor the validation loss for saving the best model
        dirpath = f'checkpoints/{model_name}', # directory to store the checkpoints
        filename = '{epoch:02d}-{val_loss:.2f}', # file name format for the checkpoints
        save_top_k = 3, # save the top 3 models based on validation loss
        mode = 'min' # minimize the validation loss
    )

    # Create early stopping callback to stop the training if the validation loss does not improve for 10 epochs
    early_stopping = EarlyStopping(
        monitor = 'val_loss',
        patience = 10,
        mode = 'min'
    )

    # Create trainer to train the model
    trainer = pl.Trainer(
        max_epochs = max_epochs, # maximum number of epochs
        accelerator = 'auto', # use GPU if available
        devices = "auto", # use GPU if available
        strategy = "auto", # use distributed training if available
        logger = logger, # logger for TensorBoard visualization
        callbacks = [checkpoint_callback, early_stopping], # list of callbacks 
        deterministic = True # set to True for reproducibility, but it may slow down the training
    )

    # Train model
    trainer.fit(model, data_module)

    # Test model
    trainer.test(model, data_module)


# 5. Main Process

In [None]:
def main():
    # Define the grid of hyperparameters
    param_grid = {
        'models': ['basic_cnn', 'alex_net',  'res_net_18'],
        'learning_rate': [1e-2, 1e-3, 1e-4],
        'optimizer': ['Adam', 'SGD'],
        'augmentation_strength': ['basic', 'strong']
    }
    # Train all models with different hyperparameters
    for model_name in param_grid['models']:
        for lr in param_grid['learning_rate']:
            for optimizer in param_grid['optimizer']:
                for augmentation_strength in param_grid['augmentation_strength']:
                    print(
                        f"Training model: {model_name}, lr: {lr}, optimizer: {optimizer}, augmentation_strength: {augmentation_strength}")
                    train_model(model_name, lr, optimizer, augmentation_strength)


main()

![Loss](./train_loss_all.png)

![valLoss](./val_loss_all.png)

![trainacc](./train_acc_all.png)

![valacc](./val_acc_all.png)