# 1. Unstructured Magnitude Pruning - finding lazy neurons (whose weights is for e.g. 0.000032) and turning them off

In [None]:
from cswin_fpn_hybrid.resnet50_cswin.new_model import ResNetCSWinHybrid
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets
import os


def load_data():
    data_transforms = {
        'train': v2.Compose([
            v2.Resize((224, 224)),
            # ------------------------------------ baseline augmentation
            # v2.RandomHorizontalFlip(),
            # v2.RandomVerticalFlip(),
            # v2.ToTensor(),
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),
            # ------------------------------------ baseline augmentation

            # ------------------------------------ new data augmentation added

            # Geometric Transforms
            v2.RandomHorizontalFlip(p=0.5),

            v2.RandomRotation(degrees=15),
            # Slight zoom/shift
            v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),

            # Color/Signal Transforms
            v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),

            # Noise & Robustness
            # Gaussian Blur helps ignore grain/noise
            v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),

            v2.ToTensor(),
            # train 2022
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),
            # train 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),

            # Occlusion (The Precision Booster)
            v2.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
            # --------------------------------- new data augmentation added
        ]),
        'test': v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            # train 2022
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815])
            # train 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863])
        ]),
    }

    data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_2_2022train_2019test'
    # data_dir = 'DeepLearning_PlantDiseases-master/Scripts/PlantVillage_1_2019train_2022test'

    dsets = {split: datasets.ImageFolder(os.path.join(data_dir, split), data_transforms[split])
             for split in ['train', 'test']}

    dset_loaders = {
        'train': torch.utils.data.DataLoader(dsets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : torch.utils.data.DataLoader(dsets['test'],  batch_size=batch_size, shuffle=False, num_workers=4),
    }

    return dset_loaders['train'], dset_loaders['test']

def measure_sparsity(model):
    # Calculates what % of the model is zeros
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        # check for standard layers
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # "weight" is the parameter
            if hasattr(module, "weight"):
                w = module.weight.data
                total_params += w.numel()
                zero_params += torch.sum(w == 0).item()

    print(f"Global Sparsity: {100. * zero_params / total_params:.2f}%")
    return zero_params / total_params

def apply_pruning(model, amount_cnn, amount_trans):
    print(f"\nStarting Pruning")
    print(f"Target is {amount_cnn*100}% on ResNet part, {amount_trans*100}% on Transformer part")
    cnn_count, trans_count = 0, 0

    for name, module in model.named_modules():
        # We need to prune both Conv2d and Linear layers
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if "stage3" in name or "stage4" in name:
                prune.l1_unstructured(module, name='weight', amount=amount_trans)
                trans_count += 1
            else:
                prune.l1_unstructured(module, name='weight', amount=amount_cnn)
                cnn_count += 1

    print(f"Pruning Applied: {cnn_count} layers at {amount_cnn} rate, {trans_count} layers at {amount_trans} rate.")
    print("Zeros injected.")

def simple_evaluate(model, loader, device, threshold):
    #  quick eval to check damage
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= threshold).long()

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    return acc, f1


def check_parameter_coverage(model):
    total_params = 0
    prunable_params = 0

    print(f"\nPruning Coverage Analysis")

    for name, module in model.named_modules():
        # Check if the module has parameters
        if hasattr(module, 'weight') and module.weight is not None:
            params = module.weight.numel()
            total_params += params

            # Is it one of the types we are targeting?
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prunable_params += params
            else:
                # These are the ones we are skipping
                print(f"Skipping: {name:<40} | Type: {type(module).__name__:<15} | Size: {params}")

    coverage = (prunable_params / total_params) * 100
    print("-" * 75)
    print(f"Total Params (Weights only): {total_params:,}")
    print(f"Targeted Params (Conv+Linear): {prunable_params:,}")
    print(f"Coverage: {coverage:.2f}%")

    if coverage > 95:
        print("all good")
    else:
        print("changed needed")



# 1. Load best Model
model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=True, cswin_pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_size = 128

path = 'threshold_0.27_hybrid_Tr2022_Te2019.pth'
model.load_state_dict(torch.load(path))

trainloader, testloader = load_data()

# 2. baseline stats
# 0.27 for train 2022, test 2019
# 0.07 for train 2019, test 2022
threshold = 0.27
acc_before, f1_before = simple_evaluate(model, testloader, device, threshold)
# acc_before = 0.8908
# f1_before = 0.8623
print(f"No pruning baseline acc, f1: {acc_before:.4f}, {f1_before:.4f}")

# 3. apply Pruning
apply_pruning(model, amount_cnn=0.08, amount_trans=0.08)
measure_sparsity(model)

# 4. Check 'Broken' Stats (Before fine-tuning)
print("Checking Pruned Accuracy (No Retraining)")
acc_after, f1_after = simple_evaluate(model, testloader, device, threshold)
print(f"Pruned Accuracy, F1: {acc_after:.4f}, {f1_after:.4f}")
print(f"Drop due to pruning: {(acc_before - acc_after)*100:.2f}%, {(f1_before - f1_after)*100:.2f}%")


# Sanity check, see if rights params are targeted
# check_parameter_coverage(model)

# Note: Pruning adds 'mask' buffers. To make it permanent/saveable:
# for name, module in model.named_modules():
#     if isinstance(module, (nn.Conv2d, nn.Linear)):
#         prune.remove(module, 'weight')

# 2. Make pruning permanent and apply quantization to INT8 (currently INT32)

In [None]:
import torch.quantization
import os
import time

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    print(f"Model: {label:<15} | Size: {size/1e6:.2f} MB")
    os.remove("temp.p")
    return size

def measure_inference_speed(model, loader, device):
    model.eval()
    # Warmup
    dummy_input, _ = next(iter(loader))
    dummy_input = dummy_input.to(device)
    for _ in range(10):
        _ = model(dummy_input)

    start = time.time()
    count = 0
    with torch.no_grad():
        for inputs, _ in loader:
            inputs = inputs.to(device)
            _ = model(inputs)
            count += inputs.size(0)
            if count > 200: # Measure first 200 images only to save time
                break
    end = time.time()

    latency = (end - start) / count * 1000 # ms per image
    print(f"Latency: {latency:.2f} ms/image")
    return latency

# 1. Make Pruning Permanent
print("\nFinalizing Compression")
for name, module in model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        if hasattr(module, 'weight_orig'):
            torch.nn.utils.prune.remove(module, 'weight')

print("Pruning masks removed. Weights are permanently sparse.")
size_fp32 = print_size_of_model(model, "Pruned FP32")

# 2. Apply Dynamic Quantization (INT8)
# We quantify Linear layers. Conv2d quantization usually requires 'Static' quantization
# which is more complex, but let's try standard dynamic first as it's the easiest win for Transformers.
quantized_model = torch.quantization.quantize_dynamic(
    model.cpu(),
    {torch.nn.Linear},
    dtype=torch.qint8
)

print("\nQuantization Results")
size_int8 = print_size_of_model(quantized_model, "Quantized INT8")

reduction = (size_fp32 - size_int8) / size_fp32 * 100
print(f"Size Reduction: {reduction:.2f}%")

# 3. Save the Efficient Model
save_folder = 'model_saves'
os.makedirs(save_folder, exist_ok=True)
q_path = os.path.join(save_folder, 'threshold_0.27_hybrid_Tr2022_Te2019_optimised.pth')
torch.save(quantized_model.state_dict(), q_path)
print(f"Efficient model saved to: {q_path}")

# 3. Test performance on newly compressed model

In [None]:
import time
from sklearn.metrics import accuracy_score,f1_score

def evaluate_quantized(model, loader, device='cpu'):
    # MUST USE CPU due to pytorch
    model.to('cpu')
    model.eval()

    all_preds = []
    all_targets = []

    start = time.time()
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to('cpu')
            outputs = model(inputs)

            probs = torch.softmax(outputs, dim=1)[:, 1]
            # USE YOUR OPTIMAL THRESHOLD HERE
            # 0.27 for train 2022, test 2019
            # 0.07 for train 2019, test 2022
            preds = (probs >= 0.27).long()

            all_preds.extend(preds.numpy())
            all_targets.extend(targets.numpy())

    end = time.time()
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    print(f"NEW Compressed Model Accuracy: {acc:.4f}")
    print(f"NEW Compressed Model F1: {f1:.4f}")
    print(f"Inference Time: {end - start:.2f} seconds")

print("\nTesting Efficient Model")
evaluate_quantized(quantized_model, testloader)