# Milestone M6 —  Sensitivity/Fisher diag + mask calibration

In [3]:
import os
import torch
import matplotlib.pyplot as plt
from src.utils import set_seed, get_device, ensure_dir
from src.data import load_cifar100, create_dataloader
from src.model import build_model
from src.masking import (
    compute_sensitivity_scores, create_mask, 
    get_mask_sparsity, save_mask
)

ModuleNotFoundError: No module named 'src'

In [None]:
config = {
    'seed': 42,
    'data_dir': './data',
    'output_dir': './outputs',
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.0,
    'batch_size': 32,
    'num_batches_calibration': 50,  # Batches per stima Fisher
}

SPARSITY_RATIOS = [0.05, 0.1, 0.2]
MASK_RULES = ['least_sensitive', 'most_sensitive', 'random', 'highest_magnitude', 'lowest_magnitude']

set_seed(config['seed'])
device = get_device()

print("Device:", device)


In [None]:
train_dataset, _ = load_cifar100(data_dir=config['data_dir'], image_size=224)
calibration_loader = create_dataloader(train_dataset, batch_size=config['batch_size'], shuffle=True)

model = build_model(config)
model.to(device)

In [None]:
print("Computing sensitivity scores...")
scores = compute_sensitivity_scores(
    model, calibration_loader, device,
    num_batches=config['num_batches_calibration'],
    method='fisher'
)

scores_path = os.path.join(config['output_dir'], 'checkpoints', 'fisher_scores.pt')
ensure_dir(os.path.dirname(scores_path))
torch.save(scores, scores_path)
print(f"Saved scores to {scores_path}")

In [None]:
mask_dir = os.path.join(config['output_dir'], 'checkpoints', 'masks')
ensure_dir(mask_dir)

for ratio in SPARSITY_RATIOS:
    for rule in MASK_RULES:
        mask = create_mask(scores, model, sparsity_ratio=ratio, rule=rule, seed=config['seed'])
        actual_sparsity = get_mask_sparsity(mask)
        
        mask_path = os.path.join(mask_dir, f'mask_{rule}_{int(ratio*100)}pct.pt')
        save_mask(mask, mask_path)
        
        print(f"{rule:20s} @ {ratio*100:.0f}% → actual: {actual_sparsity*100:.2f}% | saved: {mask_path}")

In [None]:
all_scores = torch.cat([s.flatten() for s in scores.values()]).cpu().numpy()

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(all_scores, bins=100, log=True, alpha=0.7)
ax.set_xlabel('Fisher Score')
ax.set_ylabel('Count (log)')
ax.set_title('Distribution of Fisher Diagonal Scores')
plt.tight_layout()
plt.savefig(os.path.join(config['output_dir'], 'figures', 'fisher_scores_dist.png'), dpi=150)
plt.show()