# Spectrum-Guided Model Merging: Bert on SST-2 and Emotion

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from copy import deepcopy

# ==========================================
# 1. SETUP MODELLI E DATASET (Invariato)
# ==========================================
print("Caricamento modelli...")
model_id_1 = "textattack/bert-base-uncased-SST-2"
model_id_2 = "bhadresh-savani/bert-base-uncased-emotion"
base_id = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id_1)
model1 = AutoModelForSequenceClassification.from_pretrained(model_id_1)
model2 = AutoModelForSequenceClassification.from_pretrained(model_id_2, num_labels=2, ignore_mismatched_sizes=True)
model_base = AutoModelForSequenceClassification.from_pretrained(base_id, num_labels=2)

def get_dataloader(dataset_name, config=None, split="validation", size=200):
    ds = load_dataset(dataset_name, config, split=split)
    subset = ds.select(range(min(size, len(ds))))
    def tokenize_fn(examples):
        text_col = "sentence" if "sentence" in examples else "text"
        return tokenizer(examples[text_col], truncation=True, padding="max_length", max_length=128)
    return DataLoader(subset.map(tokenize_fn, batched=True).with_format("torch"), batch_size=16)

dl_sst2 = get_dataloader("glue", "sst2")
dl_emotion = get_dataloader("dair-ai/emotion", split="validation")

def evaluate_dual(model):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    def run_eval(dl):
        correct, total = 0, 0
        for batch in dl:
            ids, mask, y = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['label'].to(device)
            with torch.no_grad():
                logits = model(ids, attention_mask=mask).logits
                preds = torch.argmax(logits, dim=1)
                y_mapped = torch.where(y > 1, torch.tensor(1).to(device), y)
                correct += (preds == y_mapped).sum().item()
                total += y.size(0)
        return correct / total
    return run_eval(dl_sst2), run_eval(dl_emotion)

# ==========================================
# 2. CALCOLO SNR SPECTRUM (Invariato)
# ==========================================
def compute_snr(p):
    W = p.detach().cpu().float().numpy()
    if W.ndim < 2: return 0.0
    if W.ndim > 2: W = W.reshape(W.shape[0], -1)
    N, M = W.shape
    Q = max(N, M) / min(N, M)
    sigma_2 = np.var(W)
    lambda_plus = sigma_2 * (1 + np.sqrt(1/Q))**2
    corr = (W.T @ W) / min(N, M)
    eigs = np.linalg.eigvalsh(corr)
    eigs = np.maximum(eigs, 1e-9)
    return np.sum(eigs[eigs > lambda_plus]) / (np.sum(eigs[eigs <= lambda_plus]) + 1e-9)

snr_m1 = {n: compute_snr(p) for n, p in model1.named_parameters() if 'layer' in n and 'weight' in n}
snr_m2 = {n: compute_snr(p) for n, p in model2.named_parameters() if 'layer' in n and 'weight' in n}

# ==========================================
# 3. STRATEGIE DI MERGING (Aggiornato)
# ==========================================
def perform_merge(strategy, p_threshold):
    """
    p_threshold: Percentuale di layer con SNR più alto da considerare (es. 25, 50, 70)
    """
    merged = deepcopy(model1)
    sd1, sd2, sdb = model1.state_dict(), model2.state_dict(), model_base.state_dict()
    new_sd = {}

    # Calcolo soglia: se vogliamo il top 25%, dobbiamo calcolare il 75° percentile (100-25)
    max_snrs = [max(snr_m1[k], snr_m2[k]) for k in snr_m1.keys()]
    t_val = np.percentile(max_snrs, 100 - p_threshold)

    for k in sd1.keys():
        if k in snr_m1:
            s1, s2 = snr_m1[k], snr_m2[k]
            is_high_snr = max(s1, s2) > t_val

            if strategy == "standard":
                new_sd[k] = 0.5 * sd1[k] + 0.5 * sd2[k]

            elif strategy == "spectrum_reset":
                # Media se SNR alto, reset al base se SNR basso
                new_sd[k] = (0.5 * sd1[k] + 0.5 * sd2[k]) if is_high_snr else sdb.get(k, sd1[k])

            elif strategy == "snr_winner":
                # NUOVA STRATEGIA: Se SNR alto -> prendi il modello con SNR maggiore.
                # Se SNR basso -> fai media normale.
                if is_high_snr:
                    new_sd[k] = sd1[k] if s1 > s2 else sd2[k]
                else:
                    new_sd[k] = 0.5 * sd1[k] + 0.5 * sd2[k]
        else:
            # Per i layer senza SNR (bias, embeddings, ecc.) facciamo media semplice
            if k in sd1 and k in sd2:
                new_sd[k] = 0.5 * sd1[k] + 0.5 * sd2[k]
            else:
                new_sd[k] = sd1[k]

    merged.load_state_dict(new_sd)
    return merged

# ==========================================
# 4. ESECUZIONE E PLOTTING
# ==========================================
threshold_pcts = [25, 50, 70] # Percentuali richieste

print("Valutazione Standard Merging...")
res_std = evaluate_dual(perform_merge("standard", 0))

print("Valutazione Spectrum Reset...")
res_reset = [evaluate_dual(perform_merge("spectrum_reset", t)) for t in threshold_pcts]

print("Valutazione SNR Winner...")
res_winner = [evaluate_dual(perform_merge("snr_winner", t)) for t in threshold_pcts]

# VISUALIZZAZIONE
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
x_labels = [f"Top {t}%" for t in threshold_pcts]

# Grafico SST-2
ax1.axhline(y=res_std[0], color='gray', linestyle='--', label='Standard Merge (50/50)')
ax1.plot(x_labels, [r[0] for r in res_reset], 'o-', label='Spectrum Base Reset')
ax1.plot(x_labels, [r[0] for r in res_winner], 's-', label='SNR Winner (Top) / Mean (Low)', color='green')
ax1.set_title("Performance su SST-2 (Sentiment)")
ax1.set_ylabel("Accuracy")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Grafico Emotion
ax2.axhline(y=res_std[1], color='gray', linestyle='--', label='Standard Merge (50/50)')
ax2.plot(x_labels, [r[1] for r in res_reset], 'o-', label='Spectrum Base Reset')
ax2.plot(x_labels, [r[1] for r in res_winner], 's-', label='SNR Winner (Top) / Mean (Low)', color='green')
ax2.set_title("Performance su Emotion")
ax2.set_ylabel("Accuracy")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

import matplotlib.pyplot as plt
import numpy as np

def plot_snr_peaks(snr_m1, snr_m2):
    # Estraiamo i nomi dei layer e i valori
    layers = list(snr_m1.keys())
    values1 = list(snr_m1.values())
    values2 = list(snr_m2.values())

    # Pulizia nomi layer per la visualizzazione (es. "layer.1.attention...")
    short_names = [n.replace('bert.encoder.layer.', 'L').replace('.weight', '') for n in layers]

    x = np.arange(len(short_names))
    width = 0.35

    fig, ax = plt.subplots(figsize=(15, 7))
    rects1 = ax.bar(x - width/2, values1, width, label='SST-2 (Sentiment)', color='teal', alpha=0.8)
    rects2 = ax.bar(x + width/2, values2, width, label='Emotion', color='orange', alpha=0.8)

    ax.set_ylabel('SNR Value')
    ax.set_title('Confronto dei Picchi di SNR per Layer')
    ax.set_xticks(x)
    ax.set_xticklabels(short_names, rotation=90, fontsize=8)
    ax.legend()

    # Evidenziamo il "vincitore" per ogni layer con un marker sopra la barra
    for i in range(len(values1)):
        winner_y = max(values1[i], values2[i])
        ax.text(i, winner_y + 0.1, '★', ha='center', va='bottom', color='gold' if values1[i] > values2[i] else 'red')

    plt.tight_layout()
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.show()

# Esegui questa funzione dopo aver calcolato snr_m1 e snr_m2 nel tuo notebook
plot_snr_peaks(snr_m1, snr_m2)