In [None]:
!pip install lightning
!rm logs -rf

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import lightning as pl
from torchmetrics import Accuracy
from pytorch_lightning.loggers import TensorBoardLogger

# Define the LightningModule
class MNISTClassifier(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super(MNISTClassifier, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 128),
            nn.Sigmoid(),
            nn.Linear(128, 128),
            nn.Sigmoid(),
            nn.Linear(128, 10),
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.val_logits = []
        self.val_labels = []
        self.val_accuracy = Accuracy('multiclass', num_classes=10)

    def forward(self, x):
        return self.layers(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        # Store logits and labels for later calculation
        self.val_logits.append(logits)
        self.val_labels.append(y)

    def on_validation_epoch_end(self):
        # Concatenate all stored logits and labels
        val_logits = torch.cat(self.val_logits, dim=0)

        # Calculate the average loss over the entire validation set
        val_labels = torch.cat(self.val_labels, dim=0)
        avg_val_loss = self.loss_fn(val_logits, val_labels)

        # Calculate accuracy using torchmetrics
        avg_val_acc = self.val_accuracy(val_logits, val_labels)

        # Log the average validation loss and accuracy
        self.log('val_loss', avg_val_loss, prog_bar=True)
        self.log('val_acc', avg_val_acc, prog_bar=True)

        # Clear the lists for the next epoch
        self.val_logits = []
        self.val_labels = []
        self.val_accuracy.reset()

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9)

def train_and_log(model, model_name):
    tb_logger = TensorBoardLogger(save_dir="logs/", name=model_name)

    trainer = pl.Trainer(max_epochs=5,
                        accelerator="gpu",
                        devices=1,
                        log_every_n_steps=1000,
                        logger=tb_logger)

    # Train the model
    trainer.fit(model, train_loader, val_loader)

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: torch.flatten(x))
])

train_dataset = MNIST('./mnist', train=True, transform=transform, download=True)
val_dataset = MNIST('./mnist', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

# Instantiate the model
model = MNISTClassifier()

train_and_log(model, 'fcc')

In [None]:
# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = MNIST('./mnist', train=True, transform=transform, download=True)
val_dataset = MNIST('./mnist', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

### exercise: Convolutional network


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

class MNISTCNN(MNISTClassifier):
    def __init__(self, learning_rate=0.001):
        super(MNISTClassifier, self).__init__()

        # edit the follwing to creat layers
        # conv 3x3, 1 -> 32 + ReLU
        self.stem = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                                 nn.ReLU()
                                 )
        # conv 3x3, 32 -> 32 + ReLU
        self.layer0 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU())
        # conv 3x3, 32 -> 32 + ReLU
        self.layer1 = nn.Sequential(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU())
        # make a conv 1x1, 32 -> 10, no activation function
        self.fc = nn.Sequential(nn.Conv2d(32, 10, kernel_size=1, stride=1, padding=0))
        # END OF edit the follwing to creat layers

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.val_logits = []
        self.val_labels = []
        self.val_accuracy = Accuracy('multiclass', num_classes=10)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer0(x)
        x = self.max_pool(x)
        x = self.layer1(x)
        x = self.max_pool(x)
        x = self.fc(x)
        x = x.mean([2, 3])
        return x

model = MNISTCNN(0.001)
train_and_log(model, 'cnn')

In [None]:
class MNISTBNCNN(MNISTCNN):
    def __init__(self, learning_rate=0.001):
        super(MNISTClassifier, self).__init__()

        # edit the follwing to creat layers
        # add nn.BatchNorm2d right after conv2d
        self.stem = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                                 ...
                                 )
        # conv 3x3, 32 -> 32 + ReLU
        # add nn.BatchNorm2d right after conv2d
        self.layer0 = nn.Sequential(...)

        # conv 3x3, 32 -> 32 + ReLU
        # add nn.BatchNorm2d right after conv2d
        self.layer1 = nn.Sequential(...)

        # make a conv 1x1, 32 -> 10, no activation function
        self.fc = nn.Sequential(...)
        # END OF edit the follwing to creat layers

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.val_logits = []
        self.val_labels = []
        self.val_accuracy = Accuracy('multiclass', num_classes=10)

model = MNISTBNCNN(0.001)
train_and_log(model, 'cnn_batchnorm')

**exercise**: skip connection

hint:
x = x + f(x)

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger

class MNISTSKIP(MNISTCNN):
    def forward(self, x):
        # edit the follwing to make skip connection
        x = self.stem(x)
        x = self.layer0(x) + x
        x = self.max_pool(x)
        x = self.layer1(x)
        x = self.max_pool(x)
        x = self.fc(x)
        # ENDOF edit the follwing to make skip connection
        x = x.mean([2, 3])
        return x

model = MNISTSKIP(0.001)
train_and_log(model, 'skip')

In [None]:
#!pip install tensorboard
%load_ext tensorboard

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
%tensorboard --logdir logs/