In [None]:
import os, time
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights

# Mount Drive for Checkpoints
from google.colab import drive
drive.mount('/content/drive')


# Dataset Setup (CIFAR-10)
BATCH_SIZE = 128
IMG_SIZE = 224  # ResNet18 expects 224x224 input
NUM_WORKERS = 8
NUM_CLASSES = 10  # CIFAR-10 has 10 classes

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
val_ds   = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True,
    prefetch_factor=4, persistent_workers=True
)

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
    prefetch_factor=4, persistent_workers=True
)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | Classes: {NUM_CLASSES}")

#model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = resnet18(weights=ResNet18_Weights.DEFAULT)

#fine tuning the model(3layers)
for name, param in model.named_parameters():
    if "layer3" in name or "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

#replacing the fc in cifar-10
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
for param in model.fc.parameters():
    param.requires_grad = True

model = model.to(device)

trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
criterion = nn.CrossEntropyLoss()

#helper fcns to determine metrics of models
def count_parameters(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

def save_model_state(m, path):
    torch.save(m.state_dict(), path)
    return os.path.getsize(path)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        with torch.cuda.amp.autocast():
            out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total

def latency_ms(model, device, n=50):
    model.to(device).eval()
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device)
    for _ in range(10):
        with torch.cuda.amp.autocast():
            _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(n):
        with torch.cuda.amp.autocast():
            _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    return 1000.0 * (time.time() - t0) / n

#training loop for model
ckpt_dir = "/content/drive/MyDrive/checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
save_every = 200
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (xb, yb) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)):
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * xb.size(0)

        if (batch_idx + 1) % save_every == 0:
            ckpt_path = f"{ckpt_dir}/resnet18_epoch{epoch}_batch{batch_idx+1}.pth"
            torch.save({
                'epoch': epoch,
                'batch_idx': batch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
            }, ckpt_path)


    epoch_loss = running_loss / len(train_ds)
    val_acc = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

#saving finetuned model to drive and determining baseline metrics
final_path = os.path.join(ckpt_dir, "finetuned_resnet18_cifar10_final.pth")
saved_size = save_model_state(model, final_path)

total_params, trainable_params = count_parameters(model)
lat_cpu = latency_ms(model.cpu(), torch.device("cpu"))
lat_gpu = latency_ms(model.to(device), device) if torch.cuda.is_available() else None

print("\n--- Fine-tuned Metrics ---")
print(f"Total params:        {total_params:,}")
print(f"Trainable params:    {trainable_params:,}")
print(f"Checkpoint size:     {saved_size/1e6:.2f} MB")
print(f"Latency (CPU, ms):   {lat_cpu:.3f}")
if lat_gpu: print(f"Latency (GPU, ms): {lat_gpu:.3f}")

#post training quantization of the model

import copy
from torch.ao.quantization import quantize_dynamic
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig import get_default_qconfig

print("\n=== Running Post-Training Quantization (PTQ) ===\n")

#dynamic quantisation
dynamic_quantized_model = quantize_dynamic(
    copy.deepcopy(model.cpu()),
    {nn.Linear},
    dtype=torch.qint8
)

dyn_path = os.path.join(ckpt_dir, "resnet18_cifar10_PTQ_dynamic.pth")
torch.save(dynamic_quantized_model.state_dict(), dyn_path)

dyn_size = os.path.getsize(dyn_path)
dyn_acc = evaluate(dynamic_quantized_model, val_loader, torch.device("cpu"))
dyn_lat = latency_ms(dynamic_quantized_model, torch.device("cpu"))

print(f"\n--- Dynamic PTQ Metrics ---")
print(f"Model size: {dyn_size/1e6:.2f} MB")
print(f"Accuracy:   {dyn_acc*100:.2f}%")
print(f"Latency:    {dyn_lat:.3f} ms\n")

#static quantization

qconfig = get_default_qconfig("fbgemm")
model_to_quantize = copy.deepcopy(model.cpu())
model_to_quantize.eval()

example_inputs = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) # Add example input
prepared_model = prepare_fx(model_to_quantize, {"": qconfig}, example_inputs=example_inputs)

# calibration loop
print("Calibrating static quantizer...")
with torch.no_grad():
    for xb, yb in tqdm(val_loader, total=50):
        prepared_model(xb)

static_quantized_model = convert_fx(prepared_model)

static_path = os.path.join(ckpt_dir, "resnet18_cifar10_PTQ_static.pth")
torch.save(static_quantized_model.state_dict(), static_path)

static_size = os.path.getsize(static_path)
static_acc = evaluate(static_quantized_model, val_loader, torch.device("cpu"))
static_lat = latency_ms(static_quantized_model, torch.device("cpu"))

print(f"\n--- Static PTQ Metrics ---")
print(f"Model size: {static_size/1e6:.2f} MB")
print(f"Accuracy:   {static_acc*100:.2f}%")
print(f"Latency:    {static_lat:.3f} ms\n")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Train samples: 50000 | Val samples: 10000 | Classes: 10




Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 1 | Loss: 0.4027 | Val Acc: 92.55%


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():



--- Fine-tuned Metrics ---
Total params:        11,181,642
Trainable params:    10,498,570
Checkpoint size:     44.81 MB
Latency (CPU, ms):   82.538
Latency (GPU, ms): 4.789

=== Running Post-Training Quantization (PTQ) ===



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  dynamic_quantized_model = quantize_dynamic(


Eval:   0%|          | 0/79 [00:00<?, ?it/s]


--- Dynamic PTQ Metrics ---
Model size: 44.79 MB
Accuracy:   92.54%
Latency:    85.139 ms



For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared_model = prepare_fx(model_to_quantize, {"": qconfig}, example_inputs=example_inputs)
  prepared = prepare(


Calibrating static quantizer...


  0%|          | 0/50 [00:00<?, ?it/s]

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  static_quantized_model = convert_fx(prepared_model)


Eval:   0%|          | 0/79 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():



--- Static PTQ Metrics ---
Model size: 11.31 MB
Accuracy:   91.80%
Latency:    31.596 ms



In [None]:
import os, time
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights

# Mount Drive for Checkpoints
from google.colab import drive
drive.mount('/content/drive')


# Dataset Setup (CIFAR-10)
BATCH_SIZE = 128
IMG_SIZE = 224  # ResNet18 expects 224x224 input
NUM_WORKERS = 8
NUM_CLASSES = 10  # CIFAR-10 has 10 classes

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
val_ds   = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True,
    prefetch_factor=4, persistent_workers=True
)

val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
    prefetch_factor=4, persistent_workers=True
)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | Classes: {NUM_CLASSES}")

#model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = resnet18(weights=ResNet18_Weights.DEFAULT)

#utility functions
def count_parameters(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

def save_model_state(m, path):
    torch.save(m.state_dict(), path)
    return os.path.getsize(path)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        # Removed torch.cuda.amp.autocast() to avoid float16 type issues with QAT
        out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total

def latency_ms(model, device, n=50):
    model.to(device).eval()
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device)
    for _ in range(10):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(n):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    return 1000.0 * (time.time() - t0) / n

#qat after loading finetuned model from drive

import copy
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.qconfig import get_default_qat_qconfig

print("\n=== Running Quantization Aware Training (QAT) on saved FP32 model ===\n")


ckpt_path = "/content/drive/MyDrive/finetuned_resnet18_cifar10_final.pth"
print(f"Loading finetuned FP32 model from: {ckpt_path}")

qat_base_model = resnet18(weights=None)
qat_base_model.fc = nn.Linear(qat_base_model.fc.in_features, 10)


state = torch.load(ckpt_path, map_location="cpu")
qat_base_model.load_state_dict(state)
qat_base_model.eval()

print("Loaded FP32 model successfully.")



#preparing model for qat
qat_qconfig = get_default_qat_qconfig("fbgemm")

example_inputs = (torch.randn(1, 3, IMG_SIZE, IMG_SIZE),)

print("Preparing model for QAT (fake quantization modules)...")
qat_prepared = prepare_qat_fx(
    qat_base_model,
    {"": qat_qconfig},
    example_inputs=example_inputs
)

# Move to device for QAT training
qat_prepared.to(device)
qat_prepared.train()

#finetuning the qat model

num_epochs_qat = 3
qat_lr = 5e-5
qat_optimizer = torch.optim.Adam(
    [p for p in qat_prepared.parameters() if p.requires_grad],
    lr=qat_lr
)
qat_criterion = nn.CrossEntropyLoss()

print(f"Starting QAT for {num_epochs_qat} epochs on {device}...")

for epoch in range(num_epochs_qat):
    running_loss = 0.0
    qat_prepared.train()

    for xb, yb in tqdm(train_loader, desc=f"QAT Epoch {epoch+1}", leave=False):
        xb, yb = xb.to(device), yb.to(device)

        qat_optimizer.zero_grad()
        out = qat_prepared(xb)
        loss = qat_criterion(out, yb)

        loss.backward()
        qat_optimizer.step()
        running_loss += loss.item() * xb.size(0)

    epoch_loss = running_loss / len(train_ds)


    qat_prepared.eval()
    with torch.no_grad():
        val_acc_fake = evaluate(qat_prepared, val_loader, device)

    print(f"QAT Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Fake-Quant Val Acc: {val_acc_fake*100:.2f}%")


#convert to int8 model

print("Converting QAT model to real INT8...")
qat_prepared.cpu()
qat_int8 = convert_fx(qat_prepared)# ================== PTQ (DYNAMIC + STATIC) for Saved VGG16 ==================
# Loads the baseline FP32 model, performs INT8 Dynamic and Static quantization,
# evaluates Accuracy, Latency, and Model Size.

import os, time
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import vgg16_bn, VGG16_BN_Weights

# ---------------------------------------------------
# 1. Environment + Dataset
# ---------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
print("Using device:", DEVICE)

BASELINE_PATH = "/content/vgg_baseline_outputs/vgg_baseline_fp32.pth"

OUT_DIR = "/content/vgg_ptq_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

IMG_SIZE    = 224
BATCH_SIZE  = 32
NUM_CLASSES = 10
NUM_WORKERS = 2

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print("Test samples:", len(test_ds))

# ---------------------------------------------------
# 2. Load baseline model
# ---------------------------------------------------
def build_vgg(num_classes=10):
    model = vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)
    in_f = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(in_f, num_classes)
    return model

def load_baseline():
    model = build_vgg()
    state = torch.load(BASELINE_PATH, map_location="cpu")
    model.load_state_dict(state)
    return model

# ---------------------------------------------------
# 3. Evaluation + latency + size functions
# ---------------------------------------------------
@torch.no_grad()
def evaluate(model, device):
    model.eval().to(device)
    correct, total = 0, 0
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb).argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total

def latency_ms(model, device, n=20):
    model.eval().to(device)
    x = torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device)

    # warmup
    for _ in range(3):
        _ = model(x)
        if device.type == "cuda":
            torch.cuda.synchronize()

    t0 = time.time()
    for _ in range(n):
        _ = model(x)
        if device.type == "cuda":
            torch.cuda.synchronize()
    return (time.time() - t0) / n * 1000.0

def model_size_mb(path):
    return os.path.getsize(path) / (1024 * 1024)

# ---------------------------------------------------
# 4. Dynamic PTQ (with reduced FC option)
# ---------------------------------------------------
def dynamic_ptq(reduce_fc=True):
    print("\n[Dynamic PTQ] Running...")

    model = load_baseline()

    # Optionally reduce classifier size to minimize FC overhead
    if reduce_fc:
        print(" - Applying reduced FC: 4096 -> 512 -> 10")
        model.classifier = nn.Sequential(
            nn.Linear(25088, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, NUM_CLASSES)
        )

    # Convert to CPU-only for quantization
    model.cpu()

    # Dynamic Quantization (weights \u2192 int8)
    qmodel = torch.quantization.quantize_dynamic(
        model,
        {nn.Linear},
        dtype=torch.qint8
    )

    # Save model
    save_path = os.path.join(OUT_DIR, "vgg16_dynamic_int8.pth")
    torch.save(qmodel.state_dict(), save_path)

    # Metrics
    # Dynamic quantized models typically run on CPU, as the quantized ops might not be available on GPU.
    cpu_device = torch.device("cpu")
    acc = evaluate(qmodel, cpu_device)
    lat = latency_ms(qmodel, cpu_device)
    size = model_size_mb(save_path)

    print(f"Dynamic PTQ | Acc={acc*100:.2f}% | Latency={lat:.3f} ms | Size={size:.2f} MB")
    return acc, lat, size, save_path

# ---------------------------------------------------
# 5. Static PTQ (full-graph int8 conversion)
# ---------------------------------------------------
def static_ptq():
    print("\n[Static PTQ] Running...")

    model = load_baseline()
    model.cpu()

    # Fuse modules where possible
    fused_model = torch.quantization.fuse_modules(
        model,
        [
            ["features.0", "features.1", "features.2"],
        ],
        inplace=False
    )

    fused_model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(fused_model, inplace=True)

    # Calibration with 1 batch from test set
    print(" - Calibrating...")
    xb, _ = next(iter(test_loader))
    fused_model(xb)

    qmodel = torch.quantization.convert(fused_model, inplace=False)

    # Save
    save_path = os.path.join(OUT_DIR, "vgg16_static_int8.pth")
    torch.save(qmodel.state_dict(), save_path)

    # Metrics
    acc = evaluate(qmodel, DEVICE)
    lat = latency_ms(qmodel, DEVICE)
    size = model_size_mb(save_path)

    print(f"Static PTQ | Acc={acc*100:.2f}% | Latency={lat:.3f} ms | Size={size:.2f} MB")
    return acc, lat, size, save_path

# ---------------------------------------------------
# 6. Run Everything
# ---------------------------------------------------
dyn_acc, dyn_lat, dyn_size, dyn_path = dynamic_ptq(reduce_fc=True)
stat_acc, stat_lat, stat_size, stat_path = static_ptq()

print("\n================ FINAL PTQ RESULTS ================")
print(f"Dynamic PTQ: Acc={dyn_acc*100:.2f}% | Lat={dyn_lat:.2f} ms | Size={dyn_size:.2f} MB")
print(f"Static PTQ:  Acc={stat_acc*100:.2f}% | Lat={stat_lat:.2f} ms | Size={stat_size:.2f} MB")
print("===================================================")

int8_path = "/content/drive/MyDrive/resnet18_cifar10_QAT_int8.pth"
torch.save(qat_int8.state_dict(), int8_path)

print(f"Saved INT8 QAT model \u2192 {int8_path}")

#metric evaluation

qat_int8.eval()
int8_acc = evaluate(qat_int8, val_loader, torch.device("cpu"))
int8_lat = latency_ms(qat_int8, torch.device("cpu"))
int8_size = os.path.getsize(int8_path)

print("\n--- QAT INT8 Metrics ---")
print(f"Model size: {int8_size/1e6:.2f} MB")
print(f"Accuracy:   {int8_acc*100:.2f}%")
print(f"Latency:    {int8_lat:.3f} ms")

Mounted at /content/drive


100%|██████████| 170M/170M [00:13<00:00, 12.3MB/s]


Train samples: 50000 | Val samples: 10000 | Classes: 10
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 219MB/s]



=== Running Quantization Aware Training (QAT) on saved FP32 model ===

Loading finetuned FP32 model from: /content/drive/MyDrive/finetuned_resnet18_cifar10_final.pth
Loaded FP32 model successfully.
Preparing model for QAT (fake quantization modules)...


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qat_prepared = prepare_qat_fx(
  prepared = prepare(


Starting QAT for 3 epochs on cuda...


QAT Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]



Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 1 | Loss: 0.1723 | Fake-Quant Val Acc: 94.00%


QAT Epoch 2:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 2 | Loss: 0.1171 | Fake-Quant Val Acc: 94.57%


QAT Epoch 3:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 3 | Loss: 0.0864 | Fake-Quant Val Acc: 94.57%
Converting QAT model to real INT8...


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qat_int8 = convert_fx(qat_prepared)


Saved INT8 QAT model → /content/drive/MyDrive/resnet18_cifar10_QAT_int8.pth


Eval:   0%|          | 0/79 [00:00<?, ?it/s]


--- QAT INT8 Metrics ---
Model size: 11.31 MB
Accuracy:   94.72%
Latency:    34.752 ms


In [None]:
#pruning of model
import os, time, copy
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights

from google.colab import drive
drive.mount('/content/drive')


BATCH_SIZE = 128
IMG_SIZE = 224
NUM_WORKERS = 8
NUM_CLASSES = 10

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
val_ds   = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          prefetch_factor=4, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True,
                        prefetch_factor=4, persistent_workers=True)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | Classes: {NUM_CLASSES}")

# utility fcns
def count_parameters(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

def save_model_state(m, path):
    torch.save(m.state_dict(), path)
    return os.path.getsize(path)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total

def latency_ms(model, device, n=50):
    model.to(device).eval()
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device)
    for _ in range(10):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(n):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    return 1000.0 * (time.time() - t0) / n

# model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

# Load finetuned checkpoint
ckpt_dir = "/content/drive/MyDrive/checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, "finetuned_resnet18_cifar10_final.pth")
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.to(device)
model.eval()
print("Loaded finetuned model from drive.")

# structured pruning
!pip install -q torch-pruning
import torch_pruning as tp

pruned_model = copy.deepcopy(model)
pruned_model.to(device)
pruned_model.eval()

example_inputs = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=example_inputs)

prune_ratio = 0.3  # Fraction of channels to remove
prunable_layers = [m for m in pruned_model.modules() if isinstance(m, nn.Conv2d)]

for layer in prunable_layers:
    n_remove = int(prune_ratio * layer.out_channels)
    if n_remove <= 0:
        continue
    weight = layer.weight.data.abs().view(layer.out_channels, -1).sum(1)
    prune_idx = torch.argsort(weight)[:n_remove].cpu().tolist()
    # Updated API for torch-pruning
    group = DG.get_pruning_group(layer, tp.prune_conv_out_channels, prune_idx)
    if group is not None:
        group.prune()

total_params, trainable_params = count_parameters(pruned_model)
print(f"Pruned model params: total={total_params:,}, trainable={trainable_params:,}")

#finetuning the pruned model
pruned_model.train()
optimizer_pruned = torch.optim.Adam([p for p in pruned_model.parameters() if p.requires_grad], lr=1e-4)
criterion = nn.CrossEntropyLoss()
num_ft_epochs = 3

for epoch in range(num_ft_epochs):
    running_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Pruned FT Epoch {epoch+1}", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        optimizer_pruned.zero_grad()
        out = pruned_model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer_pruned.step()
        running_loss += loss.item() * xb.size(0)
    epoch_loss = running_loss / len(train_ds)
    val_acc = evaluate(pruned_model, val_loader, device)
    print(f"Pruned FT Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

#save pruned model
pruned_path = os.path.join(ckpt_dir, "resnet18_cifar10_pruned.pth")
saved_size = save_model_state(pruned_model, pruned_path)
print(f"Saved pruned model: {pruned_path}, size={saved_size/1e6:.2f} MB")

#pruned model metrics
pruned_model.eval()
lat_cpu = latency_ms(pruned_model.cpu(), torch.device("cpu"))
lat_gpu = latency_ms(pruned_model.to(device), device) if torch.cuda.is_available() else None
pruned_acc = evaluate(pruned_model, val_loader, device)

print("\n--- Pruned Model Metrics ---")
print(f"Total params:        {total_params:,}")
print(f"Trainable params:    {trainable_params:,}")
print(f"Model Size:          {saved_size/1e6:.2f} MB")
print(f"Accuracy:            {pruned_acc*100:.2f}%")
print(f"Latency (CPU, ms):   {lat_cpu:.3f}")
if lat_gpu: print(f"Latency (GPU, ms): {lat_gpu:.3f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).




Train samples: 50000 | Val samples: 10000 | Classes: 10
Loaded finetuned model from drive.
Pruned model params: total=2,704,554, trainable=2,704,554


Pruned FT Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 1 | Loss: 1.0158 | Val Acc: 78.54%


Pruned FT Epoch 2:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 2 | Loss: 0.6271 | Val Acc: 80.70%


Pruned FT Epoch 3:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 3 | Loss: 0.4838 | Val Acc: 84.42%
Saved pruned model: /content/drive/MyDrive/checkpoints/resnet18_cifar10_pruned.pth, size=10.88 MB


Eval:   0%|          | 0/79 [00:00<?, ?it/s]


--- Pruned Model Metrics ---
Total params:        2,704,554
Trainable params:    2,704,554
Model Size:          10.88 MB
Accuracy:            84.42%
Latency (CPU, ms):   21.206
Latency (GPU, ms): 2.543


In [None]:
#pruning + qat of the finetuned model
import os, time, copy
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights

# Mount Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

#dataset
BATCH_SIZE = 128
IMG_SIZE = 224
NUM_WORKERS = 2
NUM_CLASSES = 10

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])

train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
val_ds   = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True,
                          prefetch_factor=4, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True,
                        prefetch_factor=4, persistent_workers=True)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)} | Classes: {NUM_CLASSES}")

#metric fcns
def count_parameters(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

def save_model_state(m, path):
    torch.save(m.state_dict(), path)
    return os.path.getsize(path)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for xb, yb in tqdm(loader, desc="Eval", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total

def latency_ms(model, device, n=50):
    model.to(device).eval()
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device)
    for _ in range(10):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(n):
        _ = model(dummy)
        if device.type == "cuda": torch.cuda.synchronize()
    return 1000.0 * (time.time() - t0) / n


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

# load finetuned model from checkpoint
ckpt_dir = "/content/drive/MyDrive/checkpoints"
ckpt_path = os.path.join(ckpt_dir, "finetuned_resnet18_cifar10_final.pth")
state = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state)
model.to(device)
model.eval()
print("Loaded finetuned FP32 model.")

#structured pruning
!pip install -q torch-pruning
import torch_pruning as tp

pruned_model = copy.deepcopy(model)
pruned_model.to(device)
pruned_model.eval()

example_inputs_pruning = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=example_inputs_pruning)

prune_ratio = 0.3
prunable_layers = [m for m in pruned_model.modules() if isinstance(m, nn.Conv2d)]

for layer in prunable_layers:
    n_remove = int(prune_ratio * layer.out_channels)
    if n_remove <= 0:
        continue
    weight = layer.weight.data.abs().view(layer.out_channels, -1).sum(1)
    prune_idx = torch.argsort(weight)[:n_remove].cpu().tolist()
    group = DG.get_pruning_group(layer, tp.prune_conv_out_channels, prune_idx)
    if group is not None:
        group.prune()

total_params, trainable_params = count_parameters(pruned_model)
print(f"After pruning: total params={total_params:,}, trainable={trainable_params:,}")

#finetune pruned model
pruned_model.train()
optimizer = torch.optim.Adam([p for p in pruned_model.parameters() if p.requires_grad], lr=1e-4)
criterion = nn.CrossEntropyLoss()
num_ft_epochs = 3

for epoch in range(num_ft_epochs):
    running_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Pruned FT Epoch {epoch+1}", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = pruned_model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
    epoch_loss = running_loss / len(train_ds)
    val_acc = evaluate(pruned_model, val_loader, device)
    print(f"Pruned FT Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Val Acc: {val_acc*100:.2f}%")


from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.qconfig import get_default_qat_qconfig

qat_qconfig = get_default_qat_qconfig("fbgemm")
example_inputs = (torch.randn(1, 3, IMG_SIZE, IMG_SIZE),)

# Prepare pruned model for QAT
qat_model = copy.deepcopy(pruned_model)
qat_prepared = prepare_qat_fx(qat_model, {"": qat_qconfig}, example_inputs)
qat_prepared.to(device)
qat_prepared.train()

#finetune qat model
qat_epochs = 3
qat_lr = 5e-5
qat_optimizer = torch.optim.Adam([p for p in qat_prepared.parameters() if p.requires_grad], lr=qat_lr)
qat_criterion = nn.CrossEntropyLoss()

for epoch in range(qat_epochs):
    running_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"QAT Epoch {epoch+1}", leave=False):
        xb, yb = xb.to(device), yb.to(device)
        qat_optimizer.zero_grad()
        out = qat_prepared(xb)
        loss = qat_criterion(out, yb)
        loss.backward()
        qat_optimizer.step()
        running_loss += loss.item() * xb.size(0)
    epoch_loss = running_loss / len(train_ds)
    val_acc = evaluate(qat_prepared, val_loader, device)
    print(f"QAT Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Fake-Quant Val Acc: {val_acc*100:.2f}%")


qat_prepared.eval()
qat_prepared.cpu()

qat_int8 = convert_fx(qat_prepared)



int8_path = os.path.join(ckpt_dir, "resnet18_cifar10_pruned_QAT_int8.pth")
torch.save(qat_int8.state_dict(), int8_path)
print(f"Saved INT8 QAT model → {int8_path}")


val_loader_cpu = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
)

@torch.no_grad()
def evaluate_cpu(model, loader):
    model.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        out = model(xb)
        preds = out.argmax(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    return correct / total


def latency_cpu(model, n=50):
    model.eval()
    dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE)
    for _ in range(10):
        _ = model(dummy)
    t0 = time.time()
    for _ in range(n):
        _ = model(dummy)
    return 1000 * (time.time() - t0) / n


#metrics

int8_acc = evaluate_cpu(qat_int8, val_loader_cpu)
int8_size = os.path.getsize(int8_path)
lat_cpu = latency_cpu(qat_int8)

print("\n--- Pruned + QAT INT8 Metrics ---")
print(f"Total params:        {total_params:,}")
print(f"Trainable params:    {trainable_params:,}")
print(f"Model Size:          {int8_size/1e6:.2f} MB")
print(f"Accuracy:            {int8_acc*100:.2f}%")
print(f"Latency (CPU, ms):   {lat_cpu:.3f}")




Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Train samples: 50000 | Val samples: 10000 | Classes: 10
Loaded finetuned FP32 model.
After pruning: total params=2,704,554, trainable=2,704,554


Pruned FT Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 1 | Loss: 1.0269 | Val Acc: 77.88%


Pruned FT Epoch 2:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 2 | Loss: 0.6377 | Val Acc: 77.89%


Pruned FT Epoch 3:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

Pruned FT Epoch 3 | Loss: 0.4950 | Val Acc: 83.59%


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qat_prepared = prepare_qat_fx(qat_model, {"": qat_qconfig}, example_inputs)
  prepared = prepare(


QAT Epoch 1:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 1 | Loss: 0.4568 | Fake-Quant Val Acc: 86.42%


QAT Epoch 2:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 2 | Loss: 0.3829 | Fake-Quant Val Acc: 86.48%


QAT Epoch 3:   0%|          | 0/391 [00:00<?, ?it/s]

Eval:   0%|          | 0/79 [00:00<?, ?it/s]

QAT Epoch 3 | Loss: 0.3271 | Fake-Quant Val Acc: 87.21%


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qat_int8 = convert_fx(qat_prepared)


Saved INT8 QAT model → /content/drive/MyDrive/checkpoints/resnet18_cifar10_pruned_QAT_int8.pth

--- Pruned + QAT INT8 Metrics ---
Total params:        2,704,554
Trainable params:    2,704,554
Model Size:          2.79 MB
Accuracy:            87.26%
Latency (CPU, ms):   13.560

Note: Quantized INT8 models DO NOT run on GPU — CPU only.
