In [1]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [2]:
# Model that we use for MNIST
# model = nn.Sequential(
#     nn.Linear(28 * 28, 64),
#     nn.ReLU(),
#     nn.Linear(64, 64),
#     nn.ReLU(),
#     nn.Linear(64, 10),
# )

In [3]:
# ResNet Model
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear (28 * 28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x): 
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits
    
model = ResNet()

In [4]:
# Optimizer Definition
params = model.parameters()
optimiser = optim.SGD(model.parameters(), lr=1e-2)

In [5]:
# Loss Definition
loss = nn.CrossEntropyLoss()

In [6]:
# Train / Val Data Split
train_data = datasets.MNIST('data', train = True, download = True, transform = transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size = 32)
val_loader = DataLoader(val, batch_size = 32)

In [7]:
# Training Loops
nb_epochs = 5
for epoch in range(nb_epochs):
  losses = list()
  accuracies = list()
  model.train()

  for batch in train_loader:
    x, y = batch

    # x: b x 1 x 28 x 28
    b = x.size(0)
    x = x.view(b, -1)

    # Step 1: Forward
    logit = model(x) 

    # Step 2: Compute the objective function
    J = loss(logit, y)

    # Step 3: Cleaning the gradients
    model.zero_grad()

    # Step 4: Compute the partial derivatives of J with respect to parameters
    J.backward()

    # Step 5: Step in the opposite direction of the gradient
    optimiser.step()

    losses.append(J.item())
    accuracies.append(y.eq(logit.detach().argmax(dim=1)).float().mean())

  print(f"Epoch {epoch + 1}, Loss: {torch.tensor(losses).mean():.2f}")
  print(f"Epoch {epoch + 1}, Accuracy: {torch.tensor(accuracies).mean():.2f}")

  # Validation
  losses = list()
  for batch in val_loader:
    x, y = batch

    # x: b x 1 x 28 x 28
    b = x.size(0)
    x = x.view(b, -1)

    # Step 1: Forward
    with torch.no_grad():
      logit = model(x)

    # Step 2: Compute the objective function
    J = loss(logit, y)

    losses.append(J.item())

  print(f"Epoch {epoch + 1}, Val Loss: {torch.tensor(losses).mean():.2f}")

Epoch 1, Loss: 0.82
Epoch 1, Accuracy: 0.79
Epoch 1, Val Loss: 0.44
Epoch 2, Loss: 0.38
Epoch 2, Accuracy: 0.89
Epoch 2, Val Loss: 0.36
Epoch 3, Loss: 0.32
Epoch 3, Accuracy: 0.91
Epoch 3, Val Loss: 0.30
Epoch 4, Loss: 0.28
Epoch 4, Accuracy: 0.92
Epoch 4, Val Loss: 0.27
Epoch 5, Loss: 0.24
Epoch 5, Accuracy: 0.93
Epoch 5, Val Loss: 0.25
