# 1. Unstructured Magnitude Pruning - finding lazy neurons (whose weights is for e.g. 0.000032) and turning them off

In [None]:
from cswin_fpn_hybrid.resnet50_cswin.new_model import ResNetCSWinHybrid
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets
import os


def load_data():
    data_transforms = {
        'train': v2.Compose([
            v2.Resize((224, 224)),
            # ------------------------------------ baseline augmentation
            # v2.RandomHorizontalFlip(),
            # v2.RandomVerticalFlip(),
            # v2.ToTensor(),
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),
            # ------------------------------------ baseline augmentation

            # ------------------------------------ new data augmentation added

            # Geometric Transforms
            v2.RandomHorizontalFlip(p=0.5),

            v2.RandomRotation(degrees=15),
            # Slight zoom/shift
            v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),

            # Color/Signal Transforms
            v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),

            # Noise & Robustness
            # Gaussian Blur helps ignore grain/noise
            v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),

            v2.ToTensor(),
            # train 2022
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),
            # train 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),

            # Occlusion (The Precision Booster)
            v2.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
            # --------------------------------- new data augmentation added
        ]),
        'test': v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            # train 2022
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815])
            # train 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863])
        ]),
    }

    data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_2_2022train_2019test'
    # data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_1_2019train_2022test'

    dsets = {split: datasets.ImageFolder(os.path.join(data_dir, split), data_transforms[split])
             for split in ['train', 'test']}

    dset_loaders = {
        'train': torch.utils.data.DataLoader(dsets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : torch.utils.data.DataLoader(dsets['test'],  batch_size=batch_size, shuffle=False, num_workers=4),
    }

    return dset_loaders['train'], dset_loaders['test']

def measure_sparsity(model):
    # Calculates what % of the model is zeros
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        # check for standard layers
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # "weight" is the parameter
            if hasattr(module, "weight"):
                w = module.weight.data
                total_params += w.numel()
                zero_params += torch.sum(w == 0).item()

    print(f"Global Sparsity: {100. * zero_params / total_params:.2f}%")
    return zero_params / total_params

def apply_pruning(model, amount_cnn, amount_trans):
    print(f"\nStarting Pruning")
    print(f"Target is {amount_cnn*100}% on ResNet part, {amount_trans*100}% on Transformer part")
    cnn_count, trans_count = 0, 0

    for name, module in model.named_modules():
        # We need to prune both Conv2d and Linear layers
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if "stage3" in name or "stage4" in name:
                prune.l1_unstructured(module, name='weight', amount=amount_trans)
                trans_count += 1
            else:
                prune.l1_unstructured(module, name='weight', amount=amount_cnn)
                cnn_count += 1

    print(f"Pruning Applied: {cnn_count} layers at {amount_cnn} rate, {trans_count} layers at {amount_trans} rate.")
    print("Zeros injected.")

def simple_evaluate(model, loader, device, threshold):
    #  quick eval to check damage
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= threshold).long()

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    return acc, f1


def check_parameter_coverage(model):
    total_params = 0
    prunable_params = 0

    print(f"\nPruning Coverage Analysis")

    for name, module in model.named_modules():
        # Check if the module has parameters
        if hasattr(module, 'weight') and module.weight is not None:
            params = module.weight.numel()
            total_params += params

            # Is it one of the types we are targeting?
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prunable_params += params
            else:
                # These are the ones we are skipping
                print(f"Skipping: {name:<40} | Type: {type(module).__name__:<15} | Size: {params}")

    coverage = (prunable_params / total_params) * 100
    print("-" * 75)
    print(f"Total Params (Weights only): {total_params:,}")
    print(f"Targeted Params (Conv+Linear): {prunable_params:,}")
    print(f"Coverage: {coverage:.2f}%")

    if coverage > 95:
        print("all good")
    else:
        print("changed needed")



# 1. Load best Model
model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=True, cswin_pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_size = 128

path = 'threshold_0.27_hybrid_Tr2022_Te2019.pth'
model.load_state_dict(torch.load(path))

trainloader, testloader = load_data()

# 2. baseline stats
# 0.27 for train 2022, test 2019
# 0.07 for train 2019, test 2022
threshold = 0.27
acc_before, f1_before = simple_evaluate(model, testloader, device, threshold)
# acc_before = 0.8908
# f1_before = 0.8623
print(f"No pruning baseline acc, f1: {acc_before:.4f}, {f1_before:.4f}")

# 3. apply Pruning
apply_pruning(model, amount_cnn=0.08, amount_trans=0.08)
measure_sparsity(model)

# 4. Check 'Broken' Stats (Before fine-tuning)
print("Checking Pruned Accuracy (No Retraining)")
acc_after, f1_after = simple_evaluate(model, testloader, device, threshold)
print(f"Pruned Accuracy, F1: {acc_after:.4f}, {f1_after:.4f}")
print(f"Drop due to pruning: {(acc_before - acc_after)*100:.2f}%, {(f1_before - f1_after)*100:.2f}%")


# Sanity check, see if rights params are targeted
# check_parameter_coverage(model)

# Note: Pruning adds 'mask' buffers. To make it permanent/saveable:
# for name, module in model.named_modules():
#     if isinstance(module, (nn.Conv2d, nn.Linear)):
#         prune.remove(module, 'weight')

# 1.2.1 Switch to structured pruning instead (unstructured is hardware dependent)
# This is the one shot version

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets
import os
from cswin_fpn_hybrid.resnet50_cswin.new_model import ResNetCSWinHybrid
import torch.functional as F


def load_data():
    data_transforms = {
        'train': v2.Compose([
            v2.Resize((224, 224)),
            # ------------------------------------ baseline augmentation
            # v2.RandomHorizontalFlip(),
            # v2.RandomVerticalFlip(),
            # v2.ToTensor(),
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),
            # ------------------------------------ baseline augmentation

            # ------------------------------------ new data augmentation added

            # Geometric Transforms
            v2.RandomHorizontalFlip(p=0.5),

            v2.RandomRotation(degrees=15),
            # Slight zoom/shift
            v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),

            # Color/Signal Transforms
            v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),

            # Noise & Robustness
            # Gaussian Blur helps ignore grain/noise
            v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),

            v2.ToTensor(),
            # v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),
            v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),

            # Occlusion (The Precision Booster)
            v2.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
            # --------------------------------- new data augmentation added
        ]),
        'test': v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            # v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815])
            v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863])
        ]),
    }

    # data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_2_2022train_2019test'
    data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_1_2019train_2022test'

    dsets = {split: datasets.ImageFolder(os.path.join(data_dir, split), data_transforms[split])
             for split in ['train', 'test']}

    dset_loaders = {
        'train': torch.utils.data.DataLoader(dsets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : torch.utils.data.DataLoader(dsets['test'],  batch_size=batch_size, shuffle=False, num_workers=4),
    }

    return dset_loaders['train'], dset_loaders['test']

def measure_sparsity(model):
    # Calculates what % of the model is zeros
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        # check for standard layers
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # "weight" is the parameter
            if hasattr(module, "weight"):
                w = module.weight.data
                total_params += w.numel()
                zero_params += torch.sum(w == 0).item()

    print(f"Global Sparsity: {100. * zero_params / total_params:.2f}%")
    return zero_params / total_params

def apply_pruning(model, amount_cnn, amount_trans):
    print(f"\nStarting Pruning")
    print(f"Target is {amount_cnn*100}% on ResNet part, {amount_trans*100}% on Transformer part")
    cnn_count, trans_count = 0, 0

    for name, module in model.named_modules():

        # prune Conv2d (ResNet/Bridge)
        # prune 'dim=0' to remove entire output filters
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount_cnn, n=2, dim=0)
            cnn_count += 1

        # prune Linear Layers
        elif isinstance(module, nn.Linear):
            if "stage3" in name or "stage4" in name:
                # transformers -> sensitive to dimension changes
                # prune 'dim=0' (output neurons) to be safe
                prune.ln_structured(module, name='weight', amount=amount_trans, n=2, dim=0)
                trans_count += 1
            else:
                # classifier head / bridge linear layers
                prune.ln_structured(module, name='weight', amount=amount_cnn, n=2, dim=0)
                cnn_count += 1

    print(f"Pruning Applied: {cnn_count} layers at {amount_cnn} rate, {trans_count} layers at {amount_trans} rate.")
    print("Zeros injected.")

def simple_evaluate(model, loader, device, threshold):
    #  quick eval to check damage
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= threshold).long()

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    return acc, f1


def check_parameter_coverage(model):
    total_params = 0
    prunable_params = 0

    print(f"\nPruning Coverage Analysis")

    for name, module in model.named_modules():
        # Check if the module has parameters
        if hasattr(module, 'weight') and module.weight is not None:
            params = module.weight.numel()
            total_params += params

            # Is it one of the types we are targeting?
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prunable_params += params
            else:
                # These are the ones we are skipping
                print(f"Skipping: {name:<40} | Type: {type(module).__name__:<15} | Size: {params}")

    coverage = (prunable_params / total_params) * 100
    print("-" * 75)
    print(f"Total Params (Weights only): {total_params:,}")
    print(f"Targeted Params (Conv+Linear): {prunable_params:,}")
    print(f"Coverage: {coverage:.2f}%")

    if coverage > 95:
        print("all good")
    else:
        print("changed needed")



# 1. Load best Model
model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=True, cswin_pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_size = 128

# path = 'threshold_0.27_hybrid_Tr2022_Te2019.pth'
path = 'threshold_0.07_hybrid_Tr2019_Te2022.pth'
model.load_state_dict(torch.load(path))

trainloader, testloader = load_data()

# 2. baseline stats
# 0.27 for train 2022, test 2019
# 0.07 for train 2019, test 2022
threshold = 0.07

# acc_before, f1_before = simple_evaluate(model, testloader, device, threshold)
# print(f"No pruning baseline acc, f1: {acc_before:.4f}, {f1_before:.4f}")
acc_before = 0.8476
f1_before = 0.8882

# 3. apply Pruning
apply_pruning(model, amount_cnn=0.15, amount_trans=0.17)
measure_sparsity(model)

# 4. Check 'Broken' Stats (Before fine-tuning)
print("Checking Pruned Accuracy (No Retraining)")
acc_after, f1_after = simple_evaluate(model, testloader, device, threshold)
print(f"Pruned Accuracy, F1: {acc_after:.4f}, {f1_after:.4f}")
print(f"Drop due to pruning: {(acc_before - acc_after)*100:.2f}%, {(f1_before - f1_after)*100:.2f}%")

## 1.2.2 Iterative structured pruning (according to the paper Charis mentioned)

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np
import copy
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets
import os
from cswin_fpn_hybrid.resnet50_cswin.new_model import ResNetCSWinHybrid

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        pt = torch.exp(-ce_loss)

        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


def load_data():
    data_transforms = {
        'train': v2.Compose([
            v2.Resize((224, 224)),
            # ------------------------------------ baseline augmentation
            # v2.RandomHorizontalFlip(),
            # v2.RandomVerticalFlip(),
            # v2.ToTensor(),
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),
            # ------------------------------------ baseline augmentation

            # ------------------------------------ new data augmentation added

            # Geometric Transforms
            v2.RandomHorizontalFlip(p=0.5),

            v2.RandomRotation(degrees=15),
            # Slight zoom/shift
            v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),

            # Color/Signal Transforms
            v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),

            # Noise & Robustness
            # Gaussian Blur helps ignore grain/noise
            v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),

            v2.ToTensor(),
            # v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),
            v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),

            # Occlusion (The Precision Booster)
            v2.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
            # --------------------------------- new data augmentation added
        ]),
        'test': v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            # v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815])
            v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863])
        ]),
    }

    # data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_2_2022train_2019test'
    data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_1_2019train_2022test'

    dsets = {split: datasets.ImageFolder(os.path.join(data_dir, split), data_transforms[split])
             for split in ['train', 'test']}

    dset_loaders = {
        'train': torch.utils.data.DataLoader(dsets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : torch.utils.data.DataLoader(dsets['test'],  batch_size=batch_size, shuffle=False, num_workers=4),
    }

    return dset_loaders['train'], dset_loaders['test']


# -----------------------------------------------------------------------------

def measure_sparsity(model):
    total_params = 0
    zero_params = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if hasattr(module, "weight"):
                w = module.weight.data
                total_params += w.numel()
                zero_params += torch.sum(w == 0).item()
    print(f"Global Sparsity: {100. * zero_params / total_params:.2f}%")
    return zero_params / total_params

def simple_evaluate(model, loader, device, threshold):
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= threshold).long()
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    return acc, f1

def get_pruning_rate_per_step(final_target, k_steps):
    # ensures that after k steps, exactly final_target % is removed
    if final_target <= 0: return 0.0
    return 1 - (1 - final_target) ** (1 / k_steps)

def apply_structured_pruning_step(model, rate_cnn, rate_trans):
    # applies one round of structured pruning using L2 norm (n=2)
    count = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Prune output filters (dim=0) using L2 norm
            prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
            count += 1
        elif isinstance(module, nn.Linear):
            # Prune output neurons (dim=0)
            if "stage3" in name or "stage4" in name:
                prune.ln_structured(module, name='weight', amount=rate_trans, n=2, dim=0)
            else:
                prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
            count += 1
    return count

def make_pruning_permanent(model):
    # burn mask into the weights so the next step treats the zeros as non-existent
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if prune.is_pruned(module):
                prune.remove(module, 'weight')

def fine_tune_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)

# -----------------------------------------------------------------------------

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainloader, testloader = load_data()

model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=True, cswin_pretrained=True)
path = 'threshold_0.07_hybrid_Tr2019_Te2022.pth'
model.load_state_dict(torch.load(path))
model.to(device)


FINAL_CNN_TARGET = 0.15
FINAL_TRANS_TARGET = 0.17
STEPS = 6                 # k=6 (Paper Recommendation)
EPOCHS_PER_STEP = 10
THRESHOLD = 0.07

cnn_step_rate = get_pruning_rate_per_step(FINAL_CNN_TARGET, STEPS)
trans_step_rate = get_pruning_rate_per_step(FINAL_TRANS_TARGET, STEPS)

print(f"Goal: {FINAL_CNN_TARGET*100}% resnet, {FINAL_TRANS_TARGET*100}% trans")
print(f"Schedule: {STEPS} steps, {EPOCHS_PER_STEP} epochs/step")
print(f"Per step pruning rate - resnet: {cnn_step_rate:.4f}, trans: {trans_step_rate:.4f}")

acc_base, f1_base = simple_evaluate(model, testloader, device, THRESHOLD)
print(f"Baseline before acc: {acc_base:.4f}, F1: {f1_base:.4f}\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # low LR for gentle fine-tuning
criterion = FocalLoss(gamma=2.0).to(device) # use same loss function during initial training

for step in range(1, STEPS + 1):
    print(f"Step {step}/{STEPS}")

    # prune
    print("Applying pruning")
    apply_structured_pruning_step(model, cnn_step_rate, trans_step_rate)
    measure_sparsity(model)

    # rehab
    print(f"Finetuning for {EPOCHS_PER_STEP} epochs")
    for epoch in range(EPOCHS_PER_STEP):
        loss = fine_tune_epoch(model, trainloader, optimizer, criterion, device)
        if (epoch+1) % 5 == 0:
            print(f"   Epoch {epoch+1}/{EPOCHS_PER_STEP} | Loss: {loss:.4f}")

    # check status
    acc, f1 = simple_evaluate(model, testloader, device, THRESHOLD)
    print(f"Step {step} result: acc: {acc:.4f} (drop: {(acc_base-acc)*100:.2f}%)")

make_pruning_permanent(model)
acc_final, f1_final = simple_evaluate(model, testloader, device, THRESHOLD)
print(f"Final result: acc: {acc_final:.4f}, {f1_final:.4f}, (drop: {(acc_base-acc_final)*100:.2f}%, {(f1_base - f1_final)*100:.2f}%)")

# Final Save
# measure_sparsity(model)
# save_folder = 'model_saves'
# os.makedirs(save_folder, exist_ok=True)
# save_path = os.path.join(save_folder, 'threshold_0.07_hybrid_Tr2019_Te2022_new_optimised.pth')
# torch.save(model.state_dict(), save_path)
# print(f"Model saved to: {save_path}")

# 2. Make pruning permanent and apply quantization to INT8 (currently INT32)

In [None]:
import torch.quantization
import os
import time

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    print(f"Model: {label:<15} | Size: {size/1e6:.2f} MB")
    os.remove("temp.p")
    return size

def measure_inference_speed(model, loader, device):
    model.eval()
    # Warmup
    dummy_input, _ = next(iter(loader))
    dummy_input = dummy_input.to(device)
    for _ in range(10):
        _ = model(dummy_input)

    start = time.time()
    count = 0
    with torch.no_grad():
        for inputs, _ in loader:
            inputs = inputs.to(device)
            _ = model(inputs)
            count += inputs.size(0)
            if count > 200: # Measure first 200 images only to save time
                break
    end = time.time()

    latency = (end - start) / count * 1000 # ms per image
    print(f"Latency: {latency:.2f} ms/image")
    return latency

# 1. Make Pruning Permanent
print("\nFinalizing Compression")
for name, module in model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        if hasattr(module, 'weight_orig'):
            torch.nn.utils.prune.remove(module, 'weight')

print("Pruning masks removed. Weights are permanently sparse.")
size_fp32 = print_size_of_model(model, "Pruned FP32")

# 2. Apply Dynamic Quantization (INT8)
# We quantify Linear layers. Conv2d quantization usually requires 'Static' quantization
# which is more complex, but let's try standard dynamic first as it's the easiest win for Transformers.
quantized_model = torch.quantization.quantize_dynamic(
    model.cpu(),
    {torch.nn.Linear},
    dtype=torch.qint8
)

print("\nQuantization Results")
size_int8 = print_size_of_model(quantized_model, "Quantized INT8")

reduction = (size_fp32 - size_int8) / size_fp32 * 100
print(f"Size Reduction: {reduction:.2f}%")

# 3. Save the Efficient Model
save_folder = 'model_saves'
os.makedirs(save_folder, exist_ok=True)
q_path = os.path.join(save_folder, 'threshold_0.27_hybrid_Tr2022_Te2019_optimised.pth')
torch.save(quantized_model.state_dict(), q_path)
print(f"Efficient model saved to: {q_path}")

# 3. Test performance on newly compressed model

In [None]:
import time
from sklearn.metrics import accuracy_score,f1_score

def evaluate_quantized(model, loader, device='cpu'):
    # MUST USE CPU due to pytorch
    model.to('cpu')
    model.eval()

    all_preds = []
    all_targets = []

    start = time.time()
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to('cpu')
            outputs = model(inputs)

            probs = torch.softmax(outputs, dim=1)[:, 1]
            # USE YOUR OPTIMAL THRESHOLD HERE
            # 0.27 for train 2022, test 2019
            # 0.07 for train 2019, test 2022
            preds = (probs >= 0.27).long()

            all_preds.extend(preds.numpy())
            all_targets.extend(targets.numpy())

    end = time.time()
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    print(f"NEW Compressed Model Accuracy: {acc:.4f}")
    print(f"NEW Compressed Model F1: {f1:.4f}")
    print(f"Inference Time: {end - start:.2f} seconds")

print("\nTesting Efficient Model")
evaluate_quantized(quantized_model, testloader)

In [None]:
# Cell 1.2.2 2019 -> 2022
# here, i wasn't iteratively increasing the pruning, it stayed the same
# so, in other words, there was a cut of 2.93% of the network only in the beginning and the rest 5 steps are just fine tuning for some more epochs
# so by mistake, giving it more epoch improved a lot the performance!!!!!

# Goal: 15.0% resnet, 17.0% trans
# Schedule: 6 steps, 10 epochs/step
# Per step pruning rate - resnet: 0.0267, trans: 0.0306
# Baseline before acc: 0.8476, F1: 0.8882
#
# Step 1/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0782
#    Epoch 10/10 | Loss: 0.0708
# Step 1 result: acc: 0.8394 (drop: 0.82%)
# Step 2/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0663
#    Epoch 10/10 | Loss: 0.0671
# Step 2 result: acc: 0.8948 (drop: -4.72%)
# Step 3/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0660
#    Epoch 10/10 | Loss: 0.0624
# Step 3 result: acc: 0.8996 (drop: -5.19%)
# Step 4/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0602
#    Epoch 10/10 | Loss: 0.0581
# Step 4 result: acc: 0.8758 (drop: -2.82%)
# Step 5/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0584
#    Epoch 10/10 | Loss: 0.0594
# Step 5 result: acc: 0.8733 (drop: -2.57%)
# Step 6/6
# Applying pruning
# Global Sparsity: 2.93%
# Finetuning for 10 epochs
#    Epoch 5/10 | Loss: 0.0531
#    Epoch 10/10 | Loss: 0.0537
# Step 6 result: acc: 0.8907 (drop: -4.31%)
# Final result: acc: 0.8907, 0.9242, (drop: -4.31%, -3.60%)