# VGG

This notebook is for the blog: [VGG: Deep but Simple](https://derekzhouai.github.io/posts/building-vgg-for-mnist/)

## Model Architecture

In [None]:
import torch.nn as nn

class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_convs):
        super().__init__()
        layers = []
        for _ in range(num_convs):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
            in_channels = out_channels
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

In [None]:
class VGG11(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            # Layer #1: 1 (convolutional layer + ReLU) + MaxPool
            VGGBlock(1, 64, 1),      # 1x224x224 → 64x112x112

            # Layer #2: 1 (convolutional layer + ReLU) + MaxPool
            VGGBlock(64, 128, 1),    # 64x112x112 → 128x56x56

            # Layer #3: 2 (convolutional layers + ReLU) + MaxPool
            VGGBlock(128, 256, 2),   # 128x56x56 → 256x28x28

            # Layer #4: 2 (convolutional layers + ReLU) + MaxPool
            VGGBlock(256, 512, 2),   # 256x28x28 → 512x14x14

            # Layer #5: 2 (convolutional layers + ReLU) + MaxPool
            VGGBlock(512, 512, 2),   # 512x14x14 → 512x7x7

            nn.Flatten(),            # 512x7x7 → 25088

            # Layer #6:
            nn.Linear(25088, 4096),  # 25088 → 4096
            nn.ReLU(),
            nn.Dropout(0.5),

            # Layer #7:
            nn.Linear(4096, 4096),  # 4096 → 4096
            nn.ReLU(),
            nn.Dropout(0.5),

            # Layer #8:
            nn.Linear(4096, num_classes), # 4096 → 10
        )

    def forward(self, x):
        return self.net(x)

## Model Training

### Preparing the Data

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def get_data_loaders(batch_size=128):
    transform = transforms.Compose([
        transforms.Resize(224),    # Upscale 28x28 → 224x224
        transforms.ToTensor()
    ])

    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader

In [None]:
train_loader, test_loader = get_data_loaders(batch_size=128)
print(f"Number of training samples: {len(train_loader.dataset)}")
print(f"Number of test samples: {len(test_loader.dataset)}")

In [None]:
X, y = next(iter(train_loader))
print(f"X.shape: {X.shape}")
print(f"y.shape: {y.shape}")

### Training Loop

In [None]:
import torch

def evaluate(model, loader, loss, device):
    model.eval()
    total_loss, total_correct, total_num = 0.0, 0, 0

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            
            y_hat = model(X)
            l = loss(y_hat, y)

            total_loss += l.item() * X.size(0)
            total_correct += (y_hat.argmax(dim=1) == y).sum().item()
            total_num += X.size(0)
            
    return total_loss / total_num, total_correct / total_num

In [None]:
def train(model, num_epochs, batch_size, lr, device):
    model.to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()

    train_loader, test_loader = get_data_loaders(batch_size)

    for epoch in range(num_epochs):
        model.train()
        total_loss, total_correct, total_num = 0.0, 0, 0

        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
        
            optimizer.zero_grad()
            y_hat = model(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()

            total_loss += l.item() * X.size(0)
            total_correct += (y_hat.argmax(dim=1) == y).sum().item()
            total_num += X.size(0)

        train_loss = total_loss / total_num
        train_acc = total_correct / total_num

        test_loss, test_acc = evaluate(model, test_loader, loss, device)
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train => Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Test => Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

In [None]:
device = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
model = VGG11()
num_epochs = 10
batch_size = 128
lr = 0.1

train(model, num_epochs, batch_size, lr, device)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

## Model Testing

In [None]:
import torch
import matplotlib.pyplot as plt

# FashionMNIST class names
FASHION_CLASSES = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

def show_fashionmnist_preds(model, test_loader, device, n=8):
    """
    Show n FashionMNIST test images with predicted and true labels.
    """
    model.to(device)
    model.eval()
    with torch.no_grad():
        X, y = next(iter(test_loader))
        X, y = X[:n], y[:n]
        pred = model(X.to(device)).argmax(1).cpu()

    plt.figure(figsize=(2*n, 2.6))
    for i in range(n):
        plt.subplot(1, n, i + 1)
        img = X[i].squeeze().cpu()
        plt.imshow(img, cmap="gray")
        p_idx, t_idx = pred[i].item(), y[i].item()
        plt.title(f"P:{FASHION_CLASSES[p_idx]}\nT:{FASHION_CLASSES[t_idx]}", fontsize=9)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

show_fashionmnist_preds(model, test_loader, device, n=8)