In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from copy import deepcopy

# ==========================================
# 1. SETUP MODELLI BINARI (SST-2 & CoLA)
# ==========================================
print("Caricamento modelli nativamente binari...")
model_id_1 = "textattack/bert-base-uncased-SST-2"
model_id_2 = "textattack/bert-base-uncased-CoLA"
base_id = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id_1)
model1 = AutoModelForSequenceClassification.from_pretrained(model_id_1)
model2 = AutoModelForSequenceClassification.from_pretrained(model_id_2)
model_base = AutoModelForSequenceClassification.from_pretrained(base_id, num_labels=2)

def get_dataloader(dataset_name, config=None, split="validation", size=200):
    ds = load_dataset(dataset_name, config, split=split)
    subset = ds.select(range(min(size, len(ds))))
    def tokenize_fn(examples):
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)
    return DataLoader(subset.map(tokenize_fn, batched=True).with_format("torch"), batch_size=16)

dl_sst2 = get_dataloader("glue", "sst2")
dl_cola = get_dataloader("glue", "cola")

def evaluate_dual(model):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    def run_eval(dl):
        correct, total = 0, 0
        for batch in dl:
            ids, mask, y = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)
            with torch.no_grad():
                logits = model(ids, attention_mask=mask).logits
                preds = torch.argmax(logits, dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        return correct / total
    return run_eval(dl_sst2), run_eval(dl_cola)

# ==========================================
# 2. CALCOLO SNR SPECTRUM
# ==========================================
def compute_snr(p):
    W = p.detach().cpu().float().numpy()
    if W.ndim < 2: return 0.0
    if W.ndim > 2: W = W.reshape(W.shape[0], -1)
    N, M = W.shape
    Q = max(N, M) / min(N, M)
    sigma_2 = np.var(W)
    lambda_plus = sigma_2 * (1 + np.sqrt(1/Q))**2
    corr = (W.T @ W) / min(N, M)
    eigs = np.linalg.eigvalsh(corr)
    eigs = np.maximum(eigs, 1e-9)
    return np.sum(eigs[eigs > lambda_plus]) / (np.sum(eigs[eigs <= lambda_plus]) + 1e-9)

print("Analisi SNR in corso...")
snr_m1 = {n: compute_snr(p) for n, p in model1.named_parameters() if 'layer' in n and 'weight' in n}
snr_m2 = {n: compute_snr(p) for n, p in model2.named_parameters() if 'layer' in n and 'weight' in n}

# ==========================================
# 3. STRATEGIE DI MERGING
# ==========================================
def perform_merge(strategy, p_threshold):
    merged = deepcopy(model1)
    sd1, sd2, sdb = model1.state_dict(), model2.state_dict(), model_base.state_dict()
    new_sd = {}
    max_snrs = [max(snr_m1[k], snr_m2[k]) for k in snr_m1.keys()]
    t_val = np.percentile(max_snrs, 100 - p_threshold)

    for k in sd1.keys():
        if k in snr_m1:
            s1, s2 = snr_m1[k], snr_m2[k]
            is_high = max(s1, s2) > t_val
            if strategy == "standard":
                new_sd[k] = 0.5 * sd1[k] + 0.5 * sd2[k]
            elif strategy == "spectrum_reset":
                new_sd[k] = (0.5 * sd1[k] + 0.5 * sd2[k]) if is_high else sdb.get(k, sd1[k])
            elif strategy == "snr_winner":
                new_sd[k] = (sd1[k] if s1 > s2 else sd2[k]) if is_high else (0.5 * sd1[k] + 0.5 * sd2[k])
        else:
            new_sd[k] = 0.5 * sd1[k] + 0.5 * sd2[k]
    merged.load_state_dict(new_sd)
    return merged

# ==========================================
# 4. ESECUZIONE ESPERIMENTO
# ==========================================
thresholds = [25, 50, 70]
print("Valutazione Standard...")
res_std = evaluate_dual(perform_merge("standard", 0))
print("Valutazione Spectrum Reset...")
res_reset = [evaluate_dual(perform_merge("spectrum_reset", t)) for t in thresholds]
print("Valutazione SNR Winner...")
res_winner = [evaluate_dual(perform_merge("snr_winner", t)) for t in thresholds]

# ==========================================
# 5. PLOTTING DEI RISULTATI
# ==========================================
x_labels = [f"Top {t}%" for t in thresholds]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))

# Plot SST-2 (Sentiment)
ax1.axhline(y=res_std[0], color='gray', linestyle='--', label='Standard Merge (Baseline)')
ax1.plot(x_labels, [r[0] for r in res_reset], 'o-', linewidth=2, label='Spectrum Base Reset')
ax1.plot(x_labels, [r[0] for r in res_winner], 's-', linewidth=2, color='green', label='SNR Winner Strategy')
ax1.set_title("SST-2 Accuracy (Sentiment Analysis)", fontsize=14)
ax1.set_ylabel("Accuracy Score")
ax1.grid(True, alpha=0.3)
ax1.legend()

# Plot CoLA (Linguistic Acceptability)
ax2.axhline(y=res_std[1], color='gray', linestyle='--', label='Standard Merge (Baseline)')
ax2.plot(x_labels, [r[1] for r in res_reset], 'o-', linewidth=2, label='Spectrum Base Reset')
ax2.plot(x_labels, [r[1] for r in res_winner], 's-', linewidth=2, color='green', label='SNR Winner Strategy')
ax2.set_title("CoLA Accuracy (Grammar Checker)", fontsize=14)
ax2.set_ylabel("Accuracy Score")
ax2.grid(True, alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.show()

# Grafico dei Picchi SNR
plt.figure(figsize=(15, 5))
plt.plot(list(snr_m1.values()), label='SNR SST-2', color='teal', marker='.')
plt.plot(list(snr_m2.values()), label='SNR CoLA', color='orange', marker='.')
plt.title("SNR Peaks across BERT Layers")
plt.xlabel("Layer Index (Flattened)")
plt.ylabel("Spectral SNR")
plt.legend()
plt.show()
