# Setup generico ML/DL
Template per esperimenti di deep learning - Tesi

In [None]:
# ============================================================
# EXECUTION FLAGS - Control what to run
# ============================================================

RUN_DATASET_ANALYSIS = False
RUN_TRAINING_CV = True
RUN_XAI = False

## 1. Mount Google Drive

In [None]:
from google.colab import drive
import os

if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')
else:
    print("✓ Drive already connected")


## 2. Import Libraries


In [None]:

# Libreria per grafici
import matplotlib.pyplot as plt

# Framework deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Librerie per manipolazione dati
import numpy as np          # calcoli matematici su array/matrici
import pandas as pd         # manipolazione dati (tabelle, CSV)

import os
import librosa

from transformers import HubertModel, Wav2Vec2FeatureExtractor
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')
import time

# AMP (Automatic Mixed Precision)
from torch.cuda.amp import autocast, GradScaler
# Scheduler per learning rate adaptation
from torch.optim.lr_scheduler import ReduceLROnPlateau

print("✓ Librerie importate con successo")

## 3. Verify GPU


In [None]:

# Verifica GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✓ Device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
else:
    print("ATTENZIONE: GPU non disponibile!")


## 4. Configure Reproducibility

In [None]:
SEED = 42

np.random.seed(SEED)              # NumPy (CPU)
torch.manual_seed(SEED)           # PyTorch CPU

# PyTorch GPU
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

print("✓ Seed impostati per riproducibilità")


## 5. Set Display Options

In [None]:
# Pandas display options
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)
pd.set_option('display.precision', 4)  # Decimali per float (opzionale)

# Matplotlib inline
%matplotlib inline

# Dimensione default plot più grande e leggibile
import matplotlib
matplotlib.rcParams['figure.figsize'] = (10, 6)
matplotlib.rcParams['font.size'] = 10  # Font leggibile (opzionale)

print("✓ Opzioni display configurate")


## 6. Define and Verify Project Structure


In [None]:
# Base path del progetto
BASE_PATH = '/content/drive/MyDrive/Tesi/'

# Percorsi specifici
DATA_PATH = BASE_PATH + 'data/'
DATA_RAW_PATH = DATA_PATH + 'raw/'
DATA_PROCESSED_PATH = DATA_PATH + 'processed/'
MODELS_PATH = BASE_PATH + 'models/'
RESULTS_PATH = BASE_PATH + 'results/'

def setup_experiment(experiment_name):
    """
    Crea struttura cartelle per esperimento specifico.
    Chiamare all'inizio di ogni notebook esperimento.

    Args:
        experiment_name: nome univoco (es. 'cnn_baseline', 'lstm_v2')

    Returns:
        dict con tutti i path per l'esperimento
    """
    # Sottocartelle per questo esperimento
    exp_models_path = os.path.join(MODELS_PATH, experiment_name)
    exp_results_path = os.path.join(RESULTS_PATH, experiment_name)

    # Crea se non esistono
    os.makedirs(exp_models_path, exist_ok=True)
    os.makedirs(exp_results_path, exist_ok=True)

    paths = {
        'models': exp_models_path,
        'results': exp_results_path,
        'data_raw': DATA_RAW_PATH,
        'data_processed': DATA_PROCESSED_PATH
    }

    print(f"✓ Esperimento '{experiment_name}' configurato")
    print(f"  Models: {exp_models_path}")
    print(f"  Results: {exp_results_path}")

    return paths

# Verifica struttura base (come prima)
print("Verifica struttura base:")
print("=" * 60)

base_paths = [
    ('Base', BASE_PATH),
    ('Data', DATA_PATH),
    ('Data Raw', DATA_RAW_PATH),
    ('Data Processed', DATA_PROCESSED_PATH),
    ('Models', MODELS_PATH),
    ('Results', RESULTS_PATH)
]

all_ok = True
for path_name, path in base_paths:
    if os.path.exists(path):
        print(f"✓ {path_name:<20} {path}")
    else:
        print(f"✗ {path_name:<20} NOT FOUND: {path}")
        all_ok = False

print("=" * 60)
if all_ok:
    print("✓ Struttura base completa!")
else:
    print("ATTENZIONE: crea le cartelle mancanti in Drive")


## Setup experiment

In [None]:
# Nome esperimento
EXPERIMENT_NAME = 'hubert_attention_pooling'

# Setup paths per esperimento
paths = setup_experiment(EXPERIMENT_NAME)

print(f"✓ Path configurati per '{EXPERIMENT_NAME}'")


In [None]:
# ============================================================
# INSTALLAZIONE DIPENDENZE
# ============================================================

!pip install -q transformers==4.47.1
!pip install -q librosa==0.10.2
!pip install -q audiomentations==0.35.0

print("✓ Librerie installate!")


In [None]:
# Parametri globali
CONFIG = {
    # Dati
    'csv_path': DATA_RAW_PATH + 'dataset_free_speech.csv',
    'audio_dir': DATA_RAW_PATH + 'audio/',

    # Preprocessing
    'target_sr': 16000,  # HuBERT richiede 16kHz - (scritto per chiarezza, ma il codice usa processor.sampling_rate)
    'max_duration': 30,  # Max 30 secondi (tronca se più lunghi)

    # Modello
    'model_name': 'facebook/hubert-base-ls960',
    'freeze_layers': 9,  # Freeze layer 0-8, train 9-11
    'pooling_type': 'attention',  # 'attention' o 'mean'

    # Training
    'n_folds': 5,  # 5-fold cross-validation
    'batch_size': 4,
    'num_epochs': 30,
    'learning_rate': 5e-5,  # Standard per fine-tuning
    'weight_decay': 5e-4,  # Regularization

    'scheduler': {
        'factor': 0.5,       # Riduce LR a 50% quando val_loss non migliora
        'patience': 3,       # Aspetta 3 epoch prima di ridurre
        'min_lr': 1e-6,
    },

    # Classificatore
    'hidden_dim': 256,
    'dropout': 0.25,
    'num_classes': 2,

    'early_stopping': {
        'warmup_epochs': 8,
        'max_loss_threshold': 0.75,
        'use_composite_score': True,
        'composite_weights': {
            'auc': 0.45,        # 45% discriminazione
            'balacc': 0.35,     # 35% balance
            'loss': -0.2,      # -20% penalizza loss
        },
        'patience': 10,
    },

    # Augmentation
    'use_augmentation': True,
    'aug_prob': 0.165,

    # Riproducibilità
    'seed': 42
}

print("✓ Configurazione caricata:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
from transformers import Wav2Vec2FeatureExtractor

processor = Wav2Vec2FeatureExtractor.from_pretrained(CONFIG['model_name'])

print(f"✓ Processor caricato (sampling_rate: {processor.sampling_rate} Hz)")


In [None]:
# ============================================================
# CARICAMENTO DATASET
# ============================================================

# Carica CSV
df = pd.read_csv(CONFIG['csv_path'], sep=';')

# Crea label binaria (0=Controllo, 1=Paziente)
df['label'] = (df['Tipo soggetto'] == 'Paziente').astype(int)

# ============================================================
# ESCLUSIONE FILE PROBLEMATICI (basato su analisi qualitativa)
# ============================================================
problematic_files = [
    'D_AP_F_51_2024_10_23_Italian.wav'
]

print(f"Dataset prima dell'esclusione: {len(df)} campioni")
df = df[~df['FileName'].isin(problematic_files)]
df = df.reset_index(drop=True)
print(f"Dataset dopo esclusione: {len(df)} campioni ({len(problematic_files)} esclusi)")
for fname in problematic_files:
    print(f"  - Escluso: {fname}")

# ============================================================
# CARICAMENTO ONSET MAP (per trimming file-specific)
# ============================================================
onset_stats_path = os.path.join(DATA_PATH, 'dataset_free_speech_analysis', 'audio_statistics.csv')

if os.path.exists(onset_stats_path):
    stats_df = pd.read_csv(onset_stats_path)

    # Applica margine conservativo (taglia solo 50% del rilevato)
    TRIM_REDUCTION_FACTOR = 0.5
    stats_df['trim_amount_adjusted'] = stats_df['trim_amount'] * TRIM_REDUCTION_FACTOR

    # Crea dizionario: filename -> trim_amount (in secondi)
    onset_map = dict(zip(stats_df['filename'], stats_df['trim_amount_adjusted']))

    print(f"\n✓ Onset map caricato (CONSERVATIVO): {len(onset_map)} file")
    print(f"  Trim originale: {stats_df['trim_amount'].mean():.2f}s ± {stats_df['trim_amount'].std():.2f}s")
    print(f"  Trim adjusted:  {stats_df['trim_amount_adjusted'].mean():.2f}s ± {stats_df['trim_amount_adjusted'].std():.2f}s")
else:
    print(f"\n⚠ WARNING: Onset map non trovato")
    onset_map = {}

# Verifica esistenza file audio
missing_files = []

for idx, row in df.iterrows():
    audio_path = os.path.join(CONFIG['audio_dir'], row['FileName'])
    if not os.path.exists(audio_path):
        missing_files.append(row['FileName'])

if missing_files:
    print(f"ATTENZIONE: {len(missing_files)} file audio non trovati!")
    print("Primi 5:", missing_files[:5])
else:
    print("✓ Tutti i file audio trovati!")

# Dataset info
print(f"\n=== DATASET INFO ===")
print(f"Totale campioni: {len(df)}")
print(f"Controlli (0): {(df['label']==0).sum()}")
print(f"Pazienti (1): {(df['label']==1).sum()}")
print(f"Ratio Paziente/Controllo: {(df['label']==1).sum()/(df['label']==0).sum():.2f}")


In [None]:
class SpeechDataset(Dataset):
    def __init__(self, dataframe, audiodir, processor, max_duration=30, augment=False, onset_map=None):
        self.df = dataframe.reset_index(drop=True)
        self.audiodir = audiodir
        self.processor = processor
        self.target_sr = processor.sampling_rate
        self.max_duration = max_duration
        self.max_samples = self.target_sr * max_duration
        self.augment = augment
        self.onset_map = onset_map if onset_map is not None else {}

        if self.augment:
            from audiomentations import Compose, AddGaussianNoise, Gain

            aug_prob = CONFIG['aug_prob']

            self.augmentor = Compose([
                # 1. Additive Noise (simula ambiente registrazione)
                AddGaussianNoise(
                    min_amplitude=0.002, max_amplitude=0.015, p=aug_prob),

                # 2. Volume/Gain (simula distanza microfono)
                Gain(min_gain_in_db=-6, max_gain_in_db=6, p=aug_prob*0.85),
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audiopath = os.path.join(self.audiodir, row['FileName'])
        label = row['label']
        filename = row['FileName']

        # Carica audio a sample rate corretto
        waveform, sr_original = librosa.load(audiopath, sr=self.target_sr, mono=True)

        # ============================================================
        # STEP 1: ONSET TRIMMING (file-specific, rimuove silenzio iniziale)
        # ============================================================
        if filename in self.onset_map:
            trim_amount_sec = self.onset_map[filename]
            trim_samples = int(trim_amount_sec * self.target_sr)

            # Trim solo se ha senso (evita trim > lunghezza audio)
            if 0 < trim_samples < len(waveform):
                waveform = waveform[trim_samples:]

        # ============================================================
        # STEP 2: AUGMENTATION (se richiesta, su audio già trimmato)
        # ============================================================
        if self.augment:
            waveform = self.augmentor(samples=waveform, sample_rate=self.target_sr)

        # ============================================================
        # STEP 3: TRUNCATION a max_duration (DOPO trimming)
        # ============================================================
        if len(waveform) > self.max_samples:
            waveform = waveform[:self.max_samples]

        # Waveform RAW numpy → HF processor viene chiamato in collate_fn
        return {
            'input_values': waveform,
            'labels': label,
        }

print("✓ SpeechDataset definito")


In [None]:
def collate_fn(batch):
    input_values = [item['input_values'] for item in batch]
    labels = [item['labels'] for item in batch]

    batch_encoded = processor(
        input_values,
        sampling_rate=processor.sampling_rate,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    batch_encoded['labels'] = torch.tensor(labels, dtype=torch.long)

    return batch_encoded

print("✓ Collate function definita")


In [None]:
# ============================================================
# MODELLO: HuBERT + ATTENTION POOLING + MLP CLASSIFIER
#
# Fonti:
# - HuBERT: Hsu et al. 2021 "HuBERT: Self-Supervised Speech Representation Learning"
# - Attention Pooling: Okabe et al. 2018 "Attentive Statistics Pooling for Deep Speaker Recognition"
# ============================================================

class AttentionPooling(nn.Module):
    """
    Attention-based pooling per aggregare sequenza embeddings.

    Input: (batch, time, features)
    Output: (batch, features), attention_weights (batch, time, 1)
    """
    def __init__(self, input_dim):
        super().__init__()
        self.attention = nn.Linear(input_dim, 1)

    def forward(self, x):
        # x: (batch, time, input_dim)
        attn_scores = self.attention(x)  # (batch, time, 1)
        attn_weights = F.softmax(attn_scores, dim=1)  # Normalize over time
        pooled = torch.sum(attn_weights * x, dim=1)  # (batch, input_dim)
        return pooled, attn_weights


class HuBERTClassifier(nn.Module):
    """
    HuBERT encoder + Attention Pooling + MLP classifier.

    Architecture:
    - HuBERT-Base (12 layer, 768-d embeddings)
    - Layer 0-8: FROZEN (general acoustic knowledge)
    - Layer 9-11: FINE-TUNED (task-specific adaptation)
    - Attention Pooling (learnable aggregation)
    - MLP: 768 → 256 → 2 (binary classification)
    """
    def __init__(self, model_name, freeze_layers=9, hidden_dim=256,
                 dropout=0.3, num_classes=2, pooling_type='attention'):
        super().__init__()

        # HuBERT encoder
        self.hubert = HubertModel.from_pretrained(model_name)
        self.hidden_size = self.hubert.config.hidden_size  # 768 for base

        # Freeze primi N layer
        for layer_idx in range(freeze_layers):
            for param in self.hubert.encoder.layers[layer_idx].parameters():
                param.requires_grad = False

        # Pooling layer
        self.pooling_type = pooling_type
        if pooling_type == 'attention':
            self.pooling = AttentionPooling(self.hidden_size)
        # else: use mean pooling (implemented in forward)

        # Classifier MLP
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_values, attention_mask=None, return_attention=False):
        # HuBERT encoding
        outputs = self.hubert(input_values, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state  # (batch, time, 768)

        # Pooling
        if self.pooling_type == 'attention':
            pooled, attn_weights = self.pooling(embeddings)
        else:
            pooled = embeddings.mean(dim=1)
            attn_weights = None

        logits = self.classifier(pooled)

        if return_attention and attn_weights is not None:
            return logits, attn_weights
        return logits

print("✓ Modello HuBERTClassifier definito")


In [None]:
# ============================================================
# TRAINING FUNCTIONS
# ============================================================

def train_epoch(model, dataloader, criterion, optimizer, device, scaler):
    """
    Training epoch con Automatic Mixed Precision (AMP).

    AMP strategy:
    - Forward pass in mixed precision (fp16/fp32 automatico)
    - Loss computation in fp16
    - Gradient scaling per evitare underflow
    - Metrics computation in fp32

    Args:
        scaler: torch.cuda.amp.GradScaler per gradient scaling
    """
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        # Forward pass in mixed precision
        with autocast():
            logits = model(input_values, attention_mask=attention_mask)
            loss = criterion(logits, labels)

        # Backward pass con gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Metrics: in float32
        # .item() converte loss a Python float (fp32)
        total_loss += loss.item()

        # .argmax() e .cpu() mantengono precisione corretta
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

    # Calcola metriche (in fp32)
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    bal_accuracy = balanced_accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy, bal_accuracy

print("✓ train_epoch definita")

def validate_epoch(model, dataloader, criterion, device):
    """
    Validation epoch con AMP.

    IMPORTANTE: Anche in validation usiamo autocast per consistency
    con training. NO gradient computation, ma stessa precisione forward.
    """
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=False):
            input_values = batch['input_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward in mixed precision (consistency con training)
            with autocast():
                logits = model(input_values, attention_mask=attention_mask)
                loss = criterion(logits, labels)

            # Probabilità per AUC: convert to float32 esplicitamente
            # softmax su logits fp16 → output fp16 → .float() → fp32
            probs = torch.softmax(logits.float(), dim=1)[:, 1].cpu().numpy()

            total_loss += loss.item()
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)

    # Metriche in fp32
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    bal_accuracy = balanced_accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    auc = roc_auc_score(all_labels, all_probs)
    cm = confusion_matrix(all_labels, all_preds)

    return avg_loss, accuracy, bal_accuracy, f1, auc, cm, all_preds, all_labels, all_probs

print("✓ validate_epoch definita")


print("✓ Training functions definite")


## K-Fold Cross Validation

In [None]:
# ============================================================
# MAIN TRAINING LOOP - 5-FOLD CROSS-VALIDATION
# ============================================================

# Setup
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

if RUN_TRAINING_CV:
    # Stratified K-Fold
    skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])

    # Storage risultati
    fold_results = []
    all_fold_histories = []

    print(f"\n{'='*60}")
    print(f"INIZIO TRAINING - {CONFIG['n_folds']}-FOLD CROSS-VALIDATION")
    print(f"{'='*60}\n")

    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['label']), 1):
        print(f"\n{'='*60}")
        print(f"FOLD {fold}/{CONFIG['n_folds']}")
        print(f"{'='*60}")

        # Split train/val
        train_df = df.iloc[train_idx]
        val_df = df.iloc[val_idx]

        print(f"Train: {len(train_df)} samples (Pz:{(train_df['label']==1).sum()}, Ctrl:{(train_df['label']==0).sum()})")
        print(f"Val:   {len(val_df)} samples (Pz:{(val_df['label']==1).sum()}, Ctrl:{(val_df['label']==0).sum()})")

        # Datasets
        train_dataset = SpeechDataset(
            train_df,
            CONFIG['audio_dir'],
            processor=processor,
            max_duration=CONFIG['max_duration'],
            augment=CONFIG['use_augmentation'],
            onset_map=onset_map
        )

        val_dataset = SpeechDataset(
            val_df,
            CONFIG['audio_dir'],
            processor=processor,
            max_duration=CONFIG['max_duration'],
            augment=False, # No augmentation in validation
            onset_map=onset_map
        )

        # DataLoaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=0,
            collate_fn=collate_fn
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn
        )

        # Modello
        model = HuBERTClassifier(
            model_name=CONFIG['model_name'],
            freeze_layers=CONFIG['freeze_layers'],
            hidden_dim=CONFIG['hidden_dim'],
            dropout=CONFIG['dropout'],
            num_classes=CONFIG['num_classes'],
            pooling_type=CONFIG['pooling_type']
        ).to(device)

        # Loss e optimizer CON CLASS WEIGHTS
        # Calcola pesi per bilanciare sbilanciamento
        n_ctrl = (train_df['label'] == 0).sum()
        n_pz = (train_df['label'] == 1).sum()
        total = len(train_df)

        weight_ctrl = total / (2 * n_ctrl)
        weight_pz = total / (2 * n_pz)

        class_weights = torch.tensor(
            [weight_ctrl, weight_pz],
            dtype=torch.float32
        ).to(device)

        print(f"  Class weights: Ctrl={weight_ctrl:.3f}, Pz={weight_pz:.3f}")

        criterion = nn.CrossEntropyLoss(weight=class_weights)
        optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])

        # AMP Scaler (per questo fold)
        scaler = GradScaler()

        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=CONFIG['scheduler']['factor'],
            patience=CONFIG['scheduler']['patience'],
            min_lr=CONFIG['scheduler']['min_lr']
        )

        # Training loop
        best_val_loss = float('inf')
        best_val_auc = 0.0
        best_val_bal_acc = 0.0
        best_composite_score = -float('inf')
        patience_counter = 0

        history = {
            'train_loss': [],
            'train_bal_acc': [],
            'val_loss': [],
            'val_acc': [],
            'val_bal_acc': [],
            'val_auc': []
        }

        for epoch in range(1, CONFIG['num_epochs'] + 1):
            print(f"\nEpoch {epoch}/{CONFIG['num_epochs']}")

            # Train con scaler
            train_loss, train_acc, train_bal_acc = train_epoch(
                model, train_loader, criterion, optimizer, device, scaler
            )

            # Validate (no scaler needed, solo autocast)
            val_loss, val_acc, val_bal_acc, val_f1, val_auc, val_cm, _, _, _ = validate_epoch(
                model, val_loader, criterion, device
            )

            # Log
            history['train_loss'].append(train_loss)
            history['train_bal_acc'].append(train_bal_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_bal_acc'].append(val_bal_acc)
            history['val_auc'].append(val_auc)

            print(f" Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | Bal_Acc: {train_bal_acc:.4f}")
            print(f" Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Bal_Acc: {val_bal_acc:.4f} | F1: {val_f1:.4f} | AUC: {val_auc:.4f}")
            print(f" LR: {optimizer.param_groups[0]['lr']:.2e}")

            # Stampa componenti confusion matrix
            tn, fp, fn, tp = val_cm.ravel()
            print(f"   CM: [[TN={tn}, FP={fp}], [FN={fn}, TP={tp}]]")

            # STEP SCHEDULER (RIDUCE LR SE VAL_LOSS NON MIGLIORA)
            scheduler.step(val_loss)
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Current LR: {current_lr:.2e}")

            # ===== EARLY STOPPING CON COMPOSITE SCORE =====
            WARMUP_EPOCHS = CONFIG['early_stopping']['warmup_epochs']
            MAX_LOSS_THRESHOLD = CONFIG['early_stopping']['max_loss_threshold']
            USE_COMPOSITE = CONFIG['early_stopping']['use_composite_score']

            # Calcola composite score (se attivato)
            if USE_COMPOSITE:
                weights = CONFIG['early_stopping']['composite_weights']
                composite_score = (
                    weights['auc'] * val_auc +
                    weights['balacc'] * val_bal_acc +
                    weights['loss'] * val_loss  # Negativo nel config (-0.2)
                )
                print(f"Composite Score: {composite_score:.4f} (AUC:{val_auc:.4f} BalAcc:{val_bal_acc:.4f} Loss:{val_loss:.4f})")
            else:
                composite_score = val_auc  # Fallback

            # FASE 1: WARMUP (primi N epoch)
            # Obiettivo: Trova un buon modello iniziale con loss-based semplice
            if epoch <= WARMUP_EPOCHS:
                if val_loss < MAX_LOSS_THRESHOLD and val_auc > best_val_auc:
                    best_val_auc = val_auc
                    best_val_loss = val_loss
                    best_val_bal_acc = val_bal_acc
                    best_composite_score = composite_score
                    patience_counter = 0

                    # Salva checkpoint
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                        'scaler_state_dict': scaler.state_dict(),
                        'epoch': epoch,
                        'val_loss': val_loss,
                        'val_auc': float(val_auc),
                        'val_bal_acc': float(val_bal_acc),
                    }

                    model_path = os.path.join(paths['models'], f'best_model_fold{fold}.pth')
                    torch.save(checkpoint, model_path)
                    print(f" [WARMUP] Best model saved! AUC: {val_auc:.4f}, BalAcc: {val_bal_acc:.4f}, Loss: {val_loss:.4f}")

                    # Verifica file
                    max_wait = 10
                    waited = 0
                    while waited < max_wait:
                        if os.path.exists(model_path):
                            file_size = os.path.getsize(model_path)
                            if file_size > 1024:
                                print(f"File verificato: {file_size / 1e6:.1f} MB")
                                break
                        time.sleep(1)
                        waited += 1

                    if waited >= max_wait:
                        print(f"  WARNING: File non trovato dopo {max_wait}s!")
                else:
                    patience_counter += 1
                    if val_loss >= MAX_LOSS_THRESHOLD:
                        print(f"  [WARMUP] Loss {val_loss:.4f} > {MAX_LOSS_THRESHOLD}")

            # FASE 2: POST-WARMUP (dopo epoch N)
            # Obiettivo: Usa composite score per bilanciare AUC + BalAcc + Loss
            else:
                if USE_COMPOSITE:
                    # POST-WARMUP CON COMPOSITE SCORE
                    if val_loss < MAX_LOSS_THRESHOLD:
                        # Salva se composite score migliora
                        if composite_score > best_composite_score:
                            best_composite_score = composite_score
                            best_val_auc = val_auc
                            best_val_loss = val_loss
                            best_val_bal_acc = val_bal_acc
                            patience_counter = 0

                            checkpoint = {
                                'model_state_dict': model.state_dict(),
                                'scaler_state_dict': scaler.state_dict(),
                                'epoch': epoch,
                                'val_loss': val_loss,
                                'val_auc': float(val_auc),
                                'val_bal_acc': float(val_bal_acc),
                                'composite_score': float(composite_score),
                            }

                            model_path = os.path.join(paths['models'], f'best_model_fold{fold}.pth')
                            torch.save(checkpoint, model_path)
                            print(f" [COMPOSITE] Best model saved! Score: {composite_score:.4f}")
                            print(f"   AUC: {val_auc:.4f}, BalAcc: {val_bal_acc:.4f}, Loss: {val_loss:.4f}")

                            # Verifica file
                            max_wait = 10
                            waited = 0
                            while waited < max_wait:
                                if os.path.exists(model_path):
                                    file_size = os.path.getsize(model_path)
                                    if file_size > 1024:
                                        print(f"File verificato: {file_size / 1e6:.1f} MB")
                                        break
                                time.sleep(1)
                                waited += 1

                            if waited >= max_wait:
                                print(f"  WARNING: File non trovato dopo {max_wait}s!")
                        else:
                            patience_counter += 1
                            print(f"  Composite score {composite_score:.4f} ≤ best {best_composite_score:.4f}")
                    else:
                        # Val loss esplosa oltre threshold
                        patience_counter += 1
                        print(f"  Loss {val_loss:.4f} > {MAX_LOSS_THRESHOLD} (explosion)")

                else:
                    # FALLBACK: Se composite disattivato, usa loss-based
                    if val_loss < MAX_LOSS_THRESHOLD and val_auc > best_val_auc:
                        best_val_auc = val_auc
                        best_val_loss = val_loss
                        patience_counter = 0
                        # ... (salva modello come in warmup)

            # Check patience finale
            if patience_counter >= CONFIG['early_stopping']['patience']:
                print(f"⏹  Early stopping at epoch {epoch} (patience exhausted)")
                break

        # Carica best model (solo model weights, scaler non serve per evaluation)
        model_path = os.path.join(paths['models'], f'best_model_fold{fold}.pth')

        print(f"  ⏳ Caricamento best model...")
        max_wait = 30
        waited = 0

        while not os.path.exists(model_path) and waited < max_wait:
            time.sleep(1)
            waited += 1

        if os.path.exists(model_path):
            try:
                checkpoint = torch.load(model_path)
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"  ✓ Best model restored (epoch {checkpoint['epoch']}, val_loss: {checkpoint['val_loss']:.4f})")
            except Exception as e:
                print(f"  ❌ Errore nel caricare il modello: {e}")
                print(f"  → Uso il modello corrente")
        else:
            print(f"  ⚠️  Best model non trovato dopo {max_wait}s!")
            print(f"  → Uso il modello corrente")

        final_val_loss, final_val_acc, final_val_bal_acc, final_val_f1, final_val_auc, final_cm, preds, labels, probs = validate_epoch(
            model, val_loader, criterion, device
        )

        # Estrai componenti confusion matrix
        tn, fp, fn, tp = final_cm.ravel()

        print(f"\n{'='*60}")
        print(f"FOLD {fold} FINAL RESULTS")
        print(f"{'='*60}")
        print(f"Val Accuracy: {final_val_acc:.4f}")
        print(f"Val Balanced Accuracy: {final_val_bal_acc:.4f}")
        print(f"Val F1-Score: {final_val_f1:.4f}")
        print(f"Val AUC-ROC: {final_val_auc:.4f}")
        print(f"\nConfusion Matrix:")
        print(final_cm)
        print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")

        # Store results (include probs/labels for ROC plot)
        fold_results.append({
            'fold': fold,
            'val_acc': final_val_acc,
            'val_bal_acc': final_val_bal_acc,
            'val_f1': final_val_f1,
            'val_auc': final_val_auc,
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp),
            'probs': probs,   # Probabilità per ROC curve
            'labels': labels  # Labels vere per ROC curve
        })
        all_fold_histories.append(history)

    print(f"\n\n{'='*60}")
    print("CROSS-VALIDATION COMPLETA")
    print(f"{'='*60}\n")
else:
    print("-> Skipping Cross-Validation (RUN_TRAINING_CV = False)")
    print("   Set RUN_TRAINING_CV = True to run 5-fold CV")


In [None]:
# ============================================================
# RISULTATI FINALI CROSS-VALIDATION
# ============================================================

if RUN_TRAINING_CV:
    # Calcola metriche aggregate
    val_accs = [r['val_acc'] for r in fold_results]
    val_bal_accs = [r['val_bal_acc'] for r in fold_results]
    val_f1s = [r['val_f1'] for r in fold_results]
    val_aucs = [r['val_auc'] for r in fold_results]

    print("=== RISULTATI 5-FOLD CROSS-VALIDATION ===\n")
    print(f"Accuracy:          {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
    print(f"Balanced Accuracy: {np.mean(val_bal_accs):.4f} ± {np.std(val_bal_accs):.4f}")
    print(f"F1-Score:          {np.mean(val_f1s):.4f} ± {np.std(val_f1s):.4f}")
    print(f"AUC-ROC:           {np.mean(val_aucs):.4f} ± {np.std(val_aucs):.4f}")

    print("\n=== PER-FOLD BREAKDOWN ===")
    for r in fold_results:
        print(f"Fold {r['fold']}: Acc={r['val_acc']:.4f}, Bal_Acc={r['val_bal_acc']:.4f}, F1={r['val_f1']:.4f}, AUC={r['val_auc']:.4f}")

    # Salva risultati
    results_df = pd.DataFrame(fold_results)
    results_df.to_csv(os.path.join(paths['results'], 'cv_results.csv'), index=False)

    print(f"\n✓ Risultati salvati in: {paths['results']}/cv_results.csv")

    # ============================================================
    # PLOT LEARNING CURVES CON TRAIN/VAL GAP
    # ============================================================

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Colori consistenti per fold
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_fold_histories)))

    # Plot 1: Training Loss (tutte fold)
    for fold_idx, history in enumerate(all_fold_histories, 1):
        axes[0, 0].plot(history['train_loss'], alpha=0.6, color=colors[fold_idx-1],
                        linestyle='-', label=f'Fold {fold_idx}')
    axes[0, 0].set_title('Training Loss per Fold', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend(loc='best')
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Validation Loss (tutte fold)
    for fold_idx, history in enumerate(all_fold_histories, 1):
        axes[0, 1].plot(history['val_loss'], alpha=0.6, color=colors[fold_idx-1],
                        linestyle='--', label=f'Fold {fold_idx}')
    axes[0, 1].set_title('Validation Loss per Fold', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend(loc='best')
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Validation Balanced Accuracy
    for fold_idx, history in enumerate(all_fold_histories, 1):
        axes[1, 0].plot(history['val_bal_acc'], alpha=0.6, color=colors[fold_idx-1],
                        label=f'Fold {fold_idx}')
    axes[1, 0].set_title('Validation Balanced Accuracy per Fold', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Balanced Accuracy')
    axes[1, 0].legend(loc='best')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim([0.4, 1.0])

    # Plot 4: Validation AUC-ROC
    for fold_idx, history in enumerate(all_fold_histories, 1):
        axes[1, 1].plot(history['val_auc'], alpha=0.6, color=colors[fold_idx-1],
                        label=f'Fold {fold_idx}')
    axes[1, 1].set_title('Validation AUC-ROC per Fold', fontsize=12, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('AUC-ROC')
    axes[1, 1].legend(loc='best')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_ylim([0.4, 1.0])

    plt.tight_layout()

    # Salva in multipli formati (PNG + PDF per tesi)
    plt.savefig(os.path.join(paths['results'], 'learning_curves.png'), dpi=300, bbox_inches='tight')
    plt.savefig(os.path.join(paths['results'], 'learning_curves.pdf'), bbox_inches='tight')
    plt.show()

    print("✓ Plot salvati:")
    print(f"  - {paths['results']}/learning_curves.png (300 DPI)")
    print(f"  - {paths['results']}/learning_curves.pdf (vettoriale)")

    # ============================================================
    # PLOT FINALE: ROC MEDIA + CONFUSION MATRIX AGGREGATA
    # ============================================================

    from sklearn.metrics import roc_curve, auc as sklearn_auc

    print("\n" + "="*60)
    print("GENERAZIONE PLOT FINALE: ROC MEDIA + CM AGGREGATA")
    print("="*60)

    # ------------------------------
    # 1. ROC CURVE MEDIA
    # ------------------------------

    # Griglia comune FPR per interpolazione (standard practice)
    mean_fpr = np.linspace(0, 1, 100)

    tprs = []  # TPR interpolati per ogni fold
    aucs = []  # AUC per ogni fold

    fig_final, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Plot ROC per ogni fold + interpola
    for i, result in enumerate(fold_results, 1):
        fpr, tpr, _ = roc_curve(result['labels'], result['probs'])
        roc_auc = sklearn_auc(fpr, tpr)
        aucs.append(roc_auc)

        # Interpola TPR sulla griglia comune FPR
        tpr_interp = np.interp(mean_fpr, fpr, tpr)
        tpr_interp[0] = 0.0  # Forza (0,0)
        tprs.append(tpr_interp)

        # Plot fold individuale (trasparente)
        axes[0].plot(fpr, tpr, alpha=0.3, linewidth=1)

    # Calcola ROC media e std
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0  # Forza (1,1)
    mean_auc = np.mean(aucs)
    std_auc = np.std(aucs)

    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

    # Plot ROC media (linea spessa)
    axes[0].plot(mean_fpr, mean_tpr, color='blue', linewidth=3,
                 label=f'Mean ROC (AUC = {mean_auc:.3f} ± {std_auc:.3f})')

    # Banda di confidenza (±1 std)
    axes[0].fill_between(mean_fpr, tprs_lower, tprs_upper,
                         color='grey', alpha=0.3, label='±1 std')

    # Linea diagonale (random classifier)
    axes[0].plot([0, 1], [0, 1], 'k--', linewidth=1.5, label='Chance (AUC = 0.5)')

    axes[0].set_xlim([0.0, 1.0])
    axes[0].set_ylim([0.0, 1.05])
    axes[0].set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    axes[0].set_title(f'Mean ROC Curve ({CONFIG["n_folds"]}-Fold CV)',
                      fontsize=13, fontweight='bold')
    axes[0].legend(loc='lower right', fontsize=9)
    axes[0].grid(True, alpha=0.3)
    axes[0].set_aspect('equal')

    # ------------------------------
    # 2. CONFUSION MATRIX AGGREGATA
    # ------------------------------

    # Somma conteggi di tutti i fold
    cm_aggregated = np.array([
        [sum(r['tn'] for r in fold_results), sum(r['fp'] for r in fold_results)],
        [sum(r['fn'] for r in fold_results), sum(r['tp'] for r in fold_results)]
    ])

    total_predictions = cm_aggregated.sum()

    # Plot heatmap con colori
    im = axes[1].imshow(cm_aggregated, interpolation='nearest', cmap='Blues')
    axes[1].figure.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

    # Annotazioni: conteggi + percentuali
    thresh = cm_aggregated.max() / 2.
    for i in range(2):
        for j in range(2):
            count = cm_aggregated[i, j]
            percentage = count / total_predictions * 100
            axes[1].text(j, i, f'{count}\n({percentage:.1f}%)',
                        ha="center", va="center", fontsize=14, fontweight='bold',
                        color="white" if cm_aggregated[i, j] > thresh else "black")

    # Etichette assi
    axes[1].set_xticks([0, 1])
    axes[1].set_yticks([0, 1])
    axes[1].set_xticklabels(['Control', 'Patient'], fontsize=11)
    axes[1].set_yticklabels(['Control', 'Patient'], fontsize=11)
    axes[1].set_ylabel('True Label', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
    axes[1].set_title(f'Aggregated Confusion Matrix ({CONFIG["n_folds"]}-Fold CV)\n'
                     f'N = {total_predictions} total predictions',
                     fontsize=13, fontweight='bold')

    plt.tight_layout()

    # Salva plot finale
    plt.savefig(os.path.join(paths['results'], 'roc_cm_final.png'), dpi=300, bbox_inches='tight')
    plt.savefig(os.path.join(paths['results'], 'roc_cm_final.pdf'), bbox_inches='tight')
    plt.show()

    print("✓ Plot finale salvati:")
    print(f"  - {paths['results']}/roc_cm_final.png (300 DPI)")
    print(f"  - {paths['results']}/roc_cm_final.pdf (vettoriale)")
    print(f"\n  ROC Media: AUC = {mean_auc:.3f} ± {std_auc:.3f}")
    print(f"    CM Aggregata:\n{cm_aggregated}")
    print(f"   TN={cm_aggregated[0,0]}, FP={cm_aggregated[0,1]}, "
          f"FN={cm_aggregated[1,0]}, TP={cm_aggregated[1,1]}")
else:
    print("⏭  Skipping CV results (not executed)")


## Analisi statistica dataset

In [None]:
# ============================================================
# AUDIO DATASET ANALYSIS - Thesis Version
# ============================================================
# Comprehensive analysis of speech audio dataset:
# - Duration statistics (total and effective after onset trimming)
# - Speech onset detection using energy-based method
# - Audio quality assessment (silence ratio, RMS energy, clipping)
# - Problematic file identification
# ============================================================

import os
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


# ============================================================
# SILENCE RATIO COMPUTATION
# ============================================================

def compute_silence_ratio(rms, method='adaptive'):
    """
    Compute silence ratio using robust energy threshold method.

    The silence ratio quantifies the proportion of low-energy frames
    in an audio signal, indicating the amount of silence or background
    noise relative to active speech.

    Parameters
    ----------
    rms : np.ndarray
        Root Mean Square (RMS) energy values for each audio frame
    method : str, default='adaptive'
        Method for silence threshold calculation:
        - 'adaptive': Threshold = 10% of peak RMS energy
                     Frames below this threshold are classified as silence.
                     This method is robust to gain variations across recordings
                     and is standard practice in speech signal processing.
        - 'absolute': Fixed threshold = 0.01 (requires manual tuning)
        - 'percentile': Threshold = 5th percentile of RMS distribution

    Returns
    -------
    silence_ratio : float
        Fraction of frames classified as silence [0, 1]
    silence_threshold : float
        Energy threshold value used for silence detection

    Notes
    -----
    The 'adaptive' method (10% of peak) is recommended because:
    1. Automatically adjusts to recording volume differences
    2. Robust to microphone gain variations
    3. Standard approach in speech processing literature

    Example: If peak RMS = 0.8, threshold = 0.08
             Frames with RMS < 0.08 are classified as silence
    """
    if len(rms) == 0 or np.all(np.isnan(rms)):
        return 0.0, 0.0

    if method == 'adaptive':
        rms_peak = np.max(rms)
        if rms_peak > 0:
            silence_threshold = rms_peak * 0.1
        else:
            silence_threshold = 0.01

    elif method == 'absolute':
        silence_threshold = 0.01

    elif method == 'percentile':
        silence_threshold = np.percentile(rms, 5)

    else:
        raise ValueError(f"Unknown method: {method}")

    silence_ratio = np.sum(rms < silence_threshold) / len(rms)

    return silence_ratio, silence_threshold


# ============================================================
# SPEECH ONSET DETECTION
# ============================================================

def detect_speech_onset(y, sr, energy_threshold_percentile=20,
                        min_speech_duration=0.3, pre_speech_margin=0.5):
    """
    Detect speech onset time using energy-based adaptive threshold method.

    This function identifies the start of meaningful speech content by
    analyzing RMS energy patterns. It distinguishes between low-energy
    background noise and high-energy speech segments.

    Parameters
    ----------
    y : np.ndarray
        Audio waveform (time-domain signal)
    sr : int
        Sample rate in Hz
    energy_threshold_percentile : float, default=20
        Percentile of RMS distribution used as energy threshold.
        Default 20 means: frames below the 20th percentile are classified
        as background noise/silence; frames above are potential speech.
        Lower values = more permissive (easier to detect speech)
        Higher values = more conservative (requires higher energy)
    min_speech_duration : float, default=0.3
        Minimum duration (in seconds) of continuous high-energy frames
        required to declare speech onset. This prevents false triggers
        from transient noise spikes.
    pre_speech_margin : float, default=0.5
        Duration (in seconds) to preserve before detected onset.
        Ensures natural transitions and context are not lost.

    Returns
    -------
    dict
        onset_time : float or None
            Time (in seconds) when speech is detected to begin.
            None if detection fails.
        trim_start : float or None
            Recommended trim start time including pre-speech margin.
            None if detection fails.
        margin_used : float
            Actual margin applied (may be less than requested if
            onset occurs near file start)
        success : bool
            True if onset was successfully detected, False otherwise

    Notes
    -----
    Detection methodology:
    1. Compute RMS energy using 25ms frames with 10ms hop
       (standard parameters for speech analysis)
    2. Calculate adaptive energy threshold (20th percentile)
    3. Find first sequence of frames above threshold lasting
       at least min_speech_duration
    4. Apply pre-speech margin to preserve natural onset

    Failure cases:
    - Audio too short for analysis
    - No sustained high-energy segment found
    - All frames have zero/invalid energy
    """
    frame_length = int(0.025 * sr)  # 25ms frame
    hop_length = int(0.010 * sr)     # 10ms hop

    rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]

    if len(rms) == 0 or np.all(np.isnan(rms)) or np.all(rms == 0):
        return {
            'onset_time': None,
            'trim_start': None,
            'margin_used': 0.0,
            'success': False
        }

    energy_threshold = np.percentile(rms, energy_threshold_percentile)
    is_speech = rms > energy_threshold

    min_frames = int(min_speech_duration * sr / hop_length)

    if min_frames >= len(is_speech):
        return {
            'onset_time': None,
            'trim_start': None,
            'margin_used': 0.0,
            'success': False
        }

    onset_time = None
    found = False

    for i in range(len(is_speech) - min_frames + 1):
        if np.all(is_speech[i:i + min_frames]):
            onset_time = librosa.frames_to_time(i, sr=sr, hop_length=hop_length)
            found = True
            break

    if not found:
        return {
            'onset_time': None,
            'trim_start': None,
            'margin_used': 0.0,
            'success': False
        }

    trim_start = max(0, onset_time - pre_speech_margin)
    margin_used = onset_time - trim_start

    return {
        'onset_time': onset_time,
        'trim_start': trim_start,
        'margin_used': margin_used,
        'success': True
    }


# ============================================================
# DATASET ANALYSIS
# ============================================================

def analyze_audio_dataset_simple(df, audio_dir):
    """
    Perform comprehensive analysis of audio dataset.

    Computes per-file statistics and aggregate summary metrics for:
    - Duration (total and effective after onset trimming)
    - Speech onset timing
    - Audio quality (silence ratio, RMS energy, clipping)
    - Group-level comparisons

    Parameters
    ----------
    df : pd.DataFrame
        Dataset metadata with required columns:
        - FileName: audio file name
        - label: binary label (0=Control, 1=Patient)
        - Tipo soggetto: descriptive label (Controllo/Paziente)
    audio_dir : str
        Directory path containing audio files

    Returns
    -------
    stats_df : pd.DataFrame
        Per-file statistics with columns:
        - filename, label, tipo_soggetto
        - sample_rate, duration_total, duration_after_trim
        - onset_time, onset_detected, trim_amount
        - rms_mean, rms_std, rms_max
        - silence_ratio, silence_threshold
        - zcr_mean (zero crossing rate)
        - clipping_ratio, clipped_samples
    summary : dict
        Aggregate statistics organized by category:
        - n_samples, n_controls, n_patients, n_errors
        - sample_rate (min, max, unique)
        - duration_total (mean, std, median, min, max, q25, q75)
        - onset (mean, std, median, max, detection_success_rate)
        - duration_after_trim (mean, std, median, min, max, q25, q75)
        - quality (silence_ratio, rms, clipping metrics)
        - by_group (separate statistics for controls and patients)
    """
    audio_stats = []

    print("="*80)
    print("AUDIO DATASET ANALYSIS")
    print("="*80)
    print(f"Analyzing {len(df)} audio files...")
    print(f"Onset detection: Energy-based adaptive threshold method")
    print(f"Silence detection: 10% of peak energy method\n")

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
        audio_path = os.path.join(audio_dir, row['FileName'])

        try:
            y, sr = librosa.load(audio_path, sr=None, mono=True)

            if len(y) == 0:
                raise ValueError("Empty audio file")

            duration_total = len(y) / sr

            # Clipping detection (amplitude near ±1.0)
            clipping_threshold = 0.99
            clipped_samples = np.sum(np.abs(y) >= clipping_threshold)
            clipping_ratio = clipped_samples / len(y)

            if clipping_ratio > 0.01:
                print(f"\n  Warning: {row['FileName']}: {clipping_ratio:.1%} clipping detected")

            # Speech onset detection
            onset = detect_speech_onset(y, sr)

            if onset['success']:
                onset_time = onset['onset_time']
                trim_start = onset['trim_start']
                duration_after_trim = duration_total - trim_start
            else:
                print(f"\n  Warning: {row['FileName']}: Onset detection failed")
                onset_time = None
                trim_start = 0.0
                duration_after_trim = duration_total

            if duration_after_trim < 5.0:
                print(f"\n  Warning: {row['FileName']}: Only {duration_after_trim:.1f}s after trim")

            # Audio quality metrics
            rms = librosa.feature.rms(y=y)[0]

            if len(rms) == 0 or np.all(np.isnan(rms)):
                print(f"\n  Warning: {row['FileName']}: Invalid RMS values")
                rms = np.array([0.0])

            rms_mean = np.mean(rms)
            rms_std = np.std(rms)
            rms_max = np.max(rms)

            silence_ratio, silence_threshold = compute_silence_ratio(rms, method='adaptive')

            zcr = librosa.feature.zero_crossing_rate(y)[0]
            zcr_mean = np.mean(zcr)

            audio_stats.append({
                'filename': row['FileName'],
                'label': row['label'],
                'tipo_soggetto': row['Tipo soggetto'],
                'sample_rate': sr,
                'duration_total': duration_total,
                'onset_time': onset_time,
                'onset_detected': onset['success'],
                'trim_amount': trim_start,
                'duration_after_trim': duration_after_trim,
                'rms_mean': rms_mean,
                'rms_std': rms_std,
                'rms_max': rms_max,
                'silence_ratio': silence_ratio,
                'silence_threshold': silence_threshold,
                'zcr_mean': zcr_mean,
                'clipping_ratio': clipping_ratio,
                'clipped_samples': clipped_samples,
                'samples_total': len(y)
            })

        except Exception as e:
            print(f"\n  Error processing {row['FileName']}: {e}")
            audio_stats.append({
                'filename': row['FileName'],
                'label': row['label'],
                'tipo_soggetto': row['Tipo soggetto'],
                'error': str(e)
            })

    stats_df = pd.DataFrame(audio_stats)
    valid = stats_df.dropna(subset=['duration_total'])
    onset_valid = valid[valid['onset_detected'] == True]

    # Compute aggregate summary statistics
    summary = {
        'n_samples': len(valid),
        'n_controls': (valid['label'] == 0).sum(),
        'n_patients': (valid['label'] == 1).sum(),
        'n_errors': len(stats_df) - len(valid),

        'sample_rate': {
            'min': int(valid['sample_rate'].min()),
            'max': int(valid['sample_rate'].max()),
            'unique': valid['sample_rate'].nunique()
        },

        'duration_total': {
            'mean': valid['duration_total'].mean(),
            'std': valid['duration_total'].std(),
            'median': valid['duration_total'].median(),
            'min': valid['duration_total'].min(),
            'max': valid['duration_total'].max(),
            'q25': valid['duration_total'].quantile(0.25),
            'q75': valid['duration_total'].quantile(0.75)
        },

        'onset': {
            'mean': onset_valid['onset_time'].mean() if len(onset_valid) > 0 else 0.0,
            'std': onset_valid['onset_time'].std() if len(onset_valid) > 0 else 0.0,
            'median': onset_valid['onset_time'].median() if len(onset_valid) > 0 else 0.0,
            'max': onset_valid['onset_time'].max() if len(onset_valid) > 0 else 0.0,
            'detection_success_rate': valid['onset_detected'].mean()
        },

        'duration_after_trim': {
            'mean': valid['duration_after_trim'].mean(),
            'std': valid['duration_after_trim'].std(),
            'median': valid['duration_after_trim'].median(),
            'min': valid['duration_after_trim'].min(),
            'max': valid['duration_after_trim'].max(),
            'q25': valid['duration_after_trim'].quantile(0.25),
            'q75': valid['duration_after_trim'].quantile(0.75)
        },

        'quality': {
            'silence_ratio_mean': valid['silence_ratio'].mean(),
            'silence_ratio_std': valid['silence_ratio'].std(),
            'rms_mean': valid['rms_mean'].mean(),
            'rms_std': valid['rms_mean'].std(),
            'clipping_mean': valid['clipping_ratio'].mean(),
            'clipping_max': valid['clipping_ratio'].max(),
            'n_clipped': (valid['clipping_ratio'] > 0.01).sum()
        },

        'by_group': {
            'controls': {
                'n': (valid['label'] == 0).sum(),
                'duration_mean': valid[valid['label'] == 0]['duration_total'].mean(),
                'duration_std': valid[valid['label'] == 0]['duration_total'].std(),
                'duration_after_trim_mean': valid[valid['label'] == 0]['duration_after_trim'].mean(),
                'onset_mean': onset_valid[onset_valid['label'] == 0]['onset_time'].mean() if len(onset_valid[onset_valid['label'] == 0]) > 0 else 0.0,
                'silence_mean': valid[valid['label'] == 0]['silence_ratio'].mean(),
                'clipping_mean': valid[valid['label'] == 0]['clipping_ratio'].mean()
            },
            'patients': {
                'n': (valid['label'] == 1).sum(),
                'duration_mean': valid[valid['label'] == 1]['duration_total'].mean(),
                'duration_std': valid[valid['label'] == 1]['duration_total'].std(),
                'duration_after_trim_mean': valid[valid['label'] == 1]['duration_after_trim'].mean(),
                'onset_mean': onset_valid[onset_valid['label'] == 1]['onset_time'].mean() if len(onset_valid[onset_valid['label'] == 1]) > 0 else 0.0,
                'silence_mean': valid[valid['label'] == 1]['silence_ratio'].mean(),
                'clipping_mean': valid[valid['label'] == 1]['clipping_ratio'].mean()
            }
        }
    }

    return stats_df, summary


# ============================================================
# PROBLEMATIC FILE IDENTIFICATION
# ============================================================

def identify_problematic_audio(stats_df):
    """
    Identify audio files with quality issues or insufficient content.

    Files are flagged based on multiple quality criteria and assigned
    a severity level (CRITICAL or WARNING) based on cumulative score.

    Parameters
    ----------
    stats_df : pd.DataFrame
        Per-file statistics from analyze_audio_dataset_simple()

    Returns
    -------
    pd.DataFrame
        Problematic files sorted by severity, with columns:
        - filename, tipo_soggetto
        - severity: 'CRITICAL' or 'WARNING'
        - severity_score: numerical score (higher = more severe)
        - issues: concatenated description of all detected issues
        - duration_total, duration_after_trim
        - silence_ratio, rms_mean, clipping_ratio

    Quality Criteria
    ----------------
    - Silence ratio > 60% (score +1) or > 80% (score +2)
    - RMS energy < 0.01 (score +2) - very low volume
    - Total duration < 60s (score +1)
    - Duration after trim < 30s (score +2) - insufficient content
    - Clipping ratio > 1% (score +1) or > 5% (score +2)
    - Onset detection failed (score +1)

    Severity Classification
    -----------------------
    - CRITICAL: severity_score >= 3 (multiple issues or severe single issue)
    - WARNING: severity_score < 3 (minor or single issue)
    """
    problematic = []

    for idx, row in stats_df.iterrows():
        if pd.isna(row.get('duration_total')):
            problematic.append({
                'filename': row['filename'],
                'tipo_soggetto': row['tipo_soggetto'],
                'severity': 'CRITICAL',
                'severity_score': 10,
                'issues': 'File loading failed',
                'duration_total': np.nan,
                'duration_after_trim': np.nan,
                'silence_ratio': np.nan,
                'rms_mean': np.nan,
                'clipping_ratio': np.nan
            })
            continue

        issues = []
        severity_score = 0

        if row['silence_ratio'] > 0.6:
            issues.append(f"High silence: {row['silence_ratio']:.1%}")
            severity_score += 2 if row['silence_ratio'] > 0.8 else 1

        if row['rms_mean'] < 0.01:
            issues.append(f"Very low energy: {row['rms_mean']:.4f}")
            severity_score += 2

        if row['duration_total'] < 60:
            issues.append(f"Short total: {row['duration_total']:.1f}s")
            severity_score += 1

        if row['duration_after_trim'] < 30:
            issues.append(f"Short after trim: {row['duration_after_trim']:.1f}s")
            severity_score += 2

        if row['clipping_ratio'] > 0.01:
            issues.append(f"Clipping: {row['clipping_ratio']:.1%}")
            severity_score += 2 if row['clipping_ratio'] > 0.05 else 1

        if not row.get('onset_detected', True):
            issues.append("Onset detection failed")
            severity_score += 1

        if issues:
            problematic.append({
                'filename': row['filename'],
                'tipo_soggetto': row['tipo_soggetto'],
                'severity': 'CRITICAL' if severity_score >= 3 else 'WARNING',
                'severity_score': severity_score,
                'issues': ' | '.join(issues),
                'duration_total': row['duration_total'],
                'duration_after_trim': row['duration_after_trim'],
                'silence_ratio': row['silence_ratio'],
                'rms_mean': row['rms_mean'],
                'clipping_ratio': row['clipping_ratio']
            })

    if problematic:
        prob_df = pd.DataFrame(problematic)
        prob_df = prob_df.sort_values('severity_score', ascending=False)
        return prob_df
    else:
        return pd.DataFrame()


# ============================================================
# ONSET VALIDATION (VISUAL)
# ============================================================

def validate_onset_detection(df, audio_dir, stats_df, n_samples=4):
    """
    Generate visual validation plots for onset detection results.

    Creates waveform plots with marked onset times for manual inspection
    and validation of the automatic detection algorithm.

    Parameters
    ----------
    df : pd.DataFrame
        Original dataset metadata
    audio_dir : str
        Directory containing audio files
    stats_df : pd.DataFrame
        Statistics from analysis (must contain onset detection results)
    n_samples : int, default=4
        Number of sample files to visualize

    Returns
    -------
    matplotlib.figure.Figure or None
        Figure with n_samples subplots showing waveforms with onset markers.
        Returns None if no valid files are available.

    Notes
    -----
    - Attempts to balance samples (2 controls + 2 patients if possible)
    - Only includes files where onset detection succeeded
    - Red line: detected onset time
    - Green line: trim start (onset - margin)
    - Shaded region: portion that would be trimmed
    """
    valid_files = stats_df[stats_df['onset_detected'] == True]

    if len(valid_files) < n_samples:
        n_samples = len(valid_files)

    if n_samples == 0:
        print("Warning: No valid files for onset validation")
        return None

    controls = valid_files[valid_files['label'] == 0]
    patients = valid_files[valid_files['label'] == 1]

    samples = []
    if len(controls) >= 2:
        samples.extend(controls.sample(min(2, len(controls))).to_dict('records'))
    if len(patients) >= 2:
        samples.extend(patients.sample(min(2, len(patients))).to_dict('records'))

    while len(samples) < n_samples:
        remaining = valid_files[~valid_files['filename'].isin([s['filename'] for s in samples])]
        if len(remaining) > 0:
            samples.append(remaining.sample(1).to_dict('records')[0])
        else:
            break

    fig, axes = plt.subplots(len(samples), 1, figsize=(15, 3*len(samples)))

    if len(samples) == 1:
        axes = [axes]

    for ax, sample in zip(axes, samples):
        audio_path = os.path.join(audio_dir, sample['filename'])

        try:
            y, sr = librosa.load(audio_path, sr=None, mono=True)
            time = np.arange(len(y)) / sr

            ax.plot(time, y, alpha=0.7, linewidth=0.5, color='steelblue')

            if sample['onset_time'] is not None:
                ax.axvline(sample['onset_time'], color='red', linestyle='--',
                          linewidth=2, label=f"Onset: {sample['onset_time']:.2f}s")

                ax.axvline(sample['trim_amount'], color='green', linestyle='--',
                          linewidth=2, alpha=0.6, label=f"Trim start: {sample['trim_amount']:.2f}s")

                ax.axvspan(0, sample['trim_amount'], alpha=0.2, color='red')

            ax.set_title(f"{sample['filename'][:50]} - {sample['tipo_soggetto']} - "
                        f"Duration: {sample['duration_total']:.1f}s → {sample['duration_after_trim']:.1f}s after trim")
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Amplitude')
            ax.legend(loc='upper right')
            ax.grid(True, alpha=0.3)

        except Exception as e:
            ax.text(0.5, 0.5, f"Error loading: {e}",
                   ha='center', va='center', transform=ax.transAxes)

    plt.tight_layout()
    return fig


# ============================================================
# VISUALIZATION
# ============================================================

def plot_analysis(stats_df, summary):
    """
    Create comprehensive 6-panel visualization of dataset statistics.

    Parameters
    ----------
    stats_df : pd.DataFrame
        Per-file statistics
    summary : dict
        Aggregate statistics

    Returns
    -------
    matplotlib.figure.Figure
        6-panel figure with:
        1. Duration by group (boxplot)
        2. Duration distribution (histogram)
        3. Onset time distribution (histogram)
        4. Duration after trimming by group (boxplot)
        5. Silence ratio by group (boxplot)
        6. Duration vs Silence scatter (outlier detection)
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Audio Dataset Analysis - Summary Statistics',
                 fontsize=16, fontweight='bold')

    valid = stats_df.dropna(subset=['duration_total'])

    # Panel 1: Duration by group
    ax = axes[0, 0]
    valid.boxplot(column='duration_total', by='tipo_soggetto', ax=ax)
    ax.axhline(y=30, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='30s')
    ax.axhline(y=60, color='green', linestyle='--', linewidth=2, label='60s')
    ax.axhline(y=90, color='orange', linestyle='--', linewidth=1.5, alpha=0.7, label='90s')
    ax.set_title('Total Duration by Group')
    ax.set_ylabel('Duration (seconds)')
    ax.set_xlabel('')
    ax.legend(loc='upper right', fontsize=8)
    plt.sca(ax)
    plt.xticks(rotation=0)

    # Panel 2: Duration histogram
    ax = axes[0, 1]
    ax.hist(valid['duration_total'], bins=30, edgecolor='black', alpha=0.7, color='skyblue')
    ax.axvline(summary['duration_total']['mean'], color='red', linestyle='--',
               linewidth=2, label=f"Mean: {summary['duration_total']['mean']:.1f}s")
    ax.axvline(summary['duration_total']['median'], color='green', linestyle='--',
               linewidth=2, label=f"Median: {summary['duration_total']['median']:.1f}s")
    ax.set_title('Duration Distribution')
    ax.set_xlabel('Duration (seconds)')
    ax.set_ylabel('Frequency')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # Panel 3: Onset distribution
    ax = axes[0, 2]
    onset_valid = valid[valid['onset_detected'] == True]['onset_time'].dropna()
    if len(onset_valid) > 0:
        ax.hist(onset_valid, bins=30, edgecolor='black', alpha=0.7, color='coral')
        ax.axvline(summary['onset']['mean'], color='red', linestyle='--',
                   linewidth=2, label=f"Mean: {summary['onset']['mean']:.2f}s")
        ax.axvline(summary['onset']['median'], color='green', linestyle='--',
                   linewidth=2, label=f"Median: {summary['onset']['median']:.2f}s")
        ax.legend(fontsize=8)
    ax.set_title(f"Speech Onset Time (detected in {summary['onset']['detection_success_rate']:.1%})")
    ax.set_xlabel('Onset Time (seconds)')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)

    # Panel 4: Duration after trim by group
    ax = axes[1, 0]
    valid.boxplot(column='duration_after_trim', by='tipo_soggetto', ax=ax)
    ax.axhline(y=60, color='green', linestyle='--', linewidth=2, label='60s reference')
    ax.axhline(y=30, color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='30s reference')
    ax.axhline(y=90, color='orange', linestyle='--', linewidth=1.5, label='90s reference')
    ax.set_title('Duration After Onset Trimming')
    ax.set_ylabel('Duration (seconds)')
    ax.set_xlabel('')
    ax.legend(loc='upper right', fontsize=8)
    plt.sca(ax)
    plt.xticks(rotation=0)

    # Panel 5: Silence ratio by group
    ax = axes[1, 1]
    valid.boxplot(column='silence_ratio', by='tipo_soggetto', ax=ax)
    ax.axhline(y=0.6, color='red', linestyle='--', linewidth=2,
               label='60% reference')
    ax.set_title('Silence Ratio by Group (10% of peak method)')
    ax.set_ylabel('Silence Ratio')
    ax.set_xlabel('')
    ax.legend(loc='upper right', fontsize=8)
    ax.set_ylim([0, 1])
    plt.sca(ax)
    plt.xticks(rotation=0)

    # Panel 6: Duration vs Silence scatter
    ax = axes[1, 2]
    colors = {'Controllo': 'blue', 'Paziente': 'red'}
    for tipo in valid['tipo_soggetto'].unique():
        subset = valid[valid['tipo_soggetto'] == tipo]
        ax.scatter(subset['duration_total'], subset['silence_ratio'],
                  label=tipo, alpha=0.6, s=50, color=colors.get(tipo, 'gray'),
                  edgecolors='black', linewidth=0.5)
    ax.axhline(y=0.6, color='red', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.axvline(x=60, color='green', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.set_xlabel('Duration (seconds)')
    ax.set_ylabel('Silence Ratio')
    ax.set_title('Duration vs Silence (outlier detection)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    return fig


# ============================================================
# SUMMARY REPORT (CONSOLE OUTPUT)
# ============================================================

def print_summary_report(summary, problematic_df):
    """
    Print formatted summary report to console.

    Displays aggregate statistics organized by category:
    - Dataset overview (sample counts by group)
    - Sample rate consistency
    - Duration statistics (total and after trimming)
    - Speech onset detection performance
    - Audio quality metrics
    - Group comparisons (Controls vs Patients)
    - Problematic file summary
    """
    print("\n" + "="*80)
    print("SUMMARY REPORT")
    print("="*80)

    print(f"\n Dataset Overview")
    print(f"  Total samples:     {summary['n_samples']}")
    print(f"  Controls:          {summary['n_controls']}")
    print(f"  Patients:          {summary['n_patients']}")
    if summary['n_errors'] > 0:
        print(f"  Errors:            {summary['n_errors']}")

    print(f"\n Sample Rate")
    sr = summary['sample_rate']
    if sr['unique'] == 1:
        print(f"  All files:         {sr['min']} Hz")
    else:
        print(f"  Range:             {sr['min']} - {sr['max']} Hz")
        print(f"  Mixed rates:       {sr['unique']} different sample rates detected")

    print(f"\n Duration Statistics (Total)")
    d = summary['duration_total']
    print(f"  Mean ± SD:         {d['mean']:.1f}s ± {d['std']:.1f}s")
    print(f"  Median:            {d['median']:.1f}s")
    print(f"  Range:             [{d['min']:.1f}s, {d['max']:.1f}s]")
    print(f"  Quartiles:         Q25={d['q25']:.1f}s, Q75={d['q75']:.1f}s")

    print(f"\n Speech Onset Detection")
    o = summary['onset']
    print(f"  Detection success: {o['detection_success_rate']:.1%}")
    print(f"  Mean onset:        {o['mean']:.2f}s ± {o['std']:.2f}s")
    print(f"  Median onset:      {o['median']:.2f}s")
    print(f"  Max onset:         {o['max']:.2f}s")

    print(f"\n Duration After Trimming")
    dt = summary['duration_after_trim']
    print(f"  Mean ± SD:         {dt['mean']:.1f}s ± {dt['std']:.1f}s")
    print(f"  Median:            {dt['median']:.1f}s")
    print(f"  Range:             [{dt['min']:.1f}s, {dt['max']:.1f}s]")
    print(f"  Quartiles:         Q25={dt['q25']:.1f}s, Q75={dt['q75']:.1f}s")

    print(f"\n Audio Quality")
    q = summary['quality']
    print(f"  Silence ratio:     {q['silence_ratio_mean']:.1%} ± {q['silence_ratio_std']:.1%}")
    print(f"  RMS energy:        {q['rms_mean']:.4f} ± {q['rms_std']:.4f}")
    print(f"  Clipping mean:     {q['clipping_mean']:.2%}")
    if q['n_clipped'] > 0:
        print(f"  Files with >1% clipping: {q['n_clipped']}")

    print(f"\n Group Comparison")
    ctrl = summary['by_group']['controls']
    ptz = summary['by_group']['patients']

    print(f"  Sample size:")
    print(f"    Controls:        {ctrl['n']}")
    print(f"    Patients:        {ptz['n']}")

    print(f"  Duration (total):")
    print(f"    Controls:        {ctrl['duration_mean']:.1f}s ± {ctrl['duration_std']:.1f}s")
    print(f"    Patients:        {ptz['duration_mean']:.1f}s ± {ptz['duration_std']:.1f}s")

    print(f"  Duration (after trim):")
    print(f"    Controls:        {ctrl['duration_after_trim_mean']:.1f}s")
    print(f"    Patients:        {ptz['duration_after_trim_mean']:.1f}s")

    print(f"  Silence ratio:")
    print(f"    Controls:        {ctrl['silence_mean']:.1%}")
    print(f"    Patients:        {ptz['silence_mean']:.1%}")

    if len(problematic_df) > 0:
        print(f"\n⚠️  Problematic Audio Files")
        print(f"  Total:             {len(problematic_df)}")
        print(f"  Critical:          {(problematic_df['severity']=='CRITICAL').sum()}")
        print(f"  Warnings:          {(problematic_df['severity']=='WARNING').sum()}")
    else:
        print(f"\n✓ No problematic audio files detected")

    print("\n" + "="*80)


# ============================================================
# SAVE TEXT REPORT
# ============================================================

def save_text_report(summary, problematic_df, output_path):
    """
    Save comprehensive text report to file for documentation purposes.

    Creates a formatted text file containing all aggregate statistics
    suitable for inclusion in thesis or technical documentation.

    Parameters
    ----------
    summary : dict
        Aggregate statistics from analysis
    problematic_df : pd.DataFrame
        Problematic files report
    output_path : str
        Output file path (typically .txt)
    """
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("AUDIO DATASET ANALYSIS REPORT\n")
        f.write("="*80 + "\n")
        f.write(f"\nGenerated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

        f.write("\n" + "="*80 + "\n")
        f.write("DATASET OVERVIEW\n")
        f.write("="*80 + "\n")
        f.write(f"Total samples:     {summary['n_samples']}\n")
        f.write(f"Controls:          {summary['n_controls']}\n")
        f.write(f"Patients:          {summary['n_patients']}\n")
        if summary['n_errors'] > 0:
            f.write(f"Errors:            {summary['n_errors']}\n")

        f.write("\n" + "="*80 + "\n")
        f.write("SAMPLE RATE\n")
        f.write("="*80 + "\n")
        sr = summary['sample_rate']
        if sr['unique'] == 1:
            f.write(f"All files:         {sr['min']} Hz\n")
        else:
            f.write(f"Range:             {sr['min']} - {sr['max']} Hz\n")
            f.write(f"Mixed rates:       {sr['unique']} different sample rates\n")

        f.write("\n" + "="*80 + "\n")
        f.write("DURATION STATISTICS (Total)\n")
        f.write("="*80 + "\n")
        d = summary['duration_total']
        f.write(f"Mean ± SD:         {d['mean']:.2f}s ± {d['std']:.2f}s\n")
        f.write(f"Median:            {d['median']:.2f}s\n")
        f.write(f"Range:             [{d['min']:.2f}s, {d['max']:.2f}s]\n")
        f.write(f"Q25:               {d['q25']:.2f}s\n")
        f.write(f"Q75:               {d['q75']:.2f}s\n")

        f.write("\n" + "="*80 + "\n")
        f.write("SPEECH ONSET DETECTION\n")
        f.write("="*80 + "\n")
        o = summary['onset']
        f.write(f"Detection success: {o['detection_success_rate']:.1%}\n")
        f.write(f"Mean onset:        {o['mean']:.2f}s ± {o['std']:.2f}s\n")
        f.write(f"Median onset:      {o['median']:.2f}s\n")
        f.write(f"Max onset:         {o['max']:.2f}s\n")

        f.write("\n" + "="*80 + "\n")
        f.write("DURATION AFTER TRIMMING\n")
        f.write("="*80 + "\n")
        dt = summary['duration_after_trim']
        f.write(f"Mean ± SD:         {dt['mean']:.2f}s ± {dt['std']:.2f}s\n")
        f.write(f"Median:            {dt['median']:.2f}s\n")
        f.write(f"Range:             [{dt['min']:.2f}s, {dt['max']:.2f}s]\n")
        f.write(f"Q25:               {dt['q25']:.2f}s\n")
        f.write(f"Q75:               {dt['q75']:.2f}s\n")

        f.write("\n" + "="*80 + "\n")
        f.write("AUDIO QUALITY METRICS\n")
        f.write("="*80 + "\n")
        q = summary['quality']
        f.write(f"Silence ratio:     {q['silence_ratio_mean']:.2%} ± {q['silence_ratio_std']:.2%}\n")
        f.write(f"RMS energy:        {q['rms_mean']:.4f} ± {q['rms_std']:.4f}\n")
        f.write(f"Clipping mean:     {q['clipping_mean']:.3%}\n")
        f.write(f"Clipping max:      {q['clipping_max']:.3%}\n")
        if q['n_clipped'] > 0:
            f.write(f"Files with >1% clipping: {q['n_clipped']}\n")

        f.write("\n" + "="*80 + "\n")
        f.write("GROUP COMPARISON (Controls vs Patients)\n")
        f.write("="*80 + "\n")
        ctrl = summary['by_group']['controls']
        ptz = summary['by_group']['patients']

        f.write(f"\nSample Size:\n")
        f.write(f"  Controls:        {ctrl['n']}\n")
        f.write(f"  Patients:        {ptz['n']}\n")

        f.write(f"\nDuration (total):\n")
        f.write(f"  Controls:        {ctrl['duration_mean']:.2f}s ± {ctrl['duration_std']:.2f}s\n")
        f.write(f"  Patients:        {ptz['duration_mean']:.2f}s ± {ptz['duration_std']:.2f}s\n")

        f.write(f"\nDuration (after trim):\n")
        f.write(f"  Controls:        {ctrl['duration_after_trim_mean']:.2f}s\n")
        f.write(f"  Patients:        {ptz['duration_after_trim_mean']:.2f}s\n")

        f.write(f"\nSilence ratio:\n")
        f.write(f"  Controls:        {ctrl['silence_mean']:.2%}\n")
        f.write(f"  Patients:        {ptz['silence_mean']:.2%}\n")

        if len(problematic_df) > 0:
            f.write("\n" + "="*80 + "\n")
            f.write("PROBLEMATIC FILES\n")
            f.write("="*80 + "\n")
            f.write(f"Total:             {len(problematic_df)}\n")
            f.write(f"Critical:          {(problematic_df['severity']=='CRITICAL').sum()}\n")
            f.write(f"Warnings:          {(problematic_df['severity']=='WARNING').sum()}\n")

        f.write("\n" + "="*80 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*80 + "\n")


# ============================================================
# EXECUTE ANALYSIS
# ============================================================

if RUN_DATASET_ANALYSIS:
    # Define output directory for dataset analysis
    output_dir = os.path.join(DATA_PATH, 'dataset_free_speech_analysis')
    os.makedirs(output_dir, exist_ok=True)

    # Run comprehensive analysis
    stats_df, summary = analyze_audio_dataset_simple(df, CONFIG['audio_dir'])

    # Identify problematic files
    problematic_df = identify_problematic_audio(stats_df)

    # Display summary report in console
    print_summary_report(summary, problematic_df)

    # Generate and save visualization
    fig = plot_analysis(stats_df, summary)
    plot_path = os.path.join(output_dir, 'audio_analysis.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()

    # Visual validation of onset detection
    print("\n" + "="*80)
    print("ONSET DETECTION VALIDATION")
    print("="*80)
    print("Generating visual validation for 4 random samples...")

    val_fig = validate_onset_detection(df, CONFIG['audio_dir'], stats_df, n_samples=4)
    if val_fig is not None:
        val_path = os.path.join(output_dir, 'onset_validation.png')
        plt.savefig(val_path, dpi=150, bbox_inches='tight')
        plt.show()
        print(f"Visual validation saved: {val_path}")

    # Save outputs
    csv_path = os.path.join(output_dir, 'audio_statistics.csv')
    stats_df.to_csv(csv_path, index=False)

    report_path = os.path.join(output_dir, 'audio_analysis_report.txt')
    save_text_report(summary, problematic_df, report_path)

    if len(problematic_df) > 0:
        prob_path = os.path.join(output_dir, 'problematic_audio.csv')
        problematic_df.to_csv(prob_path, index=False)

    # Final summary
    print("\n" + "="*80)
    print("OUTPUT FILES")
    print("="*80)
    print(f"Per-file statistics (CSV):    {csv_path}")
    print(f"Summary report (text):        {report_path}")
    print(f"Visualization (6-panel):      {plot_path}")
    if val_fig is not None:
        print(f"Onset validation (visual):    {val_path}")
    if len(problematic_df) > 0:
        print(f"Problematic files (CSV):      {prob_path}")

    print("\nAnalysis complete.")

else:
    print("-> Skipping dataset analysis (RUN_DATASET_ANALYSIS = False)")


## xAI

In [None]:
# ============================================================
# EXPLAINABLE AI - SETUP
# ============================================================
# ImplementazionE di 2 tecniche complementari:
# 1. Integrated Gradients (Sundararajan et al. 2017)
# 2. Attention Rollout (Abnar & Zuidema 2020)
# ============================================================

import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import librosa
import librosa.display
from scipy import stats

import warnings
warnings.filterwarnings('ignore')

# Setup plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("✓ Setup completato")

# ============================================================
# MODIFICA AL MODELLO: Estrazione Attention Weights
# ============================================================
# Aggiungiamo la capacità di estrarre attention da tutti i layer
# SENZA modificare i pesi (solo inference mode)
# ============================================================

class HuBERTClassifierExplainable(nn.Module):
    """
    Versione modificata di HuBERTClassifier per Explainable AI.

    Modifiche rispetto all'originale:
    - Flag return_all_attentions per estrarre attention da tutti i 12 layer HuBERT
    - Flag return_embeddings per estrarre hidden states intermedi
    - Nessuna modifica ai pesi o all'architettura

    Uso:
    - Training: usa forward() normale (comportamento identico)
    - XAI: usa forward(..., return_all_attentions=True, return_embeddings=True)
    """

    def __init__(self, model_name, freeze_layers=9, hidden_dim=256,
                 dropout=0.3, num_classes=2, pooling_type='attention'):
        super().__init__()

        # HuBERT encoder
        self.hubert = HubertModel.from_pretrained(model_name)
        self.hidden_size = self.hubert.config.hidden_size  # 768

        # Freeze primi N layer (come training)
        for layer_idx in range(freeze_layers):
            for param in self.hubert.encoder.layers[layer_idx].parameters():
                param.requires_grad = False

        # Pooling layer
        self.pooling_type = pooling_type
        if pooling_type == 'attention':
            self.pooling = AttentionPooling(self.hidden_size)

        # Classifier MLP
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_values, attention_mask=None,
                return_all_attentions=False, return_embeddings=False):
        """
        Forward pass con opzioni per Explainability.

        Args:
            input_values: audio waveform (batch, time)
            attention_mask: mask per padding (batch, time)
            return_all_attentions: se True, restituisce attention da tutti layer HuBERT
            return_embeddings: se True, restituisce hidden states

        Returns:
            Se return_all_attentions=False e return_embeddings=False:
                logits: (batch, num_classes)

            Se return_all_attentions=True o return_embeddings=True:
                dict con chiavi:
                - 'logits': (batch, num_classes)
                - 'hubert_attentions': tuple di 12 tensori (batch, heads, time, time) [opzionale]
                - 'pooling_attention': (batch, time, 1) [se attention pooling]
                - 'embeddings': (batch, time, 768) [opzionale]
        """
        # HuBERT encoding con opzione di estrarre attention
        outputs = self.hubert(
            input_values,
            attention_mask=attention_mask,
            output_attentions=return_all_attentions,  # ← Chiave: registra attention
            output_hidden_states=return_embeddings
        )

        embeddings = outputs.last_hidden_state  # (batch, time, 768)

        # Pooling
        if self.pooling_type == 'attention':
            pooled, attn_weights = self.pooling(embeddings)
        else:
            pooled = embeddings.mean(dim=1)
            attn_weights = None

        # Classificazione
        logits = self.classifier(pooled)

        # Return diversificato in base a flag
        if return_all_attentions or return_embeddings:
            result = {'logits': logits}

            if return_all_attentions:
                result['hubert_attentions'] = outputs.attentions  # Tuple di 12 tensori
                if attn_weights is not None:
                    result['pooling_attention'] = attn_weights

            if return_embeddings:
                result['embeddings'] = embeddings
                result['pooled'] = pooled

            return result
        else:
            return logits

print("✓ HuBERTClassifierExplainable definito")


## IG

In [None]:
# TECHNIQUE 1: INTEGRATED GRADIENTS
# Paper: Sundararajan et al. (2017) "Axiomatic Attribution for Deep Networks"


import os
import numpy as np
import torch
import torch.nn.functional as F
import librosa
import librosa.display
import matplotlib.pyplot as plt
from transformers import Wav2Vec2FeatureExtractor



class IntegratedGradients:
    """
    Integrated Gradients per attribution.

    Formula:
        IG(x) = (x - baseline) × ∫₀¹ ∂F(baseline + α(x - baseline))/∂x dα
    """


    def __init__(self, model, device, onset_map=None):
        self.model = model
        self.device = device
        self.onset_map = onset_map if onset_map is not None else {}
        self.model.eval()  # SEMPRE in eval mode


    def compute_attributions(
        self,
        inputaudio,
        targetclass,
        nsteps=50,
        internalbatchsize=1,
        baseline=None,
    ):
        inputaudio = inputaudio.to(self.device)


        if baseline is None:
            baseline = torch.zeros_like(inputaudio, device=self.device)
        else:
            baseline = baseline.to(self.device)


        alpha_power = 3.0  # più grande => più punti vicino a 0
        alphas = torch.linspace(0.0, 1.0, steps=nsteps + 1, device=self.device) ** alpha_power


        all_grads = []
        for i in range(0, len(alphas), internalbatchsize):
            batch_alphas = alphas[i : i + internalbatchsize]  # [b]
            interpolated = baseline + batch_alphas.view(-1, 1) * (inputaudio - baseline)  # [b, T]

            interpolated = interpolated.detach().requires_grad_(True)


            logits = self.model(interpolated)  # [b, C]
            target = logits[:, targetclass].sum()  # scalare


            grads = torch.autograd.grad(
                outputs=target,
                inputs=interpolated,
                create_graph=False
            )[0]  # [b, T]
            all_grads.append(grads.detach())


            if torch.isnan(grads).any() or torch.isinf(grads).any():
                print("GRADS NAN/INF! step index:", i, "alpha range:", batch_alphas[0].item(), "->", batch_alphas[-1].item())
                print("interpolated min/max:", interpolated.min().item(), interpolated.max().item())


            del interpolated, logits, target, grads


        all_grads = torch.cat(all_grads, dim=0)  # [nsteps+1, T]


        # Trapezoid rule PESATA (necessaria perché alphas non sono uniformi)
        dalpha = (alphas[1:] - alphas[:-1]).view(-1, 1)              # (n_steps, 1)
        trap = (all_grads[:-1] + all_grads[1:]) / 2.0               # (n_steps, T)
        integral = (trap * dalpha).sum(dim=0, keepdim=True)         # (1, T)


        print("alpha[1]=", alphas[1].item(), "alpha[2]=", alphas[2].item(), "alpha[10]=", alphas[10].item())


        attributions = (inputaudio - baseline) * integral           # (1, T)


        # Completeness check
        with torch.no_grad():
            out_in = self.model(inputaudio)[0, targetclass]
            out_base = self.model(baseline)[0, targetclass]
            delta_out = out_in - out_base
            sum_attr = attributions.sum()
            completeness_error_abs = (delta_out - sum_attr).abs().item()


            print(
                f"logit_input={out_in.item():.4f}  logit_base={out_base.item():.4f}  delta_out={delta_out.item():.4f}"
            )
            print(f"sum_attr={sum_attr.item():.4f}  abs_err={completeness_error_abs:.4f}")


            completeness_error_rel = completeness_error_abs / (delta_out.abs().item() + 1e-8)


        return attributions.detach().cpu(), completeness_error_abs, completeness_error_rel


    def visualize_on_spectrogram(self, audio_path, target_class, filename, sr=16000, save_dir=None):
        # Carica audio
        waveform, _ = librosa.load(audio_path, sr=sr, mono=True)


        # STEP 1: ONSET TRIMMING (se hai onsetmap)
        if (filename + ".wav") in self.onset_map:
            trim_amount_sec = self.onset_map[filename + ".wav"]
            trim_samples = int(trim_amount_sec * sr)
            if 0 < trim_samples < len(waveform):
                waveform = waveform[trim_samples:]
                print(f"Trimmed: {trim_amount_sec:.2f}s removed")


        # STEP 2: TRUNCATION a 30s
        max_duration = 30
        max_samples = sr * max_duration
        original_length = len(waveform)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
            print(f"Truncated: {len(waveform)/sr:.2f}s (was {original_length/sr:.2f}s)")


        # STEP 3: Processor
        processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
        inputs = processor(
            waveform,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True
        )
        input_values = inputs.input_values.to(self.device)  # [1, T]
        waveform_processed = input_values.squeeze(0).detach().cpu().numpy()


        baseline_wave = np.zeros_like(waveform)
        baseline_inputs = processor(baseline_wave, sampling_rate=sr, return_tensors="pt", padding=True)
        baseline_values = baseline_inputs.input_values.to(self.device)


        print(f"Baseline creata (shape: {tuple(baseline_values.shape)})")


        print("input dtype:", input_values.dtype, "device:", input_values.device)
        print("model dtype:", next(self.model.parameters()).dtype, "device:", next(self.model.parameters()).device)


        # DEBUG: controlla se il percorso baseline->input è smooth
        alphas_test = [0.0, 1e-6, 1e-4, 1e-2, 0.1, 0.5, 1.0]
        with torch.no_grad():
            for a in alphas_test:
                x = baseline_values + a * (input_values - baseline_values)
                li = self.model(x)[0, target_class].item()
                print(f"alpha={a:.0e}  logit={li:.6f}")


        # Predizione e confidenza
        with torch.no_grad():
            logits = self.model(input_values)
            probs = torch.softmax(logits, dim=1)
            confidence = probs[0, target_class].item()


        # Scelta adattiva n_steps e internal_batch_size
        if confidence > 0.95:
            n_steps = 600
            internal_batch_size = 1
            print(f"High confidence ({confidence:.3f}) -> n_steps={n_steps}, batch=1")
        elif confidence > 0.85:
            n_steps = 500
            internal_batch_size = 1
            print(f"Medium-high confidence ({confidence:.3f}) -> n_steps={n_steps}, batch=1")
        else:
            n_steps = 30
            internal_batch_size = 2
            print(f"Normal confidence ({confidence:.3f}) -> n_steps={n_steps}, batch=2")


        # Integrated Gradients
        attributions, completeness_err_abs, completeness_err_rel = self.compute_attributions(
            input_values,
            target_class,
            nsteps=n_steps,
            internalbatchsize=internal_batch_size,
            baseline=baseline_values,
        )


        # Mel-spectrogram (da waveform processato)
        mel_spec = librosa.feature.melspectrogram(
            y=waveform_processed,
            sr=sr,
            n_mels=128,
            hop_length=512,
            n_fft=2048
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)


        # Resample attributions alla risoluzione temporale del mel
        n_frames = mel_spec.shape[1]


        # Robustezza shape: vogliamo [1, T]
        if attributions.dim() == 1:
            attributions = attributions.unsqueeze(0)  # [1, T]
        elif attributions.dim() != 2:
            raise ValueError(f"attributions shape non supportata: {tuple(attributions.shape)}")


        # Per interpolate(linear) serve [N, C, L]
        attr_3d = attributions.unsqueeze(1)  # [1, 1, T]
        attribution_resampled = F.interpolate(
            attr_3d,
            size=n_frames,
            mode="linear",
            align_corners=False
        ).squeeze(0).squeeze(0).detach().cpu().numpy()  # [n_frames]


        # Espandi su asse frequenza (replica su tutte le bande: è una mappa time-only)
        attribution_map = np.tile(attribution_resampled, (mel_spec.shape[0], 1))


        # Scala robusta per visualizzazione: evita z-score per-sample (ingannevole),
        # ma limita outlier con percentili e mantiene il significato del valore (segno e ampiezza).
        p_low, p_high = np.percentile(attribution_resampled, [1, 99])
        lim = float(max(abs(p_low), abs(p_high)) + 1e-12)  # +eps per evitare lim=0

        # VISUALIZATION
        fig, axes = plt.subplots(2, 1, figsize=(16, 8))
        duration = input_values.shape[1] / sr

        # Plot 1
        img1 = librosa.display.specshow(
            mel_spec_db,
            sr=sr,
            hop_length=512,
            x_axis="time",
            y_axis="mel",
            ax=axes[0],
            cmap="viridis"
        )
        axes[0].set_title(f"Mel-Spectrogram - {filename}", fontsize=13, fontweight="bold")
        axes[0].set_ylabel("Mel Frequency", fontsize=11)
        fig.colorbar(img1, ax=axes[0], format="%+2.0f dB")

        # Plot 2
        img2 = axes[1].imshow(
            attribution_map,
            aspect="auto",
            cmap="RdBu_r",
            extent=[0, duration, 0, sr/2],
            vmin=-lim,
            vmax=lim,
            origin="lower",
            interpolation="bilinear"
        )

        axes[1].set_title(
            f'Integrated Gradients Attribution - Temporal Profile (Class: {"Paziente" if target_class==1 else "Controllo"})',
            fontsize=13,
            fontweight="bold"
        )
        axes[1].set_xlabel("Time (s)", fontsize=11)

        # nascondi scala Y
        axes[1].set_yticks([])
        axes[1].set_ylabel("")

        fig.colorbar(img2, ax=axes[1], label="Attribution")

        plt.tight_layout()


        # Save
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"IG_{filename}.png")
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
            print(f"Salvato: {save_path}")

        time_axis = np.linspace(0, duration, len(attribution_resampled))

        metrics = {
            "completeness_error_abs": completeness_err_abs,
            "completeness_error_rel": completeness_err_rel,
            "mean_attribution": float(attribution_resampled.mean()),
            "std_attribution": float(attribution_resampled.std()),
            "max_attribution_time": float(time_axis[np.argmax(attribution_resampled)]),
            "positive_ratio": float((attribution_resampled > 0).mean()),
        }


        return fig, attribution_map, metrics


print("IntegratedGradients definito")

## Attention Rollout

In [None]:
#
# TECHNIQUE 2: ATTENTION ROLLOUT
#
# Paper: Abnar & Zuidema (2020) "Quantifying Attention Flow in Transformers"
#

class AttentionRollout:
    """
    Attention Rollout per tracciare flusso informazione temporale.

    """

    def __init__(self, model, device, onset_map=None):
        self.model = model
        self.device = device
        self.onset_map = onset_map if onset_map is not None else {}
        self.model.eval()

    def compute_rollout(self, input_audio, head_fusion="mean", discard_ratio=0.0):
        """
        Calcola Attention Rollout.

        Args:
            input_audio: waveform (1, time)
            head_fusion: 'mean', 'max', 'min' - come aggregare multi-head
            discard_ratio: percentuale di attention minime da scartare (0-0.2)

        Returns:
            rollout: (time,) - attention score per timestep
            layer_attentions: list di attention per layer (debugging)
        """
        input_audio = input_audio.to(self.device)

        # Forward con estrazione attention
        with torch.no_grad():
            outputs = self.model(
                input_audio,
                return_all_attentions=True,
                return_embeddings=False
            )

        # Estrae attention weights da HuBERT
        hubert_attentions = outputs['hubert_attentions']  # Tuple di 12 tensori
        # Ogni tensore: (batch=1, num_heads, time, time)

        # Estrae attention pooling finale
        pooling_attention = outputs.get('pooling_attention', None)
        # Shape: (batch=1, time, 1)

        #
        # STEP 1: Fuse multi-head attention per ciascun layer
        #
        fused_attentions = []

        for layer_idx, attn in enumerate(hubert_attentions):
            attn = attn.squeeze(0)  # (num_heads, time, time)

            # Aggregazione heads
            if head_fusion == "mean":
                attn_fused = attn.mean(dim=0)  # (time, time)
            elif head_fusion == "max":
                attn_fused = attn.max(dim=0)[0]
            elif head_fusion == "min":
                attn_fused = attn.min(dim=0)[0]
            else:
                raise ValueError(f"Unknown head_fusion: {head_fusion}")

            # Discard low attention (optional, paper usa 0.1)
            if discard_ratio > 0:
                threshold = attn_fused.flatten().kthvalue(
                    int(attn_fused.numel() * discard_ratio)
                )[0]
                attn_fused = attn_fused.clamp(min=threshold)

            # Normalizza righe (somma = 1)
            attn_fused = attn_fused / (attn_fused.sum(dim=-1, keepdim=True) + 1e-8)

            fused_attentions.append(attn_fused)

        #
        # STEP 2: Rollout attraverso layer HuBERT
        #
        # Formula: A_rollout = (I + Aᵢ)/2
        # Implementazione ricorsiva:
        # R_L = (I + A_L)/2
        # R_{l-1} = (I + A_{l-1})/2 × R_l
        #
        num_layers = len(fused_attentions)
        time_dim = fused_attentions[0].shape[0]

        # Matrice identità
        I = torch.eye(time_dim, device=self.device)

        # Inizia dall'ultimo layer: R_L = (I + A_L)/2
        rollout = (I + fused_attentions[-1]) / 2.0

        # Propaga all'indietro attraverso i layer
        for layer_idx in range(num_layers - 2, -1, -1):
            # A_normalized = (I + A_i)/2
            A_normalized = (I + fused_attentions[layer_idx]) / 2.0

            # Rollout: R_i-1 = A_normalized × R_i
            rollout = torch.matmul(A_normalized, rollout)

            # Normalizza per stabilità numerica
            rollout = rollout / (rollout.sum(dim=-1, keepdim=True) + 1e-8)

        #
        # STEP 3: Include Attention Pooling finale (CORRETTO)
        #
        # rollout: (time, time) - quanto ogni token contribuisce a ogni altro
        # pooling_attention: (time, 1) - quanto ogni token contribuisce al pooled vector
        #
        # CORREZIONE: Usa prodotto matriciale, NON element-wise
        # rollout_scores = pooling_attention^T × rollout → (1, time) × (time, time)
        # Semplificato: somma pesata delle colonne
        #
        if pooling_attention is not None:
            pooling_attention = pooling_attention.squeeze()  # (time,)

            # Prodotto matriciale corretto: (1, time) × (time, time) = (1, time)
            # Poi somma su colonne per ottenere importance score finale
            rollout_scores = torch.matmul(pooling_attention.unsqueeze(0), rollout).squeeze()
            # Shape: (time,)

            # Normalizza a [0, 1]
            rollout_scores = rollout_scores / (rollout_scores.sum() + 1e-8)
        else:
            # Fallback: media su colonne
            rollout_scores = rollout.mean(dim=0)
            rollout_scores = rollout_scores / (rollout_scores.sum() + 1e-8)

        return rollout_scores.cpu().numpy(), fused_attentions

    def visualize_rollout(self, audio_path, filename, sr=16000, save_dir=None):
        """
        Visualizza Attention Rollout con overlay su waveform e spectrogram.
        """
        # Carica audio
        waveform, _ = librosa.load(audio_path, sr=sr, mono=True)

        # STEP 1: ONSET TRIMMING (se hai onsetmap)
        if hasattr(self, 'onset_map') and filename + '.wav' in self.onset_map:
            trim_amount_sec = self.onset_map[filename + '.wav']
            trim_samples = int(trim_amount_sec * sr)
            if 0 < trim_samples < len(waveform):
                waveform = waveform[trim_samples:]
                print(f"Trimmed: {trim_amount_sec:.2f}s removed")

        # STEP 2: TRUNCATION a 30s
        max_duration = 30  # Stesso di CONFIG['max_duration']
        max_samples = sr * max_duration
        original_length = len(waveform)
        if len(waveform) > max_samples:
            waveform = waveform[:max_samples]
            print(f"Truncated: {len(waveform)/sr:.2f}s (was {original_length/sr:.2f}s)")

        # STEP 3: Processor (per padding se necessario)
        from transformers import Wav2Vec2FeatureExtractor
        processor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/hubert-base-ls960')

        inputs = processor(waveform, sampling_rate=sr, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(self.device)

        # CORRETTO: Estrai waveform PROCESSATO
        waveform_processed = input_values.squeeze(0).cpu().numpy()

        # Compute rollout
        rollout_scores, _ = self.compute_rollout(input_values, head_fusion="mean")

        #
        # VISUALIZATION
        #
        fig, axes = plt.subplots(3, 1, figsize=(16, 10))

        # Asse temporale da tensor processato
        duration = input_values.shape[1] / sr
        time_axis = np.linspace(0, duration, len(rollout_scores))

        # Plot 1: Waveform con overlay attention
        waveform_time = np.linspace(0, duration, len(waveform_processed))
        axes[0].plot(waveform_time, waveform_processed, linewidth=0.5, color='black', alpha=0.7)
        axes[0].set_ylabel('Amplitude', fontsize=11)
        axes[0].set_title(f'Waveform - {filename}', fontsize=13, fontweight='bold')
        axes[0].set_xlim(0, duration)
        axes[0].grid(True, alpha=0.3)

        # Overlay attention (area colorata)
        ax_twin = axes[0].twinx()
        ax_twin.fill_between(time_axis, 0, rollout_scores, alpha=0.4, color='red', label='Attention Rollout')
        ax_twin.set_ylabel('Attention Weight', fontsize=11, color='red')
        ax_twin.tick_params(axis='y', labelcolor='red')
        ax_twin.set_ylim(0, rollout_scores.max() * 1.2)
        ax_twin.legend(loc='upper right')

        # Plot 2: Mel-spectrogram con overlay
        mel_spec = librosa.feature.melspectrogram(y=waveform_processed, sr=sr, n_mels=128, hop_length=512)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        img = librosa.display.specshow(
            mel_spec_db, sr=sr, hop_length=512,
            x_axis='time', y_axis='mel', ax=axes[1], cmap='viridis'
        )
        axes[1].set_title('Mel-Spectrogram with Attention Overlay', fontsize=13, fontweight='bold')
        fig.colorbar(img, ax=axes[1], format='%+2.0f dB')

        # Overlay attention come contorno
        rollout_resampled = np.interp(
            np.linspace(0, duration, mel_spec.shape[1]),
            time_axis,
            rollout_scores
        )

        axes[1].plot(np.linspace(0, duration, len(rollout_resampled)),
                    rollout_resampled * sr / 4,  # Scale per visualizzazione
                    color='red', linewidth=3, alpha=0.8, label='Attention Rollout')
        axes[1].legend(loc='upper right')

        # Plot 3: Attention profile temporale dettagliato
        axes[2].plot(time_axis, rollout_scores, linewidth=2, color='darkred', marker='o',
                    markersize=3, alpha=0.7)
        axes[2].fill_between(time_axis, 0, rollout_scores, alpha=0.3, color='red')
        axes[2].set_title('Attention Rollout: Temporal Importance Profile',
                         fontsize=13, fontweight='bold')
        axes[2].set_xlabel('Time (s)', fontsize=11)
        axes[2].set_ylabel('Attention Weight', fontsize=11)
        axes[2].grid(True, alpha=0.3)
        axes[2].set_xlim(0, duration)

        # Highlight top-k timesteps
        topk = 5
        top_indices = np.argsort(rollout_scores)[-topk:]
        for idx in top_indices:
            axes[2].axvline(time_axis[idx], color='orange', linestyle='--',
                          alpha=0.5, linewidth=1.5)
        axes[2].text(0.02, 0.95, f'Top-{topk} timesteps highlighted',
                    transform=axes[2].transAxes, fontsize=10,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        plt.tight_layout()

        # Save
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f'Rollout_{filename}.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"   Salvato: {save_path}")

        #
        # METRICHE PER INTERPRETAZIONE
        #
        metrics = {
            'entropy': stats.entropy(rollout_scores + 1e-8),  # Dispersione attention
            'max_attention_time': time_axis[np.argmax(rollout_scores)],
            'top5_times': time_axis[top_indices].tolist(),
            'attention_concentration': rollout_scores.max() / rollout_scores.mean(),
        }

        return fig, rollout_scores, metrics

print('AttentionRollout definito')


## Recupero validation set fold 4

In [None]:
# ============================================================
# CONFIGURAZIONE PATH OUTPUT
# ============================================================
# Verifica che paths sia già definito (dal setup esperimento)
# Se non esiste, lo crea ora
# ============================================================

if RUN_XAI:
    # Path del modello fold 4
    MODEL_PATH = os.path.join(paths['models'], 'final_model', 'best_model_fold4.pth')

    # Directory output per XAI (sottocartella di results)
    OUTPUT_DIR = os.path.join(paths['results'], 'explainability')

    # Crea directory se non esiste
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print("="*60)
    print("CONFIGURAZIONE PATH XAI")
    print("="*60)
    print(f"Modello: {MODEL_PATH}")
    print(f"Output:  {OUTPUT_DIR}")
    print("="*60)
    print()

    # Verifica che il modello esista
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Modello non trovato: {MODEL_PATH}")

    print("✓ Path verificati\n")
else:
    print("- Skipping XAI paths setup: RUN_XAI = False")

In [None]:
# ============================================================
# LOAD MODEL + GET PREDICTIONS ON VALIDATION SET FOLD
# ============================================================
#
#
# - Il modello selezionato DEVE usare il suo validation fold
# ============================================================

if RUN_XAI:
    # ====================
    # STEP 1: Ottieni predizioni del modello su validation set
    # ====================


    def get_predictions_on_valset(model, val_df, processor, device, config):
        """
        Ottieni predizioni del modello su validation set.

        Returns:
            val_df con colonne aggiunte: pred_label, confidence, correct
        """
        model.eval()

        predictions = []
        confidences = []

        for idx, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Predizioni validation"):
            audio_path = os.path.join(config['audio_dir'], row['FileName'])

            # Carica audio
            waveform, _ = librosa.load(audio_path, sr=16000, mono=True)

            # Applica stesso preprocessing di SpeechDataset

            # Estrai filename PRIMA di usarlo
            filename = row['FileName']
            audio_path = os.path.join(config['audio_dir'], filename)

            # STEP 1: ONSET TRIMMING (se hai onsetmap)
            if onset_map is not None and filename in onset_map:
                trim_amount_sec = onset_map[filename]
                trim_samples = int(trim_amount_sec * 16000)
                if 0 < trim_samples < len(waveform):
                    waveform = waveform[trim_samples:]
                    print(f"   🔧 Trimmed: {trim_amount_sec:.2f}s removed")

            # STEP 2: TRUNCATION a 30s
            max_duration = 30  # Stesso di CONFIG['max_duration']
            max_samples = 16000 * max_duration
            original_length = len(waveform)
            if len(waveform) > max_samples:
                waveform = waveform[:max_samples]
                print(f"   ✂️  Truncated: {len(waveform)/16000:.2f}s (was {original_length/16000:.2f}s)")

            # STEP 3: Processor
            from transformers import Wav2Vec2FeatureExtractor
            processor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/hubert-base-ls960')

            inputs = processor(
                waveform,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            )
            input_values = inputs.input_values.to(device)

            # Predizione
            with torch.no_grad():
                logits = model(inputs['input_values'].to(device))
                probs = F.softmax(logits, dim=1)
                pred_label = logits.argmax(dim=1).item()
                confidence = probs[0, pred_label].item()

            predictions.append(pred_label)
            confidences.append(confidence)

        val_df['pred_label'] = predictions
        val_df['confidence'] = confidences
        val_df['correct'] = (val_df['label'] == val_df['pred_label'])

        return val_df


    # ====================
    # Carica modello
    # ====================
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = HuBERTClassifierExplainable(
        model_name=CONFIG['model_name'],
        freeze_layers=CONFIG['freeze_layers'],
        hidden_dim=CONFIG['hidden_dim'],
        dropout=CONFIG['dropout'],
        num_classes=CONFIG['num_classes'],
        pooling_type=CONFIG['pooling_type']
    ).to(device)

    MODEL_PATH = os.path.join(paths['models'], 'final_model', 'best_model_fold4.pth')

    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"✓ Modello caricato: {MODEL_PATH}")

    # ====================
    # Recupera validation set FOLD 4
    # ====================
    # Usa StratifiedKFold con STESSO seed per ricostruire split identico

    from sklearn.model_selection import StratifiedKFold

    skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])

    # Itera fino a fold 4
    for fold_num, (train_idx, val_idx) in enumerate(skf.split(df, df['label']), 1):
        if fold_num == 4:
            val_df_fold4 = df.iloc[val_idx].reset_index(drop=True)
            break

    print(f"\n✓ Validation set Fold 4 recuperato:")
    print(f"   Totale campioni: {len(val_df_fold4)}")
    print(f"   Controlli: {(val_df_fold4['label']==0).sum()}")
    print(f"   Pazienti: {(val_df_fold4['label']==1).sum()}")

    # ====================
    # Ottieni predizioni
    # ====================
    val_df_fold4 = get_predictions_on_valset(model, val_df_fold4, processor, device, CONFIG)


    # ====================
    # STEP 2: Analisi delle predizioni
    # ====================
    print("\n" + "="*60)
    print("ANALISI PREDIZIONI VALIDATION SET FOLD 4")
    print("="*60)

    print(f"\nAccuracy: {val_df_fold4['correct'].mean():.3f}")
    print(f"Confidenza media: {val_df_fold4['confidence'].mean():.3f}")

    print("\n--- BREAKDOWN PER CLASSE ---")
    for label in [0, 1]:
        label_name = "Controllo" if label == 0 else "Paziente"
        subset = val_df_fold4[val_df_fold4['label'] == label]
        correct = subset['correct'].sum()
        total = len(subset)

        print(f"\n{label_name} (n={total}):")
        print(f"  Corretti: {correct}/{total} ({correct/total*100:.1f}%)")
        print(f"  Confidenza media: {subset['confidence'].mean():.3f}")


    # ====================
    # STEP 3: Selezione strategica campioni per XAI
    # ====================
    print("\n" + "="*60)
    print("SELEZIONE CAMPIONI PER XAI")
    print("="*60)

    # Strategia:
    # - 2 Controlli corretti (alta confidenza)
    # - 2 Pazienti corretti (alta confidenza)
    # - 1 Errore (qualunque classe, per failure mode analysis)

    selected_samples = []

    # 1. Controlli corretti (top 2 per confidenza)
    controls_correct = val_df_fold4[
        (val_df_fold4['label'] == 0) & (val_df_fold4['correct'] == True)
    ].nlargest(2, 'confidence')

    for idx, row in controls_correct.iterrows():
        selected_samples.append({
            'path': os.path.join(CONFIG['audio_dir'], row['FileName']),
            'label': int(row['label']),
            'pred_label': int(row['pred_label']),
            'confidence': row['confidence'],
            'id': f"Control_correct_{row['FileName'].replace('.wav', '')}",
            'filename': row['FileName']
        })

    print(f"\n✓ Selezionati {len(controls_correct)} Controlli corretti")
    for s in selected_samples[-2:]:
        print(f"   - {s['filename']} (conf: {s['confidence']:.3f})")

    # 2. Pazienti corretti (top 2 per confidenza)
    patients_correct = val_df_fold4[
        (val_df_fold4['label'] == 1) & (val_df_fold4['correct'] == True)
    ].nlargest(2, 'confidence')

    for idx, row in patients_correct.iterrows():
        selected_samples.append({
            'path': os.path.join(CONFIG['audio_dir'], row['FileName']),
            'label': int(row['label']),
            'pred_label': int(row['pred_label']),
            'confidence': row['confidence'],
            'id': f"Patient_correct_{row['FileName'].replace('.wav', '')}",
            'filename': row['FileName']
        })

    print(f"\n✓ Selezionati {len(patients_correct)} Pazienti corretti")
    for s in selected_samples[-2:]:
        print(f"   - {s['filename']} (conf: {s['confidence']:.3f})")

    # 3. Errori (max 1, quello con confidenza più alta = errore più "sicuro")
    errors = val_df_fold4[val_df_fold4['correct'] == False]

    if len(errors) > 0:
        error_sample = errors.nlargest(1, 'confidence').iloc[0]
        selected_samples.append({
            'path': os.path.join(CONFIG['audio_dir'], error_sample['FileName']),
            'label': int(error_sample['label']),
            'pred_label': int(error_sample['pred_label']),
            'confidence': error_sample['confidence'],
            'id': f"Error_{error_sample['FileName'].replace('.wav', '')}",
            'filename': error_sample['FileName']
        })
        print(f"\n✓ Selezionato 1 errore per failure mode analysis")
        print(f"   - {error_sample['FileName']} (true: {error_sample['label']}, pred: {error_sample['pred_label']}, conf: {error_sample['confidence']:.3f})")
    else:
        print("\n⚠️ Nessun errore trovato (accuracy 100%)")

    print(f"\n{'='*60}")
    print(f"TOTALE CAMPIONI SELEZIONATI: {len(selected_samples)}")
    print(f"{'='*60}")

    # Salva selezione
    selected_df = pd.DataFrame(selected_samples)
    selected_csv_path = os.path.join(OUTPUT_DIR, 'selected_samples.csv')
    selected_df.to_csv(selected_csv_path, index=False)
    print(f"\n✓ Selezione salvata: {OUTPUT_DIR}/selected_samples.csv")
else:
    print("- Skipping XAI validation set loading: RUN_XAI = False")


## Esecuzione

In [None]:
# ============================================================
# EXPLAINABLE AI - ESECUZIONE COMPLETA
# ============================================================
# Applica 2 tecniche su 5 campioni selezionati:
# 1. Integrated Gradients
# 2. Attention Rollout
#
# ============================================================


# ============================================================
# ESECUZIONE PIPELINE
# ============================================================

if RUN_XAI:
    import gc

    print("\n" + "="*60)
    print("INIZIO ANALISI EXPLAINABLE AI")
    print("="*60)
    print(f"Campioni: {len(selected_samples)}")
    print(f"Tecniche: 2 (IG, Rollout)")
    print(f"Output: {OUTPUT_DIR}")
    print("="*60)


    # Inizializza tecniche XAI
    print("\nInizializzazione tecniche XAI...")
    ig = IntegratedGradients(model, device, onset_map=onset_map)
    rollout = AttentionRollout(model, device, onset_map=onset_map)
    print("✓ Tecniche inizializzate\n")


    # Storage risultati
    all_results = []


    # ============================================================
    # ANALISI PER CIASCUN SAMPLE
    # ============================================================


    for idx, sample in enumerate(selected_samples, 1):
        # GESTIONE MEMORIA: Clear cache prima di ogni campione
        torch.cuda.empty_cache()
        gc.collect()

        audio_path = sample['path']
        true_label = sample['label']
        pred_label = sample['pred_label']
        confidence = sample['confidence']
        sample_id = sample['id']
        filename = sample['filename']


        print(f"\n{'='*60}")
        print(f"[{idx}/{len(selected_samples)}] Analisi: {filename}")
        print(f"{'='*60}")
        print(f"Label reale: {'Paziente' if true_label==1 else 'Controllo'}")
        print(f"Predizione:  {'Paziente' if pred_label==1 else 'Controllo'}")
        print(f"Confidenza:  {confidence:.3f}")
        print(f"Corretto:    {'✓' if true_label==pred_label else '✗'}")
        print("-"*60)


        # Target class per XAI: usa predizione del modello
        target_class = pred_label


        # ============================================================
        # 1. INTEGRATED GRADIENTS
        # ============================================================
        print("\n[1/2] Integrated Gradients...")
        try:
            fig_ig, _, metrics_ig = ig.visualize_on_spectrogram(
                audio_path,
                target_class,
                filename=filename.replace('.wav', ''),
                save_dir=os.path.join(OUTPUT_DIR, 'integrated_gradients')
            )
            plt.close(fig_ig)
            print("      ✓ Completato")

            # CLEAR memoria dopo IG
            del fig_ig
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"      ✗ Errore: {str(e)}")
            metrics_ig = {'completeness_error': None, 'mean_attribution': None, 'max_attribution_time': None}

        # Clear aggressivo tra tecniche
        torch.cuda.empty_cache()
        gc.collect()


        # ============================================================
        # 2. ATTENTION ROLLOUT
        # ============================================================
        print("[2/2] Attention Rollout...")
        try:
            fig_rollout, _, metrics_rollout = rollout.visualize_rollout(
                audio_path,
                filename=filename.replace('.wav', ''),
                save_dir=os.path.join(OUTPUT_DIR, 'attention_rollout')
            )
            plt.close(fig_rollout)
            print("      ✓ Completato")

            # CLEAR memoria dopo Rollout
            del fig_rollout
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"      ✗ Errore: {str(e)}")
            metrics_rollout = {'entropy': None, 'max_attention_time': None, 'attention_concentration': None}

        # Clear aggressivo tra tecniche
        torch.cuda.empty_cache()
        gc.collect()


        # ============================================================
        # ============================================================
        # AGGREGAZIONE RISULTATI
        # ============================================================
        result = {
            'sample_id': sample_id,
            'filename': filename,
            'true_label': 'Paziente' if true_label == 1 else 'Controllo',
            'pred_label': 'Paziente' if pred_label == 1 else 'Controllo',
            'confidence': confidence,
            'correct': true_label == pred_label,


            # IG metrics
            'ig_completeness_error': metrics_ig.get('completeness_error_abs'),
            'ig_mean_attribution': metrics_ig.get('mean_attribution'),
            'ig_max_time': metrics_ig.get('max_attribution_time'),


            # Rollout metrics
            'rollout_entropy': metrics_rollout.get('entropy'),
            'rollout_max_time': metrics_rollout.get('max_attention_time'),
            'rollout_concentration': metrics_rollout.get('attention_concentration'),


            # Grad×Attn metrics
        }


        all_results.append(result)

        # CLEAR completo fine campione
        torch.cuda.empty_cache()
        gc.collect()

        print(f"\n✓ Campione {idx}/{len(selected_samples)} completato")

        # Pausa stabilizzazione memoria (2 secondi)
        import time
        time.sleep(2)


    # ============================================================
    # SALVA RISULTATI AGGREGATI
    # ============================================================
    print("\n" + "="*60)
    print("SALVATAGGIO RISULTATI")
    print("="*60)


    results_df = pd.DataFrame(all_results)
    results_csv_path = os.path.join(OUTPUT_DIR, 'xai_metrics_summary.csv')
    results_df.to_csv(results_csv_path, index=False)
    print(f"\n✓ Metriche salvate: {results_csv_path}")


    # ============================================================
    # VALIDAZIONE CLINICA: Confronto Controlli vs Pazienti
    # ============================================================
    print("\n" + "="*60)
    print("VALIDAZIONE CLINICA: Analisi Comparativa")
    print("="*60)


    # Filtra solo predizioni corrette per analisi
    correct_samples = results_df[results_df['correct'] == True]


    if len(correct_samples) > 0:
        controls = correct_samples[correct_samples['true_label'] == 'Controllo']
        patients = correct_samples[correct_samples['true_label'] == 'Paziente']


        print(f"\n📊 CONFRONTO CONTROLLI vs PAZIENTI (solo corretti)")
        print("-"*60)


        if len(controls) > 0 and len(patients) > 0:
            print(f"\nControlli (n={len(controls)}):")
            print(f"  Rollout entropy:        {controls['rollout_entropy'].mean():.3f} ± {controls['rollout_entropy'].std():.3f}")
            print(f"  Attention concentration: {controls['rollout_concentration'].mean():.3f} ± {controls['rollout_concentration'].std():.3f}")


            print(f"\nPazienti (n={len(patients)}):")
            print(f"  Rollout entropy:        {patients['rollout_entropy'].mean():.3f} ± {patients['rollout_entropy'].std():.3f}")
            print(f"  Attention concentration: {patients['rollout_concentration'].mean():.3f} ± {patients['rollout_concentration'].std():.3f}")


            # Test statistico (se abbastanza campioni)
            if len(controls) >= 5 and len(patients) >= 5:
                from scipy.stats import mannwhitneyu


                try:
                    stat, p_value = mannwhitneyu(
                        controls['rollout_entropy'].dropna(),
                        patients['rollout_entropy'].dropna(),
                        alternative='two-sided'
                    )
                    print(f"\nMann-Whitney U test (entropy):")
                    print(f"  p-value = {p_value:.4f} {'✓ significativo (p<0.05)' if p_value < 0.05 else '(non significativo)'}")
                except:
                    print("\nTest statistico non eseguibile (campioni insufficienti)")


    # ============================================================
    # ANALISI FAILURE MODE (se presente errore)
    # ============================================================
    errors = results_df[results_df['correct'] == False]


    if len(errors) > 0:
        print("\n" + "="*60)
        print("FAILURE MODE ANALYSIS")
        print("="*60)


        for idx, row in errors.iterrows():
            print(f"\n Errore: {row['filename']}")
            print(f"   True: {row['true_label']}, Pred: {row['pred_label']}, Conf: {row['confidence']:.3f}")
            print(f"\n   Osservazioni XAI:")
            print(f"   - Entropy: {row['rollout_entropy']:.3f}")
            print(f"   - Concentration: {row['rollout_concentration']:.3f}")
            print(f"\n   → Controllare visualizzazioni per identificare cause dell'errore")


    # ============================================================
    # RIEPILOGO FINALE
    # ============================================================
    print("\n" + "="*60)
    print("✓ ANALISI XAI COMPLETATA")
    print("="*60)


    print(f"\n📁 STRUTTURA OUTPUT:")
    print(f"   {OUTPUT_DIR}/")
    print(f"   ├── integrated_gradients/    ({len(selected_samples)} visualizzazioni)")
    print(f"   ├── attention_rollout/       ({len(selected_samples)} visualizzazioni)")
    print(f"   ├── xai_metrics_summary.csv")
    print(f"   └── selected_samples.csv")


    print(f"\n TOTALE VISUALIZZAZIONI: {len(selected_samples) * 2}")
    print(f"\n✓ Tutti i file salvati in: {OUTPUT_DIR}")


    print("\n" + "="*60)
    print("PROSSIMO STEP: Verifica visualizzazioni e crea slide")
    print("="*60)
else:
    print("- Skipping XAI execution: RUN_XAI = False")


In [None]:
# ============================================================
# CONSENSO INTER-METODO: Correlation Analysis
# ============================================================
# Verifica concordanza tra IG e Rollout
# ============================================================

if RUN_XAI:
    import numpy as np
    from scipy.stats import pearsonr, spearmanr
    import matplotlib.pyplot as plt
    import seaborn as sns

    print("="*60)
    print("CONSENSO INTER-METODO: Correlation Analysis")
    print("="*60)

    # Carica metriche salvate
    metrics_df = pd.read_csv(os.path.join(OUTPUT_DIR, 'xai_metrics_summary.csv'))

    # Filtra solo campioni con entrambi i metodi riusciti
    valid_samples = metrics_df[
        metrics_df['ig_max_time'].notna() &
        metrics_df['rollout_entropy'].notna()
    ]

    print(f"\nCampioni con entrambi i metodi: {len(valid_samples)}/{len(metrics_df)}")

    if len(valid_samples) >= 3:
        # ============================================================
        # METRICA 1: Concordanza sui Peak Times
        # ============================================================
        print("\n" + "-"*60)
        print("CONCORDANZA TEMPORALE (Peak Times)")
        print("-"*60)

        # Peak times (normalizzate per durata)
        peak_times = {
            'IG': valid_samples['ig_max_time'].values,
            'Rollout': valid_samples['rollout_max_time'].values,
        }

        # Correlation matrix
        methods = ['IG', 'Rollout']
        corr_matrix = np.zeros((2, 2))

        for i, m1 in enumerate(methods):
            for j, m2 in enumerate(methods):
                if i <= j:
                    if i == j:
                        corr_matrix[i,j] = 1.0
                    else:
                        corr, pval = pearsonr(peak_times[m1], peak_times[m2])
                        corr_matrix[i,j] = corr
                        corr_matrix[j,i] = corr

                        sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else "ns"
                        print(f"{m1} vs {m2}: r = {corr:.3f} (p = {pval:.4f}) {sig}")

        # Stampa correlazione
        corr_ig_rollout = corr_matrix[0, 1]
        print(f"\nCorrelazione IG vs Rollout: r = {corr_ig_rollout:.3f}")

        # ============================================================
        # METRICA 2: Consenso su campioni corretti vs errori
        # ============================================================
        print("\n" + "-"*60)
        print("PATTERN SU CORRETTI vs ERRORI")
        print("-"*60)

        correct_samples = valid_samples[valid_samples['correct'] == True]
        error_samples = valid_samples[valid_samples['correct'] == False]

        if len(correct_samples) > 0:
            print(f"\nCorretti (n={len(correct_samples)}):")
            print(f"  IG max time:         {correct_samples['ig_max_time'].mean():.2f}s ± {correct_samples['ig_max_time'].std():.2f}s")
            print(f"  Rollout entropy:     {correct_samples['rollout_entropy'].mean():.3f} ± {correct_samples['rollout_entropy'].std():.3f}")

        if len(error_samples) > 0:
            print(f"\nErrori (n={len(error_samples)}):")
            print(f"  IG max time:         {error_samples['ig_max_time'].mean():.2f}s ± {error_samples['ig_max_time'].std():.2f}s")
            print(f"  Rollout entropy:     {error_samples['rollout_entropy'].mean():.3f} ± {error_samples['rollout_entropy'].std():.3f}")

        # ============================================================
        # METRICA 3: Consenso Controlli vs Pazienti
        # ============================================================
        print("\n" + "-"*60)
        print("PATTERN CONTROLLI vs PAZIENTI (solo corretti)")
        print("-"*60)

        controls = correct_samples[correct_samples['true_label'] == 'Controllo']
        patients = correct_samples[correct_samples['true_label'] == 'Paziente']

        if len(controls) > 0 and len(patients) > 0:
            print(f"\nControlli (n={len(controls)}):")
            print(f"  Rollout entropy:    {controls['rollout_entropy'].mean():.3f}")
            print(f"  Rollout peak time:  {controls['rollout_max_time'].mean():.2f}s")

            print(f"\nPazienti (n={len(patients)}):")
            print(f"  Rollout entropy:    {patients['rollout_entropy'].mean():.3f}")
            print(f"  Rollout peak time:  {patients['rollout_max_time'].mean():.2f}s")

            # Differenza percentuale
            entropy_diff = abs(patients['rollout_entropy'].mean() - controls['rollout_entropy'].mean())
            time_diff = abs(patients['rollout_max_time'].mean() - controls['rollout_max_time'].mean())

            print(f"\nDifferenze:")
            print(f"  Entropy: {entropy_diff:.3f} ({entropy_diff/controls['rollout_entropy'].mean()*100:.1f}%)")
            print(f"  Peak time: {time_diff:.2f}s")

        # ============================================================
        # INTERPRETAZIONE
        # ============================================================
        print("\n" + "="*60)
        print("INTERPRETAZIONE")
        print("="*60)

        # Correlazione (ora è un singolo valore, non media)
        avg_corr = corr_matrix[0,1]

        print(f"\nCorrelazione inter-metodo: {avg_corr:.3f}")

        if avg_corr > 0.7:
            print("✓ ALTA concordanza tra metodi (r > 0.7)")
            print("  → I 2 metodi identificano regioni temporali simili")
            print("  → Maggior affidabilità delle attribution")
        elif avg_corr > 0.4:
            print("◐ MEDIA concordanza tra metodi (0.4 < r < 0.7)")
            print("  → I metodi catturano aspetti parzialmente sovrapposti")
            print("  → Interpretare con cautela le differenze")
        else:
            print("    BASSA concordanza tra metodi (r < 0.4)")
            print("  → I metodi catturano aspetti diversi")

        print("\n" + "="*60)

    else:
        print("\n    Troppi pochi campioni validi per correlation analysis")
        print(f"   Serve almeno 3 campioni, trovati: {len(valid_samples)}")
else:
    print("- Skipping XAI consensus analysis: RUN_XAI = False")