In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
def get_data_loaders():

    print("Creating data loaders")

    BATCH_SIZE = 1024

    dataloader_kwargs = {
        'batch_size': BATCH_SIZE,
        'num_workers': 1,
        'persistent_workers': True,
        'pin_memory': True,
        'shuffle': True
    }

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

    train_dataset_full = datasets.MNIST(root='./data', train=True, download=True, transform=data_transform)

    train_size = int(0.85 * len(train_dataset_full))
    val_size = len(train_dataset_full) - train_size

    seed = 42
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_dataset_full, [train_size, val_size], generator=torch.Generator().manual_seed(seed)
    )

    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=data_transform)

    dataloaders = {}
    dataloaders["train"] = torch.utils.data.DataLoader(train_dataset,**dataloader_kwargs)
    dataloaders["val"] = torch.utils.data.DataLoader(val_dataset, **dataloader_kwargs)
    dataloaders["test"] = torch.utils.data.DataLoader(test_dataset, **dataloader_kwargs)

    print("Data loaders created successfully")

    return dataloaders

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
def train(model, device, dataloaders, num_epochs):

    print("\nStarting training...")

    learning_rate = 1.0
    optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

    for epoch in range(num_epochs):
        print()
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_correct = 0

            # Wrap dataloader with tqdm for progress bar
            loop = tqdm(dataloaders[phase], desc=f"{phase.capitalize()} Phase", leave=False)

            for inputs, labels in loop:
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = F.nll_loss(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_correct += torch.sum(preds == labels)

                # Update tqdm description dynamically
                loop.set_postfix({
                    "Loss": f"{loss.item():.4f}",
                    "Batch Acc": f"{(preds == labels).float().mean().item():.4f}"
                })

            if phase == "train":
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_correct.double() / len(dataloaders[phase].dataset)

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
            
    torch.save(model.state_dict(), "mnist_cnn.pt")
    print("\nTraining complete.")
            

In [5]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import os

def log_confusion_matrix(model, dataloader, device):
    all_preds = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='Blues')

    os.makedirs("artifacts", exist_ok=True)
    cm_path = f"artifacts/confusion_matrix.png"
    plt.savefig(cm_path)
    plt.close()


In [6]:
def log_misclassified_images(model, dataloader, device, max_images=25):
    model.eval()
    errors = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            mismatches = preds != labels

            for i in range(inputs.size(0)):
                if mismatches[i]:
                    errors.append((inputs[i].cpu(), preds[i].item(), labels[i].item()))
                if len(errors) >= max_images:
                    break
            if len(errors) >= max_images:
                break

    if errors:
        fig, axes = plt.subplots(5, 5, figsize=(10, 10))
        for ax, (img, pred, true) in zip(axes.flat, errors):
            ax.imshow(img.squeeze(), cmap="gray")
            ax.set_title(f"P: {pred} / T: {true}")
            ax.axis('off')
        plt.tight_layout()

        img_path = "artifacts/misclassified.png"
        plt.savefig(img_path)
        plt.close()


In [7]:
def test(model, device, test_loader):
    print("\nStarting testing...")

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            test_loss += F.nll_loss(outputs, labels, reduction='sum').item()
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels)

    test_loss /= len(test_loader.dataset)
    test_acc = correct.double() / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * test_acc))
    
    log_confusion_matrix(model, test_loader, device)

    log_misclassified_images(model, test_loader, device)



In [8]:
def training_workflow():

    torch.manual_seed(42)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device}\n')

    dataloaders = get_data_loaders()

    model = Net().to(device)

    train(model, device, dataloaders, num_epochs=5)
    
    test(model, device, dataloaders["test"])

In [9]:
training_workflow()

Device: cuda

Creating data loaders
Data loaders created successfully

Starting training...


Epoch 1/5
------------------------------


                                                                                           

train Loss: 0.7914 Acc: 0.7515


                                                                                       

val Loss: 0.2091 Acc: 0.9360


Epoch 2/5
------------------------------


                                                                                           

train Loss: 0.1694 Acc: 0.9507


                                                                                       

val Loss: 0.1059 Acc: 0.9682


Epoch 3/5
------------------------------


                                                                                           

train Loss: 0.1168 Acc: 0.9649


                                                                                       

val Loss: 0.0755 Acc: 0.9784


Epoch 4/5
------------------------------


                                                                                           

train Loss: 0.0917 Acc: 0.9722


                                                                                       

val Loss: 0.0643 Acc: 0.9818


Epoch 5/5
------------------------------


                                                                                           

train Loss: 0.0814 Acc: 0.9760


                                                                                       

val Loss: 0.0595 Acc: 0.9824

Training complete.

Starting testing...

Test set: Average loss: 0.0469, Accuracy: 9843/10000 (98.43%)

