In [None]:
import sys, os
proj_root = os.path.abspath(".")  # your project root
if proj_root not in sys.path:
    sys.path.insert(0, proj_root)
import torch
print("Project root:", proj_root)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Image
import json
# Pruning pipeline functions
from pruning.unstructured_pruning import (
    run_sensitivity_analysis,
    analyze_sensitivity,
    apply_pruning_masks_sparse,
    verify_masks_coo,
    finetune_pruned_model,
    profile_pruned_model,
)

from pruning.utils import evaluate_accuracy

# Datasets
from data.dataset import (
    cifar10_trainloader,
    ciaf10_testloader,
    cifar100_trainloader,
    ciaf100_testloader,
)

# Model
from models.vgg_16_bn import get_model

batch_size = 128


In [None]:
datasets = {
    "cifar10": {
        "trainloader": cifar10_trainloader(batch_size=batch_size),
        "testloader": ciaf10_testloader(batch_size=batch_size)
    },
    "cifar100": {
        "trainloader": cifar100_trainloader(batch_size=batch_size),
        "testloader": ciaf100_testloader(batch_size=batch_size)
    }
}


In [None]:
testloader = datasets["cifar10"]["testloader"]
model = get_model()
profile_results = profile_pruned_model(model, testloader, device, "cifar10")

In [None]:
def run_full_pipeline(dataset_name):
    print(f"\nRunning pipeline for {dataset_name}\n")
    
    trainloader = datasets[dataset_name]["trainloader"]
    testloader = datasets[dataset_name]["testloader"]
    
    # 1. Load model
    model = get_model()
    model = model.to(device)
    
    # 2. Sensitivity analysis
    sensitivity_csv = f"results/sensitivity_layerwise_{dataset_name}.csv"
    df_sensitivity = run_sensitivity_analysis(model, testloader, device, save_path=sensitivity_csv)
    print(f"Sensitivity CSV saved at {sensitivity_csv}")
    
    # 3. Analyze & plan sparsity
    overall_sparsity, sparsity_plan = analyze_sensitivity(sensitivity_csv, dataset_name, model)
    print(f"Overall estimated sparsity for {dataset_name}: {overall_sparsity*100:.2f}%")
    
    # 4. Apply pruning masks & generate sparse COO
    mask_dict, sparse_weights_dict = apply_pruning_masks_sparse(
        model,
        sparsity_plan_csv=f"results/plans/sparsity_plan_{dataset_name}.csv",
        dataset_name=dataset_name,
        target_sparsity=overall_sparsity,
        device=device
    )
    
    # 5. Verify masks
    verification_log = verify_masks_coo(
        mask_path=f"results/pruning_masks/{dataset_name}_unstructured_mask.pt",
        sparse_path=f"results/sparse_weights/{dataset_name}_sparse_weights.pt"
    )
    print(f"Verification log entries: {len(verification_log)}")
    
    # 6. Finetune pruned model
    finetuned_model_path = f"results/models/{dataset_name}_vgg16_unstructured_finetuned_{int(overall_sparsity*100)}.pt"
    best_val_acc = finetune_pruned_model(
        model,
        trainloader=trainloader,
        val_loader=testloader,
        masks=mask_dict,
        device=device,
        optimizer_type="sgd",
        lr=1e-3,
        momentum=0.9,
        epochs=5,
        save_path=finetuned_model_path
    )
    print(f"Best validation accuracy after finetuning {dataset_name}: {best_val_acc:.2f}%")

    finetuned_model = get_model().to(device)
    finetuned_model.load_state_dict(torch.load(finetuned_model_path))
    profile_results = profile_pruned_model(finetuned_model, testloader, device, dataset_name)
    
    return df_sensitivity, sparsity_plan, verification_log, finetuned_model_path, profile_results


In [None]:
df_sens_10, plan_10, log_10, finetuned_path_10, profile_10 = run_full_pipeline("cifar10")

In [None]:
df_sens_100, plan_100, log_100, finetuned_path_100, profile_100 = run_full_pipeline("cifar100")


In [None]:
def plot_sensitivity_curves(dataset_name, csv_path=None, save_dir="plots", plot_drop=True):
    if csv_path is None:
        csv_path = f"sensitivity_layerwise_{dataset_name}.csv"
    df = pd.read_csv(csv_path)
    os.makedirs(save_dir, exist_ok=True)

    plt.figure(figsize=(12, 7))
    for layer, group in df.groupby("layer"):
        x = group["sparsity_pct"]
        y = group["top1_drop"] if plot_drop else group["top1"]
        plt.plot(x, y, marker="o", linewidth=1, markersize=4, label=layer)

    plt.xlabel("Sparsity (%)")
    plt.ylabel("Accuracy Drop (%)" if plot_drop else "Top-1 Accuracy (%)")
    plt.title(f"Layer-wise Sensitivity — {dataset_name}")
    plt.grid(alpha=0.2)
    plt.tight_layout()
    # place legend outside to avoid overlapping
    plt.legend(fontsize=7, bbox_to_anchor=(1.02, 1), loc="upper left", ncol=1)
    out_path = os.path.join(save_dir, f"sensitivity_{dataset_name}.png")
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    display(Image(out_path))
    print(f"Saved: {out_path}")
    return out_path

# Example usage (run after sensitivity CSV exists)
plot_sensitivity_curves("cifar10")
plot_sensitivity_curves("cifar100")


In [None]:
# Cell: Plot assigned per-layer sparsity (bar chart) + computed overall sparsity
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from IPython.display import Image, display

def plot_sparsity_plan(dataset_name, plan_csv=None, model=None, save_dir="plots"):
    if plan_csv is None:
        plan_csv = f"plans/sparsity_plan_{dataset_name}.csv"
    plan = pd.read_csv(plan_csv)
    os.makedirs(save_dir, exist_ok=True)

    # Sort by params descending for nicer bars
    plan_sorted = plan.sort_values("params", ascending=False).reset_index(drop=True)
    plt.figure(figsize=(12,6))
    plt.bar(np.arange(len(plan_sorted)), plan_sorted["assigned_sparsity"])
    plt.xticks(np.arange(len(plan_sorted)), plan_sorted["layer"], rotation=90, fontsize=8)
    plt.xlabel("Layer")
    plt.ylabel("Assigned Sparsity (%)")
    plt.title(f"Assigned per-layer sparsity — {dataset_name}")
    plt.tight_layout()
    out_path = os.path.join(save_dir, f"sparsity_plan_{dataset_name}.png")
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    display(Image(out_path))
    print(f"Saved: {out_path}")

    # compute weighted overall sparsity (sanity-check)
    total_params = plan_sorted["params"].sum()
    weighted = (plan_sorted["params"] * plan_sorted["assigned_sparsity"] / 100.0).sum()
    overall = weighted / total_params
    print(f"Weighted overall sparsity (from plan): {overall*100:.2f}%")
    return out_path, overall

# Example usage
plot_sparsity_plan("cifar10")
plot_sparsity_plan("cifar100")


In [None]:
# Cell: Visualize actual mask sparsity per parameter (post-pruning) and histogram
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from IPython.display import Image, display

def plot_mask_sparsity(dataset_name, mask_path=None, save_dir="plots"):
    if mask_path is None:
        mask_path = f"results/pruning_masks/{dataset_name}_unstructured_mask.pt"
    masks = torch.load(mask_path)
    os.makedirs(save_dir, exist_ok=True)

    rows = []
    for name, mask in masks.items():
        total = mask.numel()
        zeros = (mask == 0).sum().item()
        pct = 100.0 * zeros / total
        rows.append((name, total, zeros, pct))
    df = pd.DataFrame(rows, columns=["name", "total_params", "num_zero", "sparsity_pct"])
    df_sorted = df.sort_values("sparsity_pct", ascending=False).reset_index(drop=True)

    # Bar plot: top 40 layers by sparsity (or all if small)
    topk = min(40, len(df_sorted))
    plt.figure(figsize=(12,6))
    plt.barh(np.arange(topk), df_sorted["sparsity_pct"].values[:topk])
    plt.yticks(np.arange(topk), df_sorted["name"].values[:topk], fontsize=8)
    plt.xlabel("Sparsity (%) (zeros / total)")
    plt.title(f"Mask sparsity per parameter (top {topk}) — {dataset_name}")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    out_path = os.path.join(save_dir, f"mask_sparsity_{dataset_name}.png")
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    display(Image(out_path))
    print(f"Saved: {out_path}")

    # Histogram of sparsity across layers
    plt.figure(figsize=(6,4))
    plt.hist(df["sparsity_pct"], bins=20)
    plt.xlabel("Layer sparsity (%)")
    plt.ylabel("Count")
    plt.title(f"Sparsity distribution across layers — {dataset_name}")
    hist_path = os.path.join(save_dir, f"mask_sparsity_hist_{dataset_name}.png")
    plt.tight_layout()
    plt.savefig(hist_path, dpi=200)
    plt.close()
    display(Image(hist_path))
    print(f"Saved: {hist_path}")

    return df

# Example usage
df_mask10 = plot_mask_sparsity("cifar10")
df_mask100 = plot_mask_sparsity("cifar100")


In [None]:
# Cell: Visualize actual mask sparsity per parameter (post-pruning) and histogram
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from IPython.display import Image, display

def plot_mask_sparsity(dataset_name, mask_path=None, save_dir="plots"):
    if mask_path is None:
        mask_path = f"results/pruning_masks/{dataset_name}_unstructured_mask.pt"
    masks = torch.load(mask_path)
    os.makedirs(save_dir, exist_ok=True)

    rows = []
    for name, mask in masks.items():
        total = mask.numel()
        zeros = (mask == 0).sum().item()
        pct = 100.0 * zeros / total
        rows.append((name, total, zeros, pct))
    df = pd.DataFrame(rows, columns=["name", "total_params", "num_zero", "sparsity_pct"])
    df_sorted = df.sort_values("sparsity_pct", ascending=False).reset_index(drop=True)

    # Bar plot: top 40 layers by sparsity (or all if small)
    topk = min(40, len(df_sorted))
    plt.figure(figsize=(12,6))
    plt.barh(np.arange(topk), df_sorted["sparsity_pct"].values[:topk])
    plt.yticks(np.arange(topk), df_sorted["name"].values[:topk], fontsize=8)
    plt.xlabel("Sparsity (%) (zeros / total)")
    plt.title(f"Mask sparsity per parameter (top {topk}) — {dataset_name}")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    out_path = os.path.join(save_dir, f"mask_sparsity_{dataset_name}.png")
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    display(Image(out_path))
    print(f"Saved: {out_path}")

    # Histogram of sparsity across layers
    plt.figure(figsize=(6,4))
    plt.hist(df["sparsity_pct"], bins=20)
    plt.xlabel("Layer sparsity (%)")
    plt.ylabel("Count")
    plt.title(f"Sparsity distribution across layers — {dataset_name}")
    hist_path = os.path.join(save_dir, f"mask_sparsity_hist_{dataset_name}.png")
    plt.tight_layout()
    plt.savefig(hist_path, dpi=200)
    plt.close()
    display(Image(hist_path))
    print(f"Saved: {hist_path}")

    return df

# Example usage
df_mask10 = plot_mask_sparsity("cifar10")
df_mask100 = plot_mask_sparsity("cifar100")


In [None]:
# Cell: Show sensitivity plot image files that were saved earlier (if you just want to display)
from IPython.display import Image, display
display(Image("plots/sensitivity_cifar10.png"))
display(Image("plots/sensitivity_cifar100.png"))
display(Image("plots/sparsity_plan_cifar10.png"))
display(Image("plots/sparsity_plan_cifar100.png"))
