In [1]:
import torch
import torchvision
import torch.nn as nn
from enum import Enum
import torch.optim as optim
import torchvision.transforms as transforms
from tqdm import tqdm

In [2]:
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

DEVICE = torch.device("cpu")
DEVICE

device(type='cpu')

In [3]:
class HyperParameters(Enum):
    DEVICE = DEVICE
    EPOCHS = 20
    BATCH_SIZE = 8
    LEARNING_RATE = 0.001

In [4]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5071, 0.4867, 0.4408),
        std=(0.2675, 0.2565, 0.2761)
    )
])

### Data preparation

In [5]:
from pathlib import Path

dataset_path = Path("./data/cifar-100-python.tar.gz")
if dataset_path.exists():
    should_download_data = False
else:
    should_download_data = True

should_download_data

False

In [6]:
train_set = torchvision.datasets.CIFAR100(root="./data",
                                          train=True,
                                          download=should_download_data,
                                          transform= transforms)

train_loader = torch.utils.data.DataLoader(dataset=train_set, 
                                           batch_size= HyperParameters.BATCH_SIZE.value,
                                           shuffle=True,
                                           num_workers=2)

test_set = torchvision.datasets.CIFAR100(root="./data",
                                         train=False,
                                         download=should_download_data,
                                         transform= transforms)

test_loader = torch.utils.data.DataLoader(dataset= test_set,
                                          batch_size= HyperParameters.BATCH_SIZE.value,
                                          num_workers=2,
                                          shuffle=False)

### Manual Normalisation

In [7]:
sum_channels = torch.zeros(3)
sum_sq_channels = torch.zeros(3)
num_pixels = 0

for imgs, _ in train_loader:

    B, C, H, W = imgs.shape
    num_pixels += B * H * W
    sum_channels += imgs.sum(dim=[0,2,3])
    sum_sq_channels += (imgs ** 2).sum(dim=[0,2,3])

# 3. Compute mean & std
mean = sum_channels / num_pixels
var  = (sum_sq_channels / num_pixels) - (mean ** 2)
std  = torch.sqrt(var)

print("Mean:", mean.tolist())
print("Std: ", std.tolist())

Mean: [-9.27779110497795e-05, -0.0005891519831493497, 0.00042679786565713584]
Std:  [0.999381959438324, 0.9997610449790955, 1.000183343887329]


### Model

In [8]:
class CIFAR100Net(nn.Module):
    def __init__(self, num_classes=100, dropout_prob=0.5):
        super(CIFAR100Net, self).__init__()

        # ––– Convolutional Layer 1 –––
        self.conv1   = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1     = nn.BatchNorm2d(64)
        self.relu    = nn.ReLU(inplace=True)
        self.drop1   = nn.Dropout2d(dropout_prob)
        self.pool    = nn.MaxPool2d(2, 2)

        # ––– Convolutional Layer 2 –––
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2     = nn.BatchNorm2d(128)
        self.drop2   = nn.Dropout2d(dropout_prob)

        # ––– Convolutional Layer 3 –––
        self.conv3   = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3     = nn.BatchNorm2d(256)
        self.drop3   = nn.Dropout2d(dropout_prob)

        # ––– Dynamically compute flatten size –––
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 32, 32)
            x = self._forward_features(dummy)
            flat_features = x.numel()

        # ––– Fully-connected layers –––
        self.fc1     = nn.Linear(flat_features, 512)
        self.bn_fc1  = nn.BatchNorm1d(512)
        self.drop_fc = nn.Dropout(dropout_prob)
        self.fc2     = nn.Linear(512, num_classes)
    
    def _forward_features(self, x):
        # Shared conv → BN → ReLU → Dropout → Pool (only for block1 pooling)
        x = self.drop1(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(x)
        x = self.drop2(self.relu(self.bn2(self.conv2(x))))
        x = self.drop3(self.relu(self.bn3(self.conv3(x))))
        x = self.pool(x)
        return x
    
    def forward(self, x):
        # 1) Feature extraction via conv blocks
        x = self._forward_features(x)
        # 2) Flatten
        x = x.flatten(start_dim=1)       # ← safe on MPS

        # 3) Classifier head
        x = self.drop_fc(self.relu(self.bn_fc1(self.fc1(x))))
        x = self.fc2(x)
        return x


In [9]:
model = CIFAR100Net().to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr= HyperParameters.LEARNING_RATE.value)

### Training Loop

In [None]:
def train(model, criterion, optimizer, train_loader, EPOCHS):
    
    model.train()

    total_loss = []
    for epoch in range(EPOCHS):
        running_loss = 0.0

        progress_bar = tqdm(train_loader,
                                desc = f"Epoch [{epoch+1}/{EPOCHS}]",
                                unit = "batch")
        
        for i, (images, labels) in enumerate(progress_bar, 1):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            output = model(images)
            loss = criterion(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()  
        
            running_loss += loss.item() 
            progress_bar.set_postfix(avg_loss=running_loss / i)
        
        total_loss.append(running_loss)

    return total_loss

In [11]:
train(model, criterion, optimizer, train_loader, 20)

Epoch [1/20]: 100%|██████████| 6250/6250 [04:05<00:00, 25.51batch/s, avg_loss=4.21]
Epoch [2/20]: 100%|██████████| 6250/6250 [08:00<00:00, 13.00batch/s, avg_loss=3.86]
Epoch [3/20]: 100%|██████████| 6250/6250 [05:33<00:00, 18.73batch/s, avg_loss=3.71]
Epoch [4/20]: 100%|██████████| 6250/6250 [04:54<00:00, 21.19batch/s, avg_loss=3.61]
Epoch [5/20]: 100%|██████████| 6250/6250 [06:28<00:00, 16.09batch/s, avg_loss=3.53]
Epoch [6/20]:   7%|▋         | 455/6250 [00:42<09:03, 10.67batch/s, avg_loss=3.51] 


KeyboardInterrupt: 

### Evaluation

In [None]:
def evaluate(model, test_loader, EPOCH):
    model.eval()
    correct = total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Test Accuracy: {100 * correct / total:.2f}%")