In [None]:
# ===============================================================
# 0.  Set‑up: installs and imports
# ===============================================================
!pip install --quiet onnx onnxruntime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR          # bonus scheduler
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import onnx, onnxruntime as ort
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# ===============================================================
# 1.  Dataset — normalize pixels to [0,1]  
# ===============================================================
transform = transforms.Compose([
    transforms.ToTensor(),                 # already scales to [0,1]
])

full_mnist = datasets.MNIST(root=".", download=True, train=True, transform=transform)

# ===============================================================
# 2.  Train / Test split — 2/3 vs 1/3 
# ===============================================================
train_len   = int(len(full_mnist) * 2/3)
test_len    = len(full_mnist) - train_len
train_set, test_set = random_split(full_mnist, [train_len, test_len])

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
# ===============================================================
# 3.  CNN model with the required architecture 
# ===============================================================
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(           # ----- feature extractor -----
            nn.Conv2d(1, 32, kernel_size=3),     # (28,28,1) -> (26,26,32)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                     # (13,13,32)

            nn.Conv2d(32, 64, kernel_size=3),    # (11,11,64)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                     # (5,5,64)
        )

        self.classifier = nn.Sequential(         # ----- classifier -----
            nn.Flatten(),                        # 5*5*64 = 1600
            nn.Linear(1600, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(128, 10)                   # raw logits
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MNIST_CNN().to(device)


In [None]:
# ===============================================================
# 4.  Training hyper‑parameters 
# ===============================================================
epochs       = 10
lr           = 1e-3
criterion    = nn.CrossEntropyLoss()
optimizer    = optim.Adam(model.parameters(), lr=lr)

# BONUS 5.  Learning‑rate scheduler 
scheduler = StepLR(optimizer, step_size=3, gamma=0.3)   # decay LR every 3 epochs


In [None]:
# ===============================================================
# Training & evaluation loops
# ===============================================================
def run_epoch(loader, train=True):
    model.train(mode=train)
    total_loss, correct, total = 0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        if train:
            optimizer.zero_grad()

        logits = model(x)
        loss   = criterion(logits, y)

        if train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total   += y.size(0)

    return total_loss / total, correct / total


for epoch in range(1, epochs+1):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    test_loss,  test_acc  = run_epoch(test_loader,  train=False)
    scheduler.step()

    print(f"Epoch {epoch:02d} | "
          f"Train Loss {train_loss:.4f} Acc {train_acc:.3f} | "
          f"Test Loss {test_loss:.4f}  Acc {test_acc:.3f}")


In [None]:
# ===============================================================
# 6.  Export model to ONNX 
# ===============================================================
dummy = torch.randn(1, 1, 28, 28, device=device)
onnx_path = "mnist_cnn.onnx"
torch.onnx.export(model, dummy, onnx_path,
                  input_names=['input'], output_names=['logits'],
                  dynamic_axes={'input': {0: 'batch'}, 'logits': {0: 'batch'}},
                  opset_version=13)
print(f"ONNX model saved to {onnx_path}")


In [None]:
# ===============================================================
# 7.  Load ONNX model & test on 5 random images 
# ===============================================================
# -- ONNX Runtime session
ort_sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])

# pick 5 random samples from *test_set*
idx = torch.randint(0, len(test_set), (5,))
images, labels = zip(*[test_set[i] for i in idx])
images = torch.stack(images)        # (5,1,28,28)
labels = torch.tensor(labels)

# run inference
logits = ort_sess.run(None, {'input': images.numpy()})[0]
preds  = torch.tensor(logits).argmax(dim=1)

# visual check
fig, axs = plt.subplots(1, 5, figsize=(12,2))
for i,(img, gt, pd) in enumerate(zip(images, labels, preds)):
    axs[i].imshow(img.squeeze(), cmap='gray')
    axs[i].set_title(f"GT:{gt}  Pred:{pd}", color=("green" if gt==pd else "red"))
    axs[i].axis('off')
plt.show()


In [None]:
from google.colab import files
files.download("mnist_cnn.onnx")   # pops a browser “Save as…” dialog