# 1. Iterative structured pruning

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np
import copy
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets
import os
from phase_performance_hybrids.resnet50_cswin.model_v2 import ResNetCSWinHybrid
# or use from phase_performance_hybrids.resnet50_cswin.model_v3 import ResNetCSWinHybridV3 as ReNsetCSWinHybrid for the parallel model

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        pt = torch.exp(-ce_loss)

        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


def load_data():
    data_transforms = {
        'train': v2.Compose([
            v2.Resize((224, 224)),

            # ------------------------------------ BASIC AUGMENTATION
            v2.RandomHorizontalFlip(),
            v2.RandomVerticalFlip(),
            v2.ToTensor(),
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),
            # ------------------------------------ BASIC AUGMENTATION

            # ------------------------------------ HEAVY AUGMENTATION

            # # Geometric Transforms
            # v2.RandomHorizontalFlip(p=0.5),

            # v2.RandomRotation(degrees=15),
            # # Slight zoom/shift
            # v2.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),

            # # Color/Signal Transforms
            # v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),

            # # Noise & Robustness
            # # Gaussian Blur helps ignore grain/noise
            # v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),

            # v2.ToTensor(),

            # use this norm below when training on 2022
            # # v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815]),

            # use this norm below when training on 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863]),

            # # Occlusion (The Precision Booster)
            # v2.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
            # --------------------------------- HEAVY AUGMENTATION
        ]),
        'test': v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            # use this norm below when training on 2022
            v2.Normalize([0.7083, 0.2776, 0.0762], [0.1704, 0.1296, 0.0815])

            # use this norm below when training on 2019
            # v2.Normalize([0.7553, 0.3109, 0.1059], [0.1774, 0.1262, 0.0863])
        ]),
    }

    # use this below when training on 2019
    # data_dir = '../cross_year_configurations_data/PlantVillage_1_2019train_2022test'

    # use this below when training on 2022
    data_dir = '../cross_year_configurations_data/PlantVillage_2_2022train_2019test'

    dsets = {split: datasets.ImageFolder(os.path.join(data_dir, split), data_transforms[split])
             for split in ['train', 'test']}

    dset_loaders = {
        'train': torch.utils.data.DataLoader(dsets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
        'test' : torch.utils.data.DataLoader(dsets['test'],  batch_size=batch_size, shuffle=False, num_workers=4),
    }

    return dset_loaders['train'], dset_loaders['test']



def measure_sparsity(model):
    total_params = 0
    zero_params = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if hasattr(module, "weight"):
                w = module.weight.data
                total_params += w.numel()
                zero_params += torch.sum(w == 0).item()
    print(f"Global Sparsity: {100. * zero_params / total_params:.2f}%")
    return zero_params / total_params


def simple_evaluate(model, loader, device, threshold):
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = (probs >= threshold).long()
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    return acc, f1

def get_pruning_rate_per_step(final_target, k_steps):
    if final_target <= 0: return 0.0
    return 1 - (1 - final_target) ** (1 / k_steps)

# for model_v2 (sequential)
def apply_structured_pruning_step(model, rate_cnn, rate_trans):
    # applies one round of structured pruning using L2 norm
    count = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Prune output filters (dim=0) using L2 norm
            prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
            count += 1
        elif isinstance(module, nn.Linear):
            # Prune output neurons (dim=0)
            if "stage3" in name or "stage4" in name:
                prune.ln_structured(module, name='weight', amount=rate_trans, n=2, dim=0)
            else:
                prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
            count += 1
    return count

# for model_v3 (parallel)
# def apply_structured_pruning_step(model, rate_cnn, rate_trans):
#     count = 0
#     for name, module in model.named_modules():

#         if "head" in name:
#             continue
#         if isinstance(module, nn.Conv2d):
#             prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
#             count += 1
#         elif isinstance(module, nn.Linear):
#             is_transformer_part = any(k in name for k in [
#                 "stage",        # CSWin Stages 1-4
#                 "merge",        # CSWin Merges
#                 "cross_attn",   # The Cross Attention in Fusion
#                 "fusion",       # The FFN in Fusion
#                 "cswin_proj",   # Projections
#                 "fused_to"      # Projections
#             ])

#             if is_transformer_part:
#                 prune.ln_structured(module, name='weight', amount=rate_trans, n=2, dim=0)
#             else:
#                 # Treat everything else (e.g. Bridge projections) as CNN
#                 prune.ln_structured(module, name='weight', amount=rate_cnn, n=2, dim=0)
#             count += 1

#     return count

def make_pruning_permanent(model):
    # burn mask into the weights so the next step treats the zeros as non-existent
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if prune.is_pruned(module):
                prune.remove(module, 'weight')

def fine_tune_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)

# -----------------------------------------------------------------------------

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainloader, testloader = load_data()

model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=True, cswin_pretrained=True)
path = '../phase_performance_hybrids/results_model_saves_resnet50_cswin/results_from_model_v2_heavy_augmentation/threshold_0.27_hybridv2_Tr2022_Te2019.pth'
model.load_state_dict(torch.load(path))
model.to(device)


FINAL_CNN_TARGET = 0.15
FINAL_TRANS_TARGET = 0.17
STEPS = 6
EPOCHS_PER_STEP = 10
THRESHOLD = 0.07

cnn_step_rate = get_pruning_rate_per_step(FINAL_CNN_TARGET, STEPS)
trans_step_rate = get_pruning_rate_per_step(FINAL_TRANS_TARGET, STEPS)

print(f"Goal: {FINAL_CNN_TARGET*100}% resnet, {FINAL_TRANS_TARGET*100}% trans")
print(f"Schedule: {STEPS} steps, {EPOCHS_PER_STEP} epochs/step")
print(f"Per step pruning rate - resnet: {cnn_step_rate:.4f}, trans: {trans_step_rate:.4f}")

acc_base, f1_base = simple_evaluate(model, testloader, device, THRESHOLD)
print(f"Baseline before acc: {acc_base:.4f}, F1: {f1_base:.4f}\n")

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # low LR for gentle fine-tuning
criterion = FocalLoss(gamma=2.0).to(device) # use same loss function during initial training

for step in range(1, STEPS + 1):
    print(f"Step {step}/{STEPS}")

    # prune
    print("Applying pruning")
    apply_structured_pruning_step(model, cnn_step_rate, trans_step_rate)
    measure_sparsity(model)

    # rehab
    print(f"Finetuning for {EPOCHS_PER_STEP} epochs")
    for epoch in range(EPOCHS_PER_STEP):
        loss = fine_tune_epoch(model, trainloader, optimizer, criterion, device)
        if (epoch+1) % 5 == 0:
            print(f"   Epoch {epoch+1}/{EPOCHS_PER_STEP} | Loss: {loss:.4f}")

    # check status
    acc, f1 = simple_evaluate(model, testloader, device, THRESHOLD)
    print(f"Step {step} result: acc: {acc:.4f} (drop: {(acc_base-acc)*100:.2f}%)")

make_pruning_permanent(model)
acc_final, f1_final = simple_evaluate(model, testloader, device, THRESHOLD)
print(f"Final result: acc: {acc_final:.4f}, {f1_final:.4f}, (drop: {(acc_base-acc_final)*100:.2f}%, {(f1_base - f1_final)*100:.2f}%)")

# Final Save
# measure_sparsity(model)
# save_folder = 'model_saves'
# os.makedirs(save_folder, exist_ok=True)
# save_path = os.path.join(save_folder, 'threshold_0.07_hybrid_Tr2019_Te2022_new_optimised.pth')
# torch.save(model.state_dict(), save_path)
# print(f"Model saved to: {save_path}")

# 2. Dynamic Quantization (INT8)

In [None]:
import torch.quantization
import os
import time
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torchvision import datasets

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    print(f"Model: {label:<15} | Size: {size/1e6:.2f} MB")
    os.remove("temp.p")
    return size

def measure_inference_speed(model, loader, device):
    model.eval()
    # Warmup
    dummy_input, _ = next(iter(loader))
    dummy_input = dummy_input.to(device)
    for _ in range(10):
        _ = model(dummy_input)

    start = time.time()
    count = 0
    with torch.no_grad():
        for inputs, _ in loader:
            inputs = inputs.to(device)
            _ = model(inputs)
            count += inputs.size(0)
            if count > 200: # Measure first 200 images only to save time
                break
    end = time.time()

    latency = (end - start) / count * 1000 # ms per image
    print(f"Latency: {latency:.2f} ms/image")
    return latency


CHECKPOINT_PATH = '../phase_efficiency/iterative_structured_pruning/train22_test19_new_model/threshold_0.27_hybridv2_heavyAug_Tr2022_Te2019_iterStructPrunOnly.pth'

model = ResNetCSWinHybrid(num_classes=2, resnet_pretrained=False, cswin_pretrained=False)
state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()

size_fp32 = print_size_of_model(model, "Pruned FP32")

# Apply Dynamic Quantization (INT8)
# Quantify Linear layers. Conv2d quantization usually requires 'Static' quantization
# which is more complex, but let's try standard dynamic first as it's the easiest win for Transformers.
quantized_model = torch.quantization.quantize_dynamic(
    model.cpu(),
    {torch.nn.Linear},
    dtype=torch.qint8
)

print("\nQuantization Results")
size_int8 = print_size_of_model(quantized_model, "Quantized INT8")

reduction = (size_fp32 - size_int8) / size_fp32 * 100
print(f"Size Reduction: {reduction:.2f}%")

# 3. Save the Efficient Model
save_folder = '../phase_efficiency/quantization_int8/train22_test19_new_model'
os.makedirs(save_folder, exist_ok=True)
q_path = os.path.join(save_folder, 'threshold_0.27_hybridv2_heavyAug_Tr2022_Te2019_pruned_quantized.pth')
torch.save(quantized_model.state_dict(), q_path)
print(f"Efficient model saved to: {q_path}")

# 3. Test performance on newly compressed model

In [None]:
import time
from sklearn.metrics import accuracy_score,f1_score

def evaluate_quantized(model, loader, device='cpu'):
    # MUST USE CPU due to pytorch
    model.to('cpu')
    model.eval()

    all_preds = []
    all_targets = []

    start = time.time()
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to('cpu')
            outputs = model(inputs)

            probs = torch.softmax(outputs, dim=1)[:, 1]
            # USE YOUR OPTIMAL THRESHOLD HERE
            # 0.27 for train 2022, test 2019
            # 0.07 for train 2019, test 2022
            preds = (probs >= 0.27).long()

            all_preds.extend(preds.numpy())
            all_targets.extend(targets.numpy())

    end = time.time()
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)
    print(f"NEW Compressed Model Accuracy: {acc:.4f}")
    print(f"NEW Compressed Model F1: {f1:.4f}")
    print(f"Inference Time: {end - start:.2f} seconds")

print("\nTesting Efficient Model")
evaluate_quantized(quantized_model, testloader)

# 4. Knowledge Distillation
## For loading a specific model as the teacher, please visit distillation_config.py , line 69.

In [None]:
import os
import sys
import torch
from phase_efficiency.knowledge_distillation.scripts.train_knowledge_distillation import train_with_distillation, load_data
from phase_efficiency.knowledge_distillation.scripts.distillation_config import (
    EFFICIENTNET_VARIANTS,
    DISTILLATION_CONFIGS,
    LR_CONFIGS,
    TEACHER_MODELS,
    DATA_CONFIGS
)


def run_single_experiment(
    experiment_name,
    teacher_checkpoint,
    student_name,
    trainloader,
    testloader,
    distillation_config,
    lr_config
):
    """
    Run a single distillation experiment with given configuration
    """
    print("\n" + "="*80)
    print(f"EXPERIMENT: {experiment_name}")
    print("="*80)
    print(f"Student Model: {student_name}")
    print(f"Teacher: {teacher_checkpoint}")
    print(f"Temperature: {distillation_config['temperature']}")
    print(f"Alpha: {distillation_config['alpha']}")
    print(f"Epochs: {distillation_config['epochs']}")
    print(f"Backbone LR: {lr_config['backbone_lr']}")
    print(f"Head LR: {lr_config['head_lr']}")
    print("="*80 + "\n")

    try:
        losses, trained_student = train_with_distillation(
            teacher_checkpoint=teacher_checkpoint,
            student_name=student_name,
            trainloader=trainloader,
            testloader=testloader,
            num_classes=2,
            epochs=distillation_config['epochs'],
            temperature=distillation_config['temperature'],
            alpha=distillation_config['alpha']
        )

        # Save experiment-specific model
        save_folder = 'cswin_fpn_hybrid/model_saves/experiments'
        os.makedirs(save_folder, exist_ok=True)
        model_save_path = os.path.join(save_folder, f'{experiment_name}_{student_name}.pth')
        torch.save(trained_student.state_dict(), model_save_path)
        print(f"\nExperiment model saved: {model_save_path}")

        return True

    except Exception as e:
        print(f"\n[ERROR] Experiment {experiment_name} failed: {str(e)}")
        return False


def run_student_comparison():
    """
    Compare different EfficientNet variants as students
    """
    print("\n" + "="*80)
    print("EXPERIMENT SET: Student Model Comparison")
    print("="*80)

    trainloader, testloader = load_data()
    teacher_checkpoint = TEACHER_MODELS['default']
    distillation_config = DISTILLATION_CONFIGS['default']
    lr_config = LR_CONFIGS['default']

    for student_name in EFFICIENTNET_VARIANTS[:3]:  # Test first 3 variants
        experiment_name = f"student_comparison_{student_name}"
        run_single_experiment(
            experiment_name=experiment_name,
            teacher_checkpoint=teacher_checkpoint,
            student_name=student_name,
            trainloader=trainloader,
            testloader=testloader,
            distillation_config=distillation_config,
            lr_config=lr_config
        )


def run_temperature_ablation():
    """
    Ablation study on temperature parameter
    """
    print("\n" + "="*80)
    print("EXPERIMENT SET: Temperature Ablation Study")
    print("="*80)

    trainloader, testloader = load_data()
    teacher_checkpoint = TEACHER_MODELS['default']
    student_name = 'efficientnet_b0'
    lr_config = LR_CONFIGS['default']

    temperatures = [2.0, 4.0, 6.0, 8.0]
    # alpha = 0.7
    alpha = 0.3
    # modify best alpha here after getting best one

    for temp in temperatures:
        experiment_name = f"temp_ablation_T{temp}"
        distillation_config = {
            'temperature': temp,
            'alpha': alpha,
            'epochs': 100,  # Shorter for ablation
            'batch_size': 128,
        }

        run_single_experiment(
            experiment_name=experiment_name,
            teacher_checkpoint=teacher_checkpoint,
            student_name=student_name,
            trainloader=trainloader,
            testloader=testloader,
            distillation_config=distillation_config,
            lr_config=lr_config
        )


def run_alpha_ablation():
    """
    Ablation study on alpha parameter (distillation vs hard label weight)
    """
    print("\n" + "="*80)
    print("EXPERIMENT SET: Alpha Ablation Study")
    print("="*80)

    trainloader, testloader = load_data()
    teacher_checkpoint = TEACHER_MODELS['default']
    student_name = 'efficientnet_b0'
    lr_config = LR_CONFIGS['default']

    alphas = [0.3, 0.5, 0.7, 0.9]
    temperature = 4.0

    for alpha in alphas:
        experiment_name = f"alpha_ablation_A{alpha}"
        distillation_config = {
            'temperature': temperature,
            'alpha': alpha,
            'epochs': 100,  # Shorter for ablation
            'batch_size': 128,
        }

        run_single_experiment(
            experiment_name=experiment_name,
            teacher_checkpoint=teacher_checkpoint,
            student_name=student_name,
            trainloader=trainloader,
            testloader=testloader,
            distillation_config=distillation_config,
            lr_config=lr_config
        )


def run_preset_configs():
    """
    Run all preset configurations from config file
    """
    print("\n" + "="*80)
    print("EXPERIMENT SET: Preset Configurations")
    print("="*80)

    trainloader, testloader = load_data()
    teacher_checkpoint = TEACHER_MODELS['default']
    student_name = 'efficientnet_b0'
    lr_config = LR_CONFIGS['default']

    for config_name, distillation_config in DISTILLATION_CONFIGS.items():
        experiment_name = f"preset_{config_name}"

        run_single_experiment(
            experiment_name=experiment_name,
            teacher_checkpoint=teacher_checkpoint,
            student_name=student_name,
            trainloader=trainloader,
            testloader=testloader,
            distillation_config=distillation_config,
            lr_config=lr_config
        )


if __name__ == '__main__':

    print("\n" + "="*80)
    print("Knowledge Distillation Experiment Suite")
    print("Teacher: ResNetCSWinHybrid (new_model)")
    print("Student: EfficientNet variants")
    print("="*80)

    # if args.experiment == 'student_comparison':
    #     run_student_comparison()
    # elif args.experiment == 'temperature':
    #     run_temperature_ablation()
    # elif args.experiment == 'alpha':
    #     run_alpha_ablation()
    # elif args.experiment == 'presets':
    #     run_preset_configs()
    # elif args.experiment == 'all':
    #     run_student_comparison()
    #     run_temperature_ablation()
    #     run_alpha_ablation()
    #     run_preset_configs()

    run_temperature_ablation()
    # run_alpha_ablation()

    print("\n" + "="*80)
    print("All Experiments Complete!")
    print("="*80)