In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import SVHN dataset
from torchvision.datasets import SVHN

torch.cuda.empty_cache()

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [65]:
# Data preparation with normalization specific to SVHN
transform_train = transforms.Compose([
    transforms.Resize((224, 224)), # Resize images to 224x224 for models like ResNet
    transforms.ToTensor(), # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970]),
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4377, 0.4438, 0.4728], std=[0.1980, 0.2010, 0.1970]),
])

batch_size = 32

# Load SVHN dataset
train_set = datasets.SVHN(root='./data', split='train', download=True, transform=transform_train)
test_set = datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

classes = train_set.labels

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


In [66]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet32(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_channels = 16

        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)

        self.layer1 = self._make_layer(16, 5, stride=1)
        self.layer2 = self._make_layer(32, 5, stride=2)
        self.layer3 = self._make_layer(64, 5, stride=2)

        # adaptive pooling ensures correct size regardless of input
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, out_channels, blocks, stride):
        strides = [stride] + [1]*(blocks-1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))  # fixed pooling
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [67]:
# Load a pre-defined model (ResNet32) without pre-trained weights
model = ResNet32(num_classes=len(classes)).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total trainable parameters: {num_params:,}')

Total trainable parameters: 5,227,961


In [68]:
criterion = nn.CrossEntropyLoss() # Loss function for multi-class classification
optimiser = optim.Adam(model.parameters(), lr=0.001) # Adam optimiser

In [None]:
best_acc = 0.0
patience = 8
patience_counter = 0

# Training loop
for epoch in range(100):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}', unit='batch')

    for i, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device), labels.to(device)

        optimiser.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimiser.step()

        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            progress_bar.set_postfix({'loss': running_loss / 100})
            running_loss = 0.0
    
    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Evaluating'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_acc = correct / total
    val_loss /= len(test_loader)

    if val_acc > best_acc:
        best_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pth")
        print("  ✓ New best model saved!")
    else:
        patience_counter += 1

    # ---- PERIODIC CHECKPOINT ----
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}.pth")
        print(f"  ✓ Checkpoint saved at epoch {epoch+1}")

    # ---- EARLY STOPPING ----
    if patience_counter >= patience:
        print("Early stopping triggered!")
        break

    print(f"Epoch {epoch+1}: val_acc={val_acc:.4f}, val_loss={val_loss:.4f}")

print('Training complete.')

Epoch 1/10:   6%|▋         | 145/2290 [00:42<10:31,  3.40batch/s, loss=4.14]


KeyboardInterrupt: 

In [None]:
# Save model state dict
save_path = "svhn_cnn.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

In [None]:
# Re-create the model architecture first
model = ResNet32(num_classes=len(classes)).to(device)

# Load weights
model.load_state_dict(torch.load("svhn_cnn.pth", map_location=device))
model.eval()

print("Model loaded and ready!")


In [None]:
model.eval()
correct = 0
total = 0

# Evaluation loop
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Evaluating'):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

In [None]:
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy().transpose(1, 2, 0)
    plt.imshow(npimg)
    plt.axis('off')

def show_predictions(model, test_loader, device, max_correct=10, max_wrong=10):
    model.eval()
    correct_imgs = []
    wrong_imgs = []

    with torch.no_grad():
        for images, labels in test_loader:
            labels[labels == 10] = 0
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            for img, pred, true in zip(images, predicted, labels):
                if pred == true and len(correct_imgs) < max_correct:
                    correct_imgs.append((img.cpu(), int(pred), int(true)))
                elif pred != true and len(wrong_imgs) < max_wrong:
                    wrong_imgs.append((img.cpu(), int(pred), int(true)))

                if len(correct_imgs) == max_correct and len(wrong_imgs) == max_wrong:
                    break
            if len(correct_imgs) == max_correct and len(wrong_imgs) == max_wrong:
                break

    # Display them
    print("\nCorrect Predictions:\n")
    plt.figure(figsize=(12, 4))
    for i, (img, pred, true) in enumerate(correct_imgs):
        plt.subplot(2, 5, i+1)
        imshow(img)
        plt.title(f"Pred: {pred}")
    plt.show()

    print("\nIncorrect Predictions:\n")
    plt.figure(figsize=(12, 4))
    for i, (img, pred, true) in enumerate(wrong_imgs):
        plt.subplot(2, 5, i+1)
        imshow(img)
        plt.title(f"Pred: {pred}, True: {true}")
    plt.show()

# Run it:
show_predictions(model, test_loader, device)
