In [6]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
import utils

# Data preprocessing
- Resize the data
- Implement augmentation

In [2]:
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import lightning as L

class TinyImageNetDataModule(L.LightningDataModule):
    def __init__(self, data_dir='tiny-imagenet-200/train', test_dir='tiny-imagenet-200/val/processed_val', batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.test_dir = test_dir
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Data augmentation
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                [0.229, 0.224, 0.225])
        ])

        self.val_test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])

        full_dataset = datasets.ImageFolder(self.data_dir, transform=self.train_transform)
        train_size = int(0.7 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        self.train_ds, self.val_ds = random_split(full_dataset, [train_size, val_size])
        self.val_ds.dataset.transform = self.val_test_transform

        self.test_ds = datasets.ImageFolder(self.test_dir, transform=self.val_test_transform)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)


# Baseline Model

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision.models as models


class ResNet50Baseline(L.LightningModule):
    def __init__(self, num_classes=200, lr=1e-4):
        super().__init__()
        self.save_hyperparameters()
        
        # Load pretrained ResNet50
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Dropout(p=0.4),
            nn.Linear(self.model.fc.in_features, num_classes)
        )
        
        # Freeze all layers first
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze only the last 3 layers (starting from the end)
        unfreeze_count = 3
        trainable_layers = list(self.model.named_parameters())[-unfreeze_count:]
        for name, param in trainable_layers:
            param.requires_grad = True
            print(f"Unfroze layer: {name}")

        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.hparams.lr,
            weight_decay=1e-4  # L2 regularization
        )
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = self.train_acc(outputs, labels)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = self.val_acc(outputs, labels)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True)


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    mode='max',
    save_top_k=1,
    dirpath='checkpoints/',
    filename='resnet50-{epoch:02d}-{val_acc:.2f}'
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min',
    verbose=True
)

logger = TensorBoardLogger("tb_logs", name="resnet50Baseline")

trainer = L.Trainer(
    max_epochs=7,
    accelerator="mps",  # Use GPU on Mac
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger
)

datamodule = TinyImageNetDataModule()
model = ResNet50Baseline()
trainer.fit(model, datamodule=datamodule)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Unfroze layer: layer4.2.bn3.bias
Unfroze layer: fc.1.weight
Unfroze layer: fc.1.bias


/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/lynnning/Desktop/group project/checkpoints exists and is not empty.

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 23.9 M | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
411 K     Trainable params
23.5 M    Non-trainable params
23.9 M    Total params
95.671    Total estimated model params size (MB)
156       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0: 100%|██████████| 2188/2188 [04:57<00:00,  7.35it/s, v_num=10]

Metric val_loss improved. New best score: 2.706


Epoch 1: 100%|██████████| 2188/2188 [05:35<00:00,  6.52it/s, v_num=10]

Metric val_loss improved by 0.694 >= min_delta = 0.0. New best score: 2.013


Epoch 2: 100%|██████████| 2188/2188 [05:35<00:00,  6.53it/s, v_num=10]

Metric val_loss improved by 0.238 >= min_delta = 0.0. New best score: 1.774


Epoch 3: 100%|██████████| 2188/2188 [05:37<00:00,  6.48it/s, v_num=10]

Metric val_loss improved by 0.100 >= min_delta = 0.0. New best score: 1.675


Epoch 4: 100%|██████████| 2188/2188 [05:38<00:00,  6.46it/s, v_num=10]

Metric val_loss improved by 0.071 >= min_delta = 0.0. New best score: 1.604


Epoch 5: 100%|██████████| 2188/2188 [05:40<00:00,  6.42it/s, v_num=10]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 1.590


Epoch 6: 100%|██████████| 2188/2188 [05:35<00:00,  6.51it/s, v_num=10]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 1.570
`Trainer.fit` stopped: `max_epochs=7` reached.


Epoch 6: 100%|██████████| 2188/2188 [05:36<00:00,  6.51it/s, v_num=10]


In [None]:

ckpt_path = "checkpoints/resnet50-epoch=06-val_acc=0.62.ckpt"

# Continue training
trainer = L.Trainer(
    max_epochs=12,  # New total number of epochs
    accelerator="mps",
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger
)

model = ResNet50Baseline.load_from_checkpoint(ckpt_path)
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)


In [12]:
torch.save(model.state_dict(), 'resnet50_baseline_weights.pth')


In [10]:
import torch
import torch.nn.functional as F

def evaluate_metrics(model, dataloader, device):
    model.eval()
    top1_correct = 0
    top5_correct = 0
    total = 0
    total_loss = 0
    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)

            # Top-1
            _, predicted = torch.max(outputs, 1)
            top1_correct += (predicted == labels).sum().item()

            # Top-5
            top5_preds = torch.topk(outputs, k=5, dim=1).indices
            top5_correct += sum([labels[i] in top5_preds[i] for i in range(labels.size(0))])
            
            total += labels.size(0)

    top1_acc = top1_correct / total
    top5_acc = top5_correct / total
    avg_loss = total_loss / total
    return avg_loss, top1_acc, top5_acc

import time

def measure_inference_time(model, dataloader, device):
    model.eval()
    images, labels = next(iter(dataloader))
    images = images.to(device)

    start_time = time.time()
    with torch.no_grad():
        _ = model(images)
    end_time = time.time()

    inference_time = end_time - start_time
    return inference_time, images.size(0)

import os

model_path = 'resnet50_baseline_weights.pth'
torch.save(model.state_dict(), model_path)
model_size_MB = os.path.getsize(model_path) / 1e6




In [13]:
import json

test_dataloader = datamodule.test_dataloader()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = model.to(device)

val_loss, top1_acc, top5_acc = evaluate_metrics(model, test_dataloader, device)

# Measure inference time
inference_time, batch_size = measure_inference_time(model, test_dataloader, device)

baseline_metrics = {
    "categorical_crossentropy": val_loss,
    "top1_accuracy": top1_acc,
    "top5_accuracy": top5_acc,
    "model_size_MB": model_size_MB,
    "inference_time_per_batch_sec": inference_time,
    "batch_size": batch_size
}
print(baseline_metrics)
print(json.dumps(baseline_metrics, indent=4))

with open('baseline_resnet50_metrics.json', 'w') as f:
    json.dump(baseline_metrics, f, indent=4)


{'categorical_crossentropy': 1.6259361061096191, 'top1_accuracy': 0.6243, 'top5_accuracy': 0.8429, 'model_size_MB': 95.993338, 'inference_time_per_batch_sec': 0.01288914680480957, 'batch_size': 32}
{
    "categorical_crossentropy": 1.6259361061096191,
    "top1_accuracy": 0.6243,
    "top5_accuracy": 0.8429,
    "model_size_MB": 95.993338,
    "inference_time_per_batch_sec": 0.01288914680480957,
    "batch_size": 32
}


In [None]:
model = model.to(device)
model.eval()


ResNet50Baseline(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
   

# Structured Pruning

In [None]:
import torch
import torch_pruning as tp
from torchvision import models

ckpt_path = "checkpoints/resnet50-epoch=07-val_acc=0.65.ckpt"
lightning_model = ResNet50Baseline.load_from_checkpoint(ckpt_path)
lightning_model.eval()
lightning_model.to("mps")

resnet = lightning_model.model

for name, param in resnet.named_parameters():
    if name.startswith("conv1") or name.startswith("bn1") or \
       name.startswith("layer1") or name.startswith("layer2"):
        param.requires_grad = False
    else:
        param.requires_grad = True

example_inputs = torch.randn(1, 3, 224, 224).to("mps")

importance = tp.importance.GroupMagnitudeImportance(p=2)

ignored_layers = [m for m in resnet.modules() if isinstance(m, torch.nn.Linear)]
for name, m in resnet.named_modules():
    if name.startswith("conv1") or name.startswith("bn1") or \
       name.startswith("layer1") or name.startswith("layer2"):
        ignored_layers.append(m)

pruner = tp.pruner.BasePruner(
    model=resnet,
    example_inputs=example_inputs,
    importance=importance,
    pruning_ratio=0.3,
    ignored_layers=ignored_layers,
    round_to=8
)

print("Before pruning:")
tp.utils.print_tool.before_pruning(resnet)
base_macs, base_params = tp.utils.count_ops_and_params(resnet, example_inputs)

pruner.step()

print("After pruning:")
tp.utils.print_tool.after_pruning(resnet)
macs, params = tp.utils.count_ops_and_params(resnet, example_inputs)
print(f"MACs: {base_macs/1e9:.2f}G → {macs/1e9:.2f}G")
print(f"#Params: {base_params/1e6:.2f}M → {params/1e6:.2f}M")

torch.save(resnet.state_dict(), "resnet50_structured_pruned_weights.pth")
torch.save(resnet, "resnet50_structured_pruned_model.pth")

print("Pruned model and weights saved successfully.")




Unfroze layer: layer4.2.bn3.bias
Unfroze layer: fc.1.weight
Unfroze layer: fc.1.bias
Before pruning:
After pruning:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (re

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import lightning as L

class StructuredPrunedResNet50FineTune(L.LightningModule):
    def __init__(self, pruned_model, num_classes=200, lr=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=['pruned_model'])  # save lr and num_classes only
        self.model = pruned_model

        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = self.criterion(out, y)
        acc = self.train_acc(out, y)
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc", acc, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = self.criterion(out, y)
        acc = self.val_acc(out, y)
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_acc", acc, on_epoch=True)


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torchvision.models import resnet

# Allowlist ResNet so torch.load can deserialize it
torch.serialization.add_safe_globals({"ResNet": resnet.ResNet})

# Now load the full model
model = torch.load("resnet50_structured_pruned_model.pth", weights_only=False)
model.to("mps")
model.eval()


# Wrap in Lightning
fine_tune_model = StructuredPrunedResNet50FineTune(pruned_model=model)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    dirpath="checkpoints_pruned/",
    filename="resnet50_pruned-{epoch:02d}-{val_acc:.2f}"
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=0,
    mode="min",
    verbose=True
)

logger = TensorBoardLogger("tb_logs", name="resnet50StructuredPruned")

# Trainer
trainer = L.Trainer(
    max_epochs=7,
    accelerator="mps",
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger
)

datamodule = TinyImageNetDataModule()
trainer.fit(fine_tune_model, datamodule=datamodule)


In [14]:
ckpt_path = "checkpoints_pruned/resnet50_pruned-epoch=05-val_acc=0.63.ckpt"
trainer = L.Trainer(
    max_epochs=7,
    accelerator="mps",
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger
)
pruned_model = torch.load("resnet50_structured_pruned_model.pth", weights_only=False)
model = StructuredPrunedResNet50FineTune.load_from_checkpoint(ckpt_path, pruned_model=pruned_model)
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)



GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/lynnning/Desktop/group project/checkpoints_pruned exists and is not empty.
Restoring states from the checkpoint path at checkpoints_pruned/resnet50_pruned-epoch=05-val_acc=0.63.ckpt

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 12.4 M | eval 
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
11.0 M    Trainable params
1.4 M     Non-trainable params
12.4 M    Total params
49.613    Total estimated model params size (MB)
3         Modules in train mode
153     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/lynnning/miniforge3/envs/resnet_pruning/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 6: 100%|██████████| 2188/2188 [06:07<00:00,  5.96it/s, v_num=0]

Metric val_loss improved by 0.651 >= min_delta = 0.0. New best score: 0.894
`Trainer.fit` stopped: `max_epochs=7` reached.


Epoch 6: 100%|██████████| 2188/2188 [06:07<00:00,  5.96it/s, v_num=0]


In [None]:
# Save the fine-tuned model and weights
torch.save(model.model, "resnet50_structured_pruned_finetuned_model.pth")
torch.save(model.state_dict(), "resnet50_structured_pruned_finetuned_weights.pth")


In [23]:
import os
import time
import json
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load pruned model architecture
model = torch.load("resnet50_structured_pruned_finetuned_model.pth", weights_only=False)
model.eval()

# Data
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder("tiny-imagenet-200/val/processed_val", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

# Warmup for timing
batch_images, _ = next(iter(test_loader))
batch_images = batch_images.to(device)
_ = model(batch_images)  # warmup
start = time.time()
_ = model(batch_images)
end = time.time()
inference_time = end - start

# Evaluation loop
top1_correct = 0
top5_correct = 0
total = 0
total_loss = 0.0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item() * images.size(0)

        _, top1_preds = torch.max(outputs, 1)
        top1_correct += (top1_preds == labels).sum().item()

        top5_preds = torch.topk(outputs, 5, dim=1).indices
        for i in range(labels.size(0)):
            if labels[i] in top5_preds[i]:
                top5_correct += 1

        total += labels.size(0)

# Metrics
top1_acc = top1_correct / total
top5_acc = top5_correct / total
avg_loss = total_loss / total
model_size_MB = os.path.getsize("resnet50_structured_pruned_finetuned_weights.pth") / 1e6

metrics = {
    "categorical_crossentropy": avg_loss,
    "top1_accuracy": top1_acc,
    "top5_accuracy": top5_acc,
    "model_size_MB": model_size_MB,
    "inference_time_per_batch_sec": inference_time,
    "batch_size": 32
}

print(json.dumps(metrics, indent=4))

with open("resnet50_structured_pruned_metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)


{
    "categorical_crossentropy": 1.576888065624237,
    "top1_accuracy": 0.6196,
    "top5_accuracy": 0.847,
    "model_size_MB": 49.889478,
    "inference_time_per_batch_sec": 0.03519296646118164,
    "batch_size": 32
}


In [20]:
test_path = Path('/Users/lakshya/quantization/tiny-imagenet-200/val/processed_val_2')
transformations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
test_dataset = datasets.ImageFolder(root=test_path, transform=transformations)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [21]:
from pathlib import Path
model_path = Path("./saved_models/resnet")

struc_prun_model = torch.load(model_path / "gradient_30%_model.pth", weights_only=False)
struc_prun_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [23]:
struc_prun_metrics = utils.evaluate_model(struc_prun_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [25:41<00:00,  9.82s/it, Loss=1.5018, Top1=63.29%]


In [26]:
grad_prun_metrics = struc_prun_metrics
grad_prun_metrics

{'top1_acc': 0.6329,
 'top5_acc': 0.8474,
 'total_inference_time': 1541.7002170085907,
 'average_inference_time': 9.819746605150259,
 'average_loss': 95.5161116350988,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [27]:
struc_prun_model = torch.load(model_path / "resnet50_structured_pruned_finetuned_model.pth", 
                              weights_only=False)
struc_prun_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [28]:
struc_prun_metrics = utils.evaluate_model(struc_prun_model, test_loader, 'resnet50')

100%|███████████████| 157/157 [17:48<00:00,  6.81s/it, Loss=1.5769, Top1=61.96%]


In [29]:
struc_prun_metrics

{'top1_acc': 0.6196,
 'top5_acc': 0.847,
 'total_inference_time': 1068.744558095932,
 'average_inference_time': 6.807290178954981,
 'average_loss': 100.32848511228136,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [31]:
import json
metrics_folder = Path("./model_metrics/resnet50")
metrics_folder.mkdir(parents=True, exist_ok=True)
with (metrics_folder / "baseline_grad_prun_metrics.json").open("w") as file:
    json.dump(grad_prun_metrics, file, indent=1)
with (metrics_folder / "baseline_struct_prun_metrics.json").open("w") as file:
    json.dump(struc_prun_metrics, file, indent=1)

# Further Pruning: Gradient Pruning
Use SparseML

In [None]:
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import lightning as L
import torch

from sparseml.pytorch.optim import ScheduledModifierManager

class StructuredThenSparseResNet50(L.LightningModule):
    def __init__(self, pruned_model_path, lr=1e-4, recipe_path=None):
        super().__init__()
        self.save_hyperparameters()

        self.model = torch.load(pruned_model_path, map_location="cpu",weights_only=False)
        self.model.eval()

        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=200)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=200)

        self.recipe_path = recipe_path
        self.manager = None

        if recipe_path:
            self.manager = ScheduledModifierManager.from_yaml(recipe_path)
            self.manager.initialize(self.model, optimizer=None)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

        if self.manager:
            self.manager.initialize(self.model, optimizer)

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = self.train_acc(outputs, labels)

        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        acc = self.val_acc(outputs, labels)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)


In [56]:
sparse_model = StructuredThenSparseResNet50(
    pruned_model_path="resnet50_structured_pruned_finetuned_model.pth",
    recipe_path="sparseml_unstructured_pruning_recipe.yaml"
)


In [None]:
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor='val_acc',
    mode='max',
    patience=0,  # stop immediately when no improvement
    verbose=True
)

logger = TensorBoardLogger("tb_logs", name="resnet50StructuredThenSparse")

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    mode='max',
    save_top_k=1,
    dirpath='checkpoints/',
    filename='resnet50-structured-sparseml-{epoch:02d}-{val_acc:.2f}'
)

trainer = L.Trainer(
    max_epochs=5,
    accelerator="mps",
    callbacks=[checkpoint_callback, early_stopping],
    logger=logger
)

trainer.fit(sparse_model, datamodule=datamodule)


In [46]:
sparse_model.manager.finalize(sparse_model.model)


In [59]:
torch.save(sparse_model.model, "resnet50_structured_pruned_SparseML40%_finalized_model.pth")
torch.save(sparse_model.model.state_dict(), "resnet50_structured_pruned_SparseML40%_finalized_weights.pth")

In [61]:
import os
import time
import json
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np



# 2. Load the finalized weights
state_dict = torch.load("resnet50_structured_pruned_SparseML40%_finalized_weights.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()



test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder("tiny-imagenet-200/val/processed_val", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

criterion = nn.CrossEntropyLoss()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

top1_correct = 0
top5_correct = 0
total = 0
total_loss = 0.0

batch_images, _ = next(iter(test_loader))
batch_images = batch_images.to(device)
_ = model(batch_images)  # warmup
start = time.time()
_ = model(batch_images)
end = time.time()
inference_time = end - start

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        total_loss += loss.item() * images.size(0)

        _, top1_preds = torch.max(outputs, 1)
        top1_correct += (top1_preds == labels).sum().item()

        top5_preds = torch.topk(outputs, 5, dim=1).indices
        for i in range(labels.size(0)):
            if labels[i] in top5_preds[i]:
                top5_correct += 1

        total += labels.size(0)

top1_acc = top1_correct / total
top5_acc = top5_correct / total
avg_loss = total_loss / total
model_size_MB = os.path.getsize("resnet50_structured_pruned_SparseML40%_finalized_weights.pth") / 1e6

structued_SML_metrics = {
    "categorical_crossentropy": avg_loss,
    "top1_accuracy": top1_acc,
    "top5_accuracy": top5_acc,
    "model_size_MB": model_size_MB,
    "inference_time_per_batch_sec": inference_time,
    "batch_size": 32
}

print(json.dumps(structued_SML_metrics, indent=4))

with open("resnet50_structued_pruned_SML_metrics.json", "w") as f:
    json.dump(structued_SML_metrics, f, indent=4)


{
    "categorical_crossentropy": 1.6869331655979156,
    "top1_accuracy": 0.6318,
    "top5_accuracy": 0.8495,
    "model_size_MB": 49.890486,
    "inference_time_per_batch_sec": 0.03762316703796387,
    "batch_size": 32
}
