In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.quantization
from tqdm import tqdm
import pandas as pd
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
import matplotlib.patches as mpatches
import torch.nn.utils.prune as prune
import torch.onnx
import torch_pruning as tp
import tempfile
from models.LeNet5 import LeNet5
from pathlib import Path

In [2]:
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
saved_model_path = Path("./saved_models/lenet/")
saved_model_path.mkdir(parents=True, exist_ok=True)

In [3]:
# ======= Data Preparation
torch.manual_seed(42)
# === 3. Data ===
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

In [4]:
# ======= model training function
def train_model(model, loader, criterion, epochs):
    model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    batch_log, image_log = [], []
    for epoch in range(epochs):
        for batch_idx, (images, labels) in enumerate(tqdm(loader)):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            preds = outputs.argmax(dim=1)
            acc = preds.eq(labels).sum().item() / len(labels)
            batch_log.append({"epoch": epoch+1, "batch": batch_idx, "loss": loss.item(), "accuracy": acc})

            if batch_idx % 100 == 0:
                print(f"Batch training loss {loss.item():.4f} | training accuracy {acc:.4f} at step {batch_idx}")
                
            probs = F.softmax(outputs, dim=1)
            confs, pred_labels = probs.max(dim=1)
            for i in range(len(images)):
                image_log.append({"epoch": epoch+1, "batch": batch_idx, "true": labels[i].item(), "pred": pred_labels[i].item(), "conf": confs[i].item()})

    return model

In [5]:
# ====== Baseline Model Training
baseline_model = LeNet5()
criterion = nn.CrossEntropyLoss()
baseline_model = train_model(baseline_model, train_loader, criterion, epochs=2)

  0%|▏                                          | 4/938 [00:00<02:05,  7.45it/s]

Batch training loss 2.2973 | training accuracy 0.0625 at step 0


 11%|████▌                                    | 103/938 [00:05<00:36, 22.86it/s]

Batch training loss 0.4365 | training accuracy 0.8438 at step 100


 22%|████████▊                                | 202/938 [00:09<00:32, 22.76it/s]

Batch training loss 0.1114 | training accuracy 0.9844 at step 200


 32%|█████████████▏                           | 302/938 [00:13<00:28, 22.42it/s]

Batch training loss 0.1357 | training accuracy 0.9531 at step 300


 43%|█████████████████▋                       | 404/938 [00:18<00:23, 23.07it/s]

Batch training loss 0.0941 | training accuracy 0.9844 at step 400


 54%|█████████████████████▉                   | 503/938 [00:22<00:19, 22.46it/s]

Batch training loss 0.0324 | training accuracy 1.0000 at step 500


 64%|██████████████████████████▎              | 602/938 [00:27<00:14, 22.63it/s]

Batch training loss 0.3778 | training accuracy 0.9375 at step 600


 75%|██████████████████████████████▊          | 704/938 [00:31<00:10, 22.53it/s]

Batch training loss 0.0973 | training accuracy 0.9688 at step 700


 86%|███████████████████████████████████      | 803/938 [00:36<00:05, 24.13it/s]

Batch training loss 0.1411 | training accuracy 0.9375 at step 800


 96%|███████████████████████████████████████▍ | 902/938 [00:40<00:01, 22.44it/s]

Batch training loss 0.0077 | training accuracy 1.0000 at step 900


100%|█████████████████████████████████████████| 938/938 [00:42<00:00, 22.24it/s]
  0%|▏                                          | 3/938 [00:00<00:36, 25.68it/s]

Batch training loss 0.0254 | training accuracy 1.0000 at step 0


 11%|████▌                                    | 105/938 [00:04<00:32, 25.66it/s]

Batch training loss 0.0468 | training accuracy 0.9844 at step 100


 22%|████████▉                                | 204/938 [00:08<00:32, 22.49it/s]

Batch training loss 0.1072 | training accuracy 0.9531 at step 200


 32%|█████████████▏                           | 303/938 [00:12<00:28, 22.48it/s]

Batch training loss 0.1595 | training accuracy 0.9688 at step 300


 43%|█████████████████▌                       | 402/938 [00:17<00:23, 22.49it/s]

Batch training loss 0.0323 | training accuracy 0.9844 at step 400


 54%|██████████████████████                   | 504/938 [00:21<00:19, 22.38it/s]

Batch training loss 0.0378 | training accuracy 0.9844 at step 500


 64%|██████████████████████████▎              | 603/938 [00:26<00:14, 22.34it/s]

Batch training loss 0.0775 | training accuracy 0.9688 at step 600


 75%|██████████████████████████████▋          | 702/938 [00:30<00:10, 22.25it/s]

Batch training loss 0.0152 | training accuracy 1.0000 at step 700


 86%|███████████████████████████████████▏     | 804/938 [00:35<00:05, 22.62it/s]

Batch training loss 0.0429 | training accuracy 0.9844 at step 800


 96%|███████████████████████████████████████▍ | 903/938 [00:39<00:01, 22.82it/s]

Batch training loss 0.0812 | training accuracy 0.9844 at step 900


100%|█████████████████████████████████████████| 938/938 [00:41<00:00, 22.80it/s]


In [6]:
torch.save(baseline_model.state_dict(), saved_model_path / 'lenet5_baseline_weights.pth')
torch.save(baseline_model, saved_model_path / 'lenet5_baseline_model.pth')