In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 512
EPOCHS = 10
LEARNING_RATE = 1e-2

In [None]:
def load_data(batch_size=BATCH_SIZE):
    # Training normalization and augmentation
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=0.1307, std=0.3081)
    ])

    # Test normalization (but no augmentation)
    test_transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=0.1307, std=0.3081)
    ])

    # Load the dataset and apply the transformations
    train_data = torchvision.datasets.MNIST('./datafiles/', train=True, download=True, transform=train_transform)
    test_data = torchvision.datasets.MNIST('./datafiles/', train=False, download=True, transform=test_transform)

    # Note: Iterating through the dataloader yields batches of (inputs, targets)
    # Where inputs is a torch.Tensor of shape (B, 1, 28, 28) and targets is a torch.Tensor of shape (B,)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

    return train_data, test_data, train_loader, test_loader

In [None]:
def plot_transformations(train_data):
    fig, axs = plt.subplots(4, 5, figsize=(5, 6))

    plot_images = []
    plot_labels = []

    for i, ax in enumerate(axs.flatten(), start=1000):
        (image, label) = train_data[i]

        # Save this data for later
        plot_images.append(image)
        plot_labels.append(label)

        # Plot each image
        ax.imshow(image.squeeze(), cmap="gray")
        ax.set_title(f"Label: {label}")
        ax.axis("off")
    plt.show()

In [None]:
class MyVGG(nn.Module):
    def __init__(self):
        super(MyVGG, self).__init__()

        # Make the image smaller but deeper
        self.features = nn.Sequential(
            # Block 1: 1x28x28 --> 2x14x14
            nn.Conv2d(1, 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2: 2x14x14 --> 4x7x7
            nn.Conv2d(2, 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3: 4x7x7 --> 8x3x3
            nn.Conv2d(4, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # Make predictions based on the fully connected layers
        self.classifier = nn.Sequential(
            # Block 4: 8x3x3 --> 30 --> 10 classes
            nn.Linear(8 * 3 * 3, 30),
            nn.ReLU(),
            nn.Linear(30, 10)
        )

    def forward(self, x):
        x = self.features(x)        # Feature extractor
        x = x.view(x.size(0), -1)   # Flatten the features into [batch_size, channels * height * width]
        x = self.classifier(x)      # Classifier
        return x

In [None]:
def train(model, train_loader, loss_fn, optimizer, epoch=-1):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    model.train()

    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)
        loss = loss_fn(outputs, targets.to(DEVICE))
        loss.backward()
        optimizer.step()

        # Track some values to compute statistics
        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=-1)
        all_predictions.extend(preds.detach().cpu().tolist())
        all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} --> Train loss = {final_loss:.2f}, Train accuracy = {acc * 100:.3f}%")
    return acc, final_loss

In [None]:
def test(model, test_loader, loss_fn, epoch=-1):
    total_loss = 0
    all_predictions = []
    all_targets = []

    model = model.to(DEVICE)
    model.eval()
    
    for i, (inputs, targets) in enumerate(test_loader):
        with torch.no_grad():
            outputs = model(inputs.to(DEVICE))
            loss = loss_fn(outputs, targets.to(DEVICE))

            # Track some values to compute statistics
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=-1)
            all_predictions.extend(preds.detach().cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    acc = accuracy_score(all_targets, all_predictions)
    final_loss = total_loss / len(test_loader)
    print(f"Epoch {epoch + 1} --> Test loss = {final_loss:.2f}, Test accuracy = {acc * 100:.3f}%")
    return acc, final_loss

In [None]:
train_data, test_data, train_loader, test_loader = load_data()
plot_transformations(train_data)

In [None]:
torch.manual_seed(0)
model = MyVGG()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-8)
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_losses = []
test_losses = []
train_metrics = []
test_metrics = []

for epoch in range(EPOCHS):
    train_acc, train_loss = train(model, train_loader, loss_fn, optimizer, epoch)
    test_acc, test_loss = test(model, test_loader, loss_fn, epoch)

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_metrics.append(train_acc)
    test_metrics.append(test_acc)

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(train_losses, c="r", label="Train loss")
axs[0].plot(test_losses, c="b", label="Test loss")
axs[0].legend()
axs[1].set_xlabel("Epochs")

axs[1].plot(train_metrics, "o-", c="r", label="Train accuracy")
axs[1].plot(test_metrics, "o-", c="b", label="Test accuracy")
axs[1].legend()
axs[1].set_xlabel("Epochs")

plt.show()