In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization import QuantStub, DeQuantStub
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from collections import OrderedDict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device={device}")

Device=cuda


In [2]:
from models.cifar_resnet32 import ResNet32

model = ResNet32().to('cuda' if torch.cuda.is_available() else 'cpu')
print(model)

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=

In [18]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy

def load_dataset(path='./data', batch_size=64, num_workers=4):
    print("Loading the CIFAR10 dataset")

    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(root=path, train=True, download=True, transform=transform)
    test_dataset  = datasets.CIFAR10(root=path, train=False, download=True, transform=test_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,   # <-- use workers
        pin_memory=True,           # <-- good for GPU
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size * 2, # eval can use bigger batch
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    print(
        f"Loaded train data: {len(train_loader.dataset)} samples, {len(train_loader)} batches\n"
        f"Loaded test data:  {len(test_loader.dataset)} samples, {len(test_loader)} batches"
    )

    return train_loader, test_loader

In [None]:
train_loader, test_loader = load_dataset(batch_size=128, num_workers=4)

In [5]:
import matplotlib.pyplot as plt
import numpy as np

def plot_metrics(metrics):
  train_losses = metrics.get('train_loss',None)
  test_losses = metrics.get('test_loss',None)
  train_accs = metrics.get('train_acc',None)
  test_accs = metrics.get('test_acc',None)

  epochs = range(1, len(train_losses) + 1)

  plt.figure(figsize=(12, 5))

  # Loss Graph
  plt.subplot(1, 2, 1)
  if train_losses:
    plt.plot(epochs, train_losses, label='Train Loss', marker='o')
  if test_losses:
    plt.plot(epochs, test_losses, label='Test Loss', marker='s')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training vs Test Loss')
  plt.legend()
  plt.grid(True, linestyle='--', alpha=0.6)

  # Accuracy Graph
  plt.subplot(1, 2, 2)
  if train_accs:
    plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')
  if test_accs:
    plt.plot(epochs, test_accs, label='Test Accuracy', marker='s')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy (%)')
  plt.title('Training vs Test Accuracy')
  plt.legend()
  plt.grid(True, linestyle='--', alpha=0.6)

  plt.tight_layout()
  plt.show()

def plot_weight_histogram(model, bins=50):
  all_weights = []

  for name, module in model.named_modules():
      if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
          weights = module.weight.detach().cpu().numpy().flatten()
          all_weights.extend(weights)

  all_weights = np.array(all_weights)

  plt.figure(figsize=(8,6))
  plt.hist(all_weights, bins=bins, color='skyblue', edgecolor='black')
  plt.title("Weight Distribution Histogram")
  plt.xlabel("Weight Value")
  plt.ylabel("Frequency")
  plt.grid(True, alpha=0.3)
  plt.show()

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim

def train_model(model,
                train_loader,
                test_loader,
                device='cpu',
                epochs=10,
                lr=1e-3,
                train=True,
                test=True):
    """
    Generic training helper:
      - uses SGD (to match your ResNet32 training)
      - logs train/test loss & accuracy per epoch
      - returns a metrics dict you can use for plotting
    """
    model.to(device)

    metrics = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=5e-4
    )

    for e in range(epochs):
        print(f"Epoch [{e+1}/{epochs}] ", end='')

        # ---- TRAIN ----
        if train:
            model.train()
            train_loss, total_examples, correct = 0.0, 0, 0

            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * labels.size(0)
                _, pred_ind = outputs.max(1)
                total_examples += labels.size(0)
                correct += pred_ind.eq(labels).sum().item()

            train_loss /= total_examples
            train_acc = 100.0 * correct / total_examples

            metrics["train_loss"].append(train_loss)
            metrics["train_acc"].append(train_acc)

            print(f"Train Loss: {train_loss:.4f}, "
                  f"Train Acc: {train_acc:.2f}% ", end='')

        # ---- TEST / VALIDATION ----
        if test:
            model.eval()
            test_loss, total_examples, correct = 0.0, 0, 0

            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    test_loss += loss.item() * labels.size(0)
                    _, pred_ind = outputs.max(1)
                    total_examples += labels.size(0)
                    correct += pred_ind.eq(labels).sum().item()

            test_loss /= total_examples
            test_acc = 100.0 * correct / total_examples

            metrics["test_loss"].append(test_loss)
            metrics["test_acc"].append(test_acc)

            print(f"Test/Val Loss: {test_loss:.4f}, "
                  f"Test/Val Acc: {test_acc:.2f}%")

    return metrics

In [15]:
from models.cifar_resnet32 import ResNet32
import torch

model = ResNet32(num_classes=10)
model.load_state_dict(torch.load("checkpoints/resnet32_fp32_best.pt", map_location="cpu"))

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

model_size_MB = total_params * 4 / (1024 ** 2)
print(f"Model Size (FP32): {model_size_MB:.2f} MB")

Total Parameters: 466,906
Trainable Parameters: 466,906
Model Size (FP32): 1.78 MB


In [16]:
print(f"Running on: {device}")

Running on: cuda


In [17]:
train, test = True, True
epochs = 100
fp32_metrics = train_model(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    train=train,
    test=test,
    device=device,
    epochs=epochs
)

Epoch [1/100] Train Loss: 0.4591, Train Acc: 84.46% Test/Val Loss: 0.2618, Test/Val Acc: 91.25%
Epoch [2/100] Train Loss: 0.4100, Train Acc: 86.01% Test/Val Loss: 0.2581, Test/Val Acc: 91.49%
Epoch [3/100] Train Loss: 0.4044, Train Acc: 86.28% Test/Val Loss: 0.2531, Test/Val Acc: 91.55%
Epoch [4/100] Train Loss: 0.3849, Train Acc: 86.98% Test/Val Loss: 0.2496, Test/Val Acc: 91.79%
Epoch [5/100] Train Loss: 0.3834, Train Acc: 87.07% Test/Val Loss: 0.2485, Test/Val Acc: 91.66%
Epoch [6/100] Train Loss: 0.3741, Train Acc: 87.26% Test/Val Loss: 0.2450, Test/Val Acc: 91.79%
Epoch [7/100] Train Loss: 0.3695, Train Acc: 87.46% Test/Val Loss: 0.2515, Test/Val Acc: 91.74%
Epoch [8/100] Train Loss: 0.3595, Train Acc: 87.66% Test/Val Loss: 0.2498, Test/Val Acc: 91.72%
Epoch [9/100] Train Loss: 0.3606, Train Acc: 87.71% Test/Val Loss: 0.2460, Test/Val Acc: 91.81%
Epoch [10/100] Train Loss: 0.3558, Train Acc: 87.82% Test/Val Loss: 0.2445, Test/Val Acc: 91.78%
Epoch [11/100] Train Loss: 0.3610, Trai

KeyboardInterrupt: 

In [None]:
import sys
!{sys.executable} -u train_fp32.py

In [None]:
metrics = torch.load("checkpoints/resnet32_fp32_metrics.pt")
plot_metrics(metrics)
plot_weight_histogram(model)

In [None]:
model = ResNet32(num_classes=10)
model.load_state_dict(torch.load("checkpoints/resnet32_fp32_best.pt"))

model = SqueezeNetCIFAR10()
model.load_model("squeezenet_dataaug_cifar10_fp32.pth", device=device)
model.train()

In [None]:
import torch.nn.utils.prune as prune

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name="weight", amount=0.3)

In [None]:
import sys
!{sys.executable} -u train_qat.py

In [None]:
import os
print(os.listdir("checkpoints"))

In [None]:
import torch

qat_metrics = torch.load("checkpoints/resnet32_qat_metrics.pt")
plot_metrics(qat_metrics)