In [1]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [8]:
# Train, val split
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=64)
val_loader = DataLoader(val, batch_size=32)


In [9]:
# Model Architecture
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(), 
    nn.Linear(64, 64),
    nn.ReLU(), 
    nn.Linear(64, 10)
)

In [10]:
# Defining a more flexible model
class ResNet(nn.Module): # --> Model with residual connections
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)
    
    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits

model = ResNet().cuda()

In [11]:
# Optimizer
params = model.parameters()
optimizer = optim.SGD(params, lr=1e-2)

In [12]:
# Loss function
loss = nn.CrossEntropyLoss()

In [13]:
# Training and Validation

epochs = 5
for epoch in range(epochs):
    losses = list()
    accuracies = list()
    model.train() # --> Since Dropouts are used
    for batch in train_loader:
        x, y = batch

        # x: b * 1 * 28 * 28
        b = x.size(0)
        x = x.view(b, -1).cuda()

        # 1: Forward Prop
        logits = model(x)

        # 2: Compute Objective / Loss Function
        J = loss(logits, y.cuda())

        # 3: Cleaning / Updating the Gradients
        model.zero_grad()

        # 4: Accumulate the Partial Derivatives of J wrt params
        J.backward()

        # 5: Step on the opposite direction of the gradient
        optimizer.step()
        # with torch.no_grad(): params = params - eta * params.grad
        losses.append(J.item())
        accuracies.append(y.eq(logits.detach().argmax(dim=1).cpu()).float().mean())

    print(f'Epoch {epoch+1}', end=':\n') 
    print(f'Training Loss: {torch.tensor(losses).mean():.2f}', end=' | ')
    print(f'Training Accuracy: {torch.tensor(accuracies).mean():.2f}')


    losses = list()
    accuracies = list()
    model.eval()
    for batch in val_loader:
        x, y = batch

        # x: b * 1 * 28 * 28
        b = x.size(0)
        x = x.view(b, -1).cuda()

        # 1: Forward Prop
        with torch.no_grad():
            logits = model(x)

        # 2: Compute Objective / Loss Function
        J = loss(logits, y.cuda())

        losses.append(J.item())
        accuracies.append(y.eq(logits.detach().argmax(dim=1).cpu()).float().mean())

    print(f'Validation Loss: {torch.tensor(losses).mean():.2f}', end=' | ')
    print(f'Validation Accuracy: {torch.tensor(accuracies).mean():.2f}\n')

Epoch 1:
Training Loss: 1.18 | Training Accuracy: 0.70
Validation Loss: 0.54 | Validation Accuracy: 0.87

Epoch 2:
Training Loss: 0.49 | Training Accuracy: 0.86
Validation Loss: 0.38 | Validation Accuracy: 0.89

Epoch 3:
Training Loss: 0.39 | Training Accuracy: 0.89
Validation Loss: 0.33 | Validation Accuracy: 0.90

Epoch 4:
Training Loss: 0.34 | Training Accuracy: 0.90
Validation Loss: 0.30 | Validation Accuracy: 0.91

Epoch 5:
Training Loss: 0.31 | Training Accuracy: 0.91
Validation Loss: 0.28 | Validation Accuracy: 0.92

