# MNIST CNN Training Pipeline
This notebook trains a simple CNN on the MNIST dataset, then saves the model weights for later FPGA integration.

In [4]:
# 1. Imports & Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import random

# (Optional) Fix seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [5]:
# 2. Dataset Preparation
transform = transforms.ToTensor()
mnist_train = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

batch_size = 64
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [6]:
import sys
sys.path.append('../src')
from mnist_model import MnistCNN
# Instantiate model
model = MnistCNN()
print(model)

MnistCNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [7]:
# 4. Training Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

# Track best model
best_accuracy = 0.0
best_model_wts = None

In [8]:
# 5. Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

Epoch 1/10 - Loss: 0.2312, Accuracy: 0.9298
Epoch 2/10 - Loss: 0.0628, Accuracy: 0.9806
Epoch 3/10 - Loss: 0.0438, Accuracy: 0.9864
Epoch 4/10 - Loss: 0.0342, Accuracy: 0.9896
Epoch 5/10 - Loss: 0.0275, Accuracy: 0.9911
Epoch 6/10 - Loss: 0.0215, Accuracy: 0.9930
Epoch 7/10 - Loss: 0.0179, Accuracy: 0.9941
Epoch 8/10 - Loss: 0.0128, Accuracy: 0.9960
Epoch 9/10 - Loss: 0.0118, Accuracy: 0.9961
Epoch 10/10 - Loss: 0.0109, Accuracy: 0.9964


In [9]:
# 6. Evaluation on Test Set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
test_acc = correct / total
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.9899


In [10]:
# 7. Saving the Model Weights
torch.save(model.state_dict(), '../model/best_mnist_cnn.pth')
print('Model weights saved to ../model/best_mnist_cnn.pth')

Model weights saved to ../model/best_mnist_cnn.pth


## Next Steps: Export for FPGA, Quantization, Reference Inference
- Quantize and export weights as needed
- Run sample inference and compare outputs
- Document accuracy and training details

In [11]:
from quantize import quantize_model_weights
from convert_to_hex import save_model_weights_hex

model = MnistCNN()
model.load_state_dict(torch.load("../model/best_mnist_cnn.pth"))
model.eval()

quantized_params = quantize_model_weights(model, num_bits=8, bias_bits=32)
save_model_weights_hex(quantized_params, base_path="../../weights/")



Saved conv1.weight to ../../weights/conv1_weight.hex
Saved conv1.bias to ../../weights/conv1_bias.hex
Saved conv2.weight to ../../weights/conv2_weight.hex
Saved conv2.bias to ../../weights/conv2_bias.hex
Saved fc1.weight to ../../weights/fc1_weight.hex
Saved fc1.bias to ../../weights/fc1_bias.hex
Saved fc2.weight to ../../weights/fc2_weight.hex
Saved fc2.bias to ../../weights/fc2_bias.hex
Saved requant params to ../../weights/requant_params.json


In [12]:
# 8. Save MNIST test images to hex files for FPGA image_streamer

import os

def save_mnist_hex(image_tensor, filename):
    """
    Save a single MNIST image tensor (1x28x28) as a .hex file.
    Each line contains one pixel in 2-digit hex (00–FF).
    """
    img = (image_tensor.squeeze() * 255).to(torch.uint8).numpy()  # shape (28,28)
    with open(filename, "w") as f:
        for y in range(28):
            for x in range(28):
                pixel = int(img[y, x])
                f.write(f"{pixel:02x}\n")

# Directory for saving test images
img_out_dir = "../../images/"
os.makedirs(img_out_dir, exist_ok=True)

# Save the first N test images
N = 10
for i in range(N):
    img, label = mnist_test[i]  # (1,28,28), label
    out_file = os.path.join(img_out_dir, f"test_img_{i:04d}_label{label}.hex")
    save_mnist_hex(img, out_file)
    print(f"Saved {out_file}")


Saved ../../images/test_img_0000_label7.hex
Saved ../../images/test_img_0001_label2.hex
Saved ../../images/test_img_0002_label1.hex
Saved ../../images/test_img_0003_label0.hex
Saved ../../images/test_img_0004_label4.hex
Saved ../../images/test_img_0005_label1.hex
Saved ../../images/test_img_0006_label4.hex
Saved ../../images/test_img_0007_label9.hex
Saved ../../images/test_img_0008_label5.hex
Saved ../../images/test_img_0009_label9.hex


In [20]:
import json

with open("../../weights/requant_params.json") as f:
    params = json.load(f)

layers = ["conv1.bias", "conv2.bias", "fc1.bias", "fc2.bias"]

with open("../../weights/requant_params.hex", "w") as f:
    for l in layers:
        scale = params[l]["scale_int"]
        shift = params[l]["shift"]
        # pack scale and shift in one 32-bit word: [31:16]=scale, [15:0]=shift
        word = (scale << 16) | (shift & 0xFFFF)
        f.write(f"{word:08x}\n")
