# üß† WavLM Weighted Layer Sum Training

Questo notebook addestra WavLM con **Weighted Layer Sum**, un'architettura SOTA che combina tutti i 25 hidden states del Transformer con pesi apprendibili.

**Vantaggi:**
- Layer bassi: informazioni acustiche (formanti, pitch)
- Layer alti: informazioni fonetiche/semantiche
- Pesi apprendibili: il modello impara la combinazione ottimale

## 1. Setup Ambiente

In [None]:
# 1.1 Verifica GPU
!nvidia-smi

import torch
print(f"\n{'='*50}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA disponibile: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# 1.2 Monta Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Drive montato")

In [None]:
# 1.3 Estrai progetto da zip
import os
import zipfile
from pathlib import Path

ZIP_PATH = '/content/drive/MyDrive/phonemeRef.zip'
EXTRACT_PATH = '/content/DeepLearning-Phoneme'

if not os.path.exists(ZIP_PATH):
    raise FileNotFoundError(f"‚ùå File non trovato: {ZIP_PATH}\nCarica phonemeRef.zip su Google Drive")

print(f"üì¶ Estrazione {ZIP_PATH}...")
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall('/content/')

# Trova cartella estratta
extracted = [f for f in os.listdir('/content/') if os.path.isdir(f'/content/{f}') and 'Phoneme' in f]
if extracted:
    EXTRACT_PATH = f'/content/{extracted[0]}'

os.chdir(EXTRACT_PATH)
print(f"‚úÖ Progetto in: {EXTRACT_PATH}")
!ls -la

In [None]:
# 1.4 Installa dipendenze
!pip install -q transformers datasets evaluate jiwer accelerate soundfile librosa pyyaml tqdm audiomentations
!pip install -q torchcodec
print("\n‚úÖ Dipendenze installate")

## 2. Preparazione Dataset

In [None]:
# 2.1 Carica e analizza dataset
import pandas as pd
from pathlib import Path

# Opzioni dataset
DATASET_OPTIONS = [
    'data/processed/combined_augmented.csv',
    'data/processed/combined_dataset.csv',
    'data/processed/phonemeref_processed.csv',
]

DATASET_CSV = None
for opt in DATASET_OPTIONS:
    if Path(opt).exists():
        DATASET_CSV = opt
        break

if not DATASET_CSV:
    raise FileNotFoundError("‚ùå Nessun dataset trovato!")

df = pd.read_csv(DATASET_CSV)
print(f"üìä Dataset: {DATASET_CSV}")
print(f"   Samples: {len(df):,}")
print(f"\n=== Distribuzione ===")
if 'source' in df.columns:
    print(df['source'].value_counts())
if 'is_correct' in df.columns:
    print(f"\n=== Corretti vs Errori ===")
    print(df['is_correct'].value_counts())

In [None]:
# 2.2 Verifica qualit√† IPA (cerca placeholder invalidi E annotazioni)
import pandas as pd
import json
import re

df = pd.read_csv(DATASET_CSV)

# 1. Cerca IPA invalidi (placeholder [word])
placeholder_mask = df['ipa_clean'].str.contains(r'^\[.*\]$', regex=True, na=False)

# 2. Cerca annotazioni problematiche (adj., n., v., etc.)
annotation_mask = df['ipa_clean'].str.contains(
    r'adj\.|n\.|v\.|adv\.|interj\.|for \d|unstressed|stressed|esp\.|also|Brit\.|;',
    regex=True, na=False
)

# 3. IPA troppo corti (< 2 caratteri)
short_mask = df['ipa_clean'].str.len() < 2

invalid_mask = placeholder_mask | annotation_mask | short_mask
invalid_count = invalid_mask.sum()

print(f"üîç Analisi qualit√† IPA:")
print(f"   Totale samples: {len(df):,}")
print(f"   IPA placeholder [word]: {placeholder_mask.sum():,}")
print(f"   IPA con annotazioni (adj., v., etc.): {annotation_mask.sum():,}")
print(f"   IPA troppo corti (<2): {short_mask.sum():,}")
print(f"   Totale invalidi: {invalid_count:,} ({100*invalid_count/len(df):.1f}%)")

if invalid_count > 0:
    print(f"\n‚ö†Ô∏è ATTENZIONE: {invalid_count} samples hanno IPA problematici!")
    
    # Mostra esempi
    print("\n   Esempi di IPA invalidi:")
    examples = df[invalid_mask][['word', 'ipa_clean']].head(10)
    for _, row in examples.iterrows():
        print(f"   - {row['word']}: '{row['ipa_clean']}'")
    
    # Rimuovi invalidi
    df_clean = df[~invalid_mask].copy()
    DATASET_CLEAN = 'data/processed/phonemeref_clean.csv'
    df_clean.to_csv(DATASET_CLEAN, index=False)
    print(f"\n‚úÖ Dataset pulito salvato: {DATASET_CLEAN}")
    print(f"   Samples validi: {len(df_clean):,}")
    DATASET_CSV = DATASET_CLEAN
else:
    print("\n‚úÖ Tutti gli IPA sono validi!")

In [None]:
# 2.3 Fix path e rimuovi file mancanti
import pandas as pd
from pathlib import Path
from tqdm import tqdm

df = pd.read_csv(DATASET_CSV)

def fix_path(path_str):
    """Converte path Windows in path Colab."""
    path_str = str(path_str).replace('\\', '/')
    
    # Se √® gi√† un path relativo corretto (data/...), usalo
    if path_str.startswith('data/'):
        return path_str
    
    # Se inizia con 'audio/' (path relativo senza prefisso)
    if path_str.startswith('audio/'):
        return 'data/raw/phonemeref_data/' + path_str
    
    # Se contiene 'audio/' ma non 'data/', aggiungi il prefisso corretto
    if '/audio/' in path_str:
        idx = path_str.find('/audio/')
        return 'data/raw/phonemeref_data' + path_str[idx:]
    
    # Se contiene path Windows assoluto con 'data/'
    if 'data/' in path_str:
        idx = path_str.find('data/')
        return path_str[idx:]
    
    return path_str

# Fix path
df['audio_path'] = df['audio_path'].apply(fix_path)

# === RIMUOVI FILE MANCANTI ===
print("üîç Verifica esistenza file audio...")
missing_files = []
existing_mask = []

for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking files"):
    exists = Path(row['audio_path']).exists()
    existing_mask.append(exists)
    if not exists:
        missing_files.append((row.get('word', '?'), row['audio_path']))

existing_mask = pd.Series(existing_mask, index=df.index)
n_missing = len(missing_files)
n_total = len(df)

print(f"\nüìä Risultato verifica:")
print(f"   Totale samples: {n_total:,}")
print(f"   File esistenti: {n_total - n_missing:,} ({100*(n_total-n_missing)/n_total:.1f}%)")
print(f"   File mancanti: {n_missing:,} ({100*n_missing/n_total:.1f}%)")

if n_missing > 0:
    print(f"\n‚ö†Ô∏è Esempi file mancanti:")
    for word, path in missing_files[:10]:
        print(f"   ‚ùå {word}: {path}")
    
    # Rimuovi file mancanti
    df_clean = df[existing_mask].copy()
    print(f"\n‚úÖ Rimossi {n_missing} samples con file mancanti")
    print(f"   Dataset finale: {len(df_clean):,} samples")
    df = df_clean
else:
    print("\n‚úÖ Tutti i file audio esistono!")

# Verifica distribuzione finale
if 'source' in df.columns:
    print(f"\nüìä Distribuzione finale:")
    print(df['source'].value_counts())

# Salva
DATASET_FINAL = 'data/processed/phonemeref_ready.csv'
df.to_csv(DATASET_FINAL, index=False)
print(f"\n‚úÖ Dataset pronto: {DATASET_FINAL}")
DATASET_CSV = DATASET_FINAL

In [None]:
# 2.4 Verifica vocab.json
import json
from pathlib import Path

vocab_path = Path('data/processed/vocab.json')
if vocab_path.exists():
    with open(vocab_path, encoding='utf-8') as f:
        vocab = json.load(f)
    
    print(f"üìä Vocab: {len(vocab)} simboli")
    
    # Caratteri speciali attesi
    special = ['[PAD]', '[UNK]', '|']
    
    # Caratteri non-IPA problematici
    non_ipa = []
    ipa_chars = []
    for char in vocab.keys():
        if char in special:
            continue
        if len(char) == 1 and char.isalpha() and not char.isascii():
            ipa_chars.append(char)
        elif char in ['Àà', 'Àå', 'Àê', ' ≥', "'", '-', ' ']:  # Accenti e simboli IPA
            ipa_chars.append(char)
        elif char.lower() in 'abcdefghijklmnopqrstuvwxyz':  # Lettere ASCII (ok per IPA)
            ipa_chars.append(char)
        else:
            non_ipa.append(char)
    
    print(f"\n   Caratteri speciali: {special}")
    print(f"   Caratteri IPA: {len(ipa_chars)}")
    
    if non_ipa:
        print(f"\n   ‚ö†Ô∏è Caratteri sospetti: {non_ipa}")
    else:
        print(f"\n   ‚úÖ Tutti i caratteri sembrano IPA validi")
    
    print(f"\n   Esempio simboli: {list(vocab.keys())[3:15]}...")
else:
    raise FileNotFoundError("‚ùå vocab.json non trovato!")

## 3. Configurazione Training

In [None]:
# 3.1 Configurazione (ottimizzata per Tesla T4)
import yaml
import os

# === CONFIGURAZIONE PRINCIPALE ===
DRIVE_OUTPUT_DIR = '/content/drive/MyDrive/phoneme_wavlm_weighted'

config = {
    'seed': 42,
    'model': {
        'name': 'microsoft/wavlm-large',
        'freeze_feature_encoder': True
    },
    'data': {
        'csv_path': DATASET_CSV,
        'vocab_path': 'data/processed/vocab.json',
        'audio_base_path': '.',
        'val_size': 0.05,
        'test_size': 0.05,
        'sampling_rate': 16000
    },
    'training': {
        'output_dir': DRIVE_OUTPUT_DIR,
        'num_train_epochs': 10,
        'per_device_train_batch_size': 8,
        'per_device_eval_batch_size': 8,
        'gradient_accumulation_steps': 2,
        'dataloader_num_workers': 0,
        'dataloader_pin_memory': False,
        'learning_rate': 3e-5,
        'warmup_steps': 500,
        'weight_decay': 0.01,
        'optim': 'adamw_torch',
        'max_grad_norm': 1.0,
        'fp16': True,
        'bf16': False,
        'eval_strategy': 'epoch',
        'save_strategy': 'epoch',
        'save_total_limit': 3,
        'load_best_model_at_end': True,
        'metric_for_best_model': 'per',
        'greater_is_better': False,
        'logging_steps': 100,
        'disable_tqdm': False,
        'group_by_length': True,
    }
}

os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)

# Salva config
with open('configs/training_config_weighted.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("="*60)
print("üìã CONFIGURAZIONE WAVLM WEIGHTED (LARGE)")
print("="*60)
print(f"üìÅ Output: {DRIVE_OUTPUT_DIR}")
print(f"üìä Dataset: {DATASET_CSV}")
print(f"üî¢ Epochs: {config['training']['num_train_epochs']}")
print(f"üì¶ Batch: {config['training']['per_device_train_batch_size']} x {config['training']['gradient_accumulation_steps']}")
print(f"üìà LR: {config['training']['learning_rate']}")
print("="*60)

In [None]:
# 3.2 Verifica checkpoint esistenti
from pathlib import Path
import json

output_dir = Path(DRIVE_OUTPUT_DIR)
checkpoints = []

if output_dir.exists():
    checkpoints = sorted([
        d for d in output_dir.iterdir() 
        if d.is_dir() and d.name.startswith("checkpoint-")
    ])

print(f"üìÅ Output: {output_dir}")
print("-"*50)

if checkpoints:
    print(f"‚úÖ {len(checkpoints)} checkpoint trovati:")
    
    last_epoch = 0
    best_per = None
    
    for cp in checkpoints[-3:]:
        state_file = cp / "trainer_state.json"
        if state_file.exists():
            with open(state_file) as f:
                state = json.load(f)
            epoch = state.get('epoch', 0)
            step = state.get('global_step', 0)
            best = state.get('best_metric', None)
            
            last_epoch = max(last_epoch, epoch)
            if best:
                best_per = best
            
            info = f"Epoch {epoch:.1f}, Step {step}"
            if best:
                info += f", Best PER: {best:.4f}"
            print(f"   üìÅ {cp.name}: {info}")
    
    target_epochs = config['training']['num_train_epochs']
    if last_epoch >= target_epochs:
        print(f"\n‚ö†Ô∏è TRAINING GI√Ä COMPLETATO! (epoch {last_epoch} >= {target_epochs})")
    else:
        print(f"\n‚úÖ Training pu√≤ continuare per {target_epochs - last_epoch:.0f} epoche")
else:
    print("‚ùå Nessun checkpoint - Training partir√† da zero")

## 4. Training

In [None]:
# 4.1 Avvia Training con script train_weighted.py
import os
from pathlib import Path

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# === OPZIONI ===
RESUME = "auto"

drive_path = Path(DRIVE_OUTPUT_DIR)
existing_checkpoints = []
if drive_path.exists():
    existing_checkpoints = sorted([
        d for d in drive_path.iterdir() 
        if d.is_dir() and d.name.startswith("checkpoint-")
    ])

if RESUME == "auto":
    do_resume = len(existing_checkpoints) > 0
else:
    do_resume = bool(RESUME)

print("="*60)
print("üöÄ AVVIO TRAINING WAVLM WEIGHTED (LARGE)")
print("="*60)
print(f"üìä Dataset: {DATASET_CSV}")
print(f"üìÅ Output: {DRIVE_OUTPUT_DIR}")
print(f"üîÑ Resume: {do_resume}")
print("="*60)

# Comando
cmd = f"python scripts/training/train_weighted.py --config configs/training_config_weighted.yaml --data-csv {DATASET_CSV}"
if do_resume:
    cmd += " --resume"

!{cmd}

## 5. Valutazione

In [None]:
# 5.1 Visualizza curve di training
import json
import matplotlib.pyplot as plt
from pathlib import Path

# Trova trainer_state.json
state_path = None
for loc in [
    Path(DRIVE_OUTPUT_DIR) / 'final_model_weighted' / 'trainer_state.json',
    Path(DRIVE_OUTPUT_DIR) / 'trainer_state.json',
]:
    if loc.exists():
        state_path = loc
        break

# Cerca anche nell'ultimo checkpoint
if not state_path:
    checkpoints = sorted([
        d for d in Path(DRIVE_OUTPUT_DIR).iterdir() 
        if d.is_dir() and d.name.startswith("checkpoint-")
    ]) if Path(DRIVE_OUTPUT_DIR).exists() else []
    if checkpoints:
        state_path = checkpoints[-1] / 'trainer_state.json'

if state_path and state_path.exists():
    with open(state_path) as f:
        state = json.load(f)
    
    log_history = state.get('log_history', [])
    
    # Estrai metriche
    train_loss = [(h['step'], h['loss']) for h in log_history if 'loss' in h and 'eval_loss' not in h]
    eval_loss = [(h['step'], h['eval_loss']) for h in log_history if 'eval_loss' in h]
    eval_per = [(h['step'], h['eval_per']) for h in log_history if 'eval_per' in h]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    if train_loss:
        steps, losses = zip(*train_loss)
        axes[0].plot(steps, losses, 'b-', alpha=0.7)
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, alpha=0.3)
    
    if eval_loss:
        steps, losses = zip(*eval_loss)
        axes[1].plot(steps, losses, 'r-o')
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Eval Loss')
        axes[1].set_title('Validation Loss')
        axes[1].grid(True, alpha=0.3)
    
    if eval_per:
        steps, pers = zip(*eval_per)
        axes[2].plot(steps, [p*100 for p in pers], 'g-o')
        axes[2].set_xlabel('Step')
        axes[2].set_ylabel('PER (%)')
        axes[2].set_title('Phoneme Error Rate')
        axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{DRIVE_OUTPUT_DIR}/training_curves.png', dpi=150)
    plt.show()
    
    if eval_per:
        best_per = min(pers)
        print(f"\nüèÜ Migliore PER: {best_per*100:.2f}%")
else:
    print("‚ùå trainer_state.json non trovato - training non ancora completato?")

In [None]:
# 5.2 Valutazione su SpeechOcean762
MODEL_PATH = f"{DRIVE_OUTPUT_DIR}/final_model_weighted"

if Path(MODEL_PATH).exists():
    print(f"üî¨ Valutazione modello: {MODEL_PATH}")
    !python scripts/evaluation/evaluate_speechocean.py --model-path {MODEL_PATH}
else:
    print(f"‚ö†Ô∏è Modello non trovato: {MODEL_PATH}")
    print("   Esegui prima il training!")

In [None]:
# 5.3 Analisi Layer Weights (quali layer sono pi√π importanti)
import torch
import torch.nn.functional as F
from pathlib import Path

MODEL_PATH = f"{DRIVE_OUTPUT_DIR}/final_model_weighted"

try:
    # Carica il modello per vedere i pesi
    checkpoint = torch.load(f"{MODEL_PATH}/pytorch_model.bin", map_location='cpu')
    
    if 'layer_weights' in checkpoint:
        weights = checkpoint['layer_weights']
        normalized = F.softmax(torch.tensor(weights), dim=0)
        
        print("üìä LAYER WEIGHTS (dopo training)")
        print("="*50)
        for i, w in enumerate(normalized):
            bar = "‚ñà" * int(w * 50)
            print(f"Layer {i:2d}: {w:.4f} {bar}")
        
        print(f"\nüìä Layer pi√π importante: {normalized.argmax().item()}")
    else:
        print("‚ö†Ô∏è layer_weights non trovato nel checkpoint")
except Exception as e:
    print(f"‚ö†Ô∏è Errore caricamento: {e}")

## 6. Salvataggio Finale

In [None]:
# 6.1 Verifica contenuto su Drive
from pathlib import Path

print("="*60)
print("üìÅ CONTENUTO SU GOOGLE DRIVE")
print("="*60)
print(f"Cartella: {DRIVE_OUTPUT_DIR}")
print("-"*60)

drive_path = Path(DRIVE_OUTPUT_DIR)
if drive_path.exists():
    for item in sorted(drive_path.iterdir()):
        if item.is_dir():
            n_files = len(list(item.rglob("*")))
            print(f"  üìÅ {item.name}/ ({n_files} files)")
        else:
            size_mb = item.stat().st_size / 1e6
            print(f"  üìÑ {item.name} ({size_mb:.1f} MB)")

    final_model = drive_path / "final_model_weighted"
    if final_model.exists():
        print("\n‚úÖ Modello finale presente!")
    else:
        print("\n‚ö†Ô∏è Modello finale non trovato")
else:
    print("‚ùå Cartella non trovata")

In [None]:
# 6.2 Crea zip per download
import os

FINAL_MODEL = f'{DRIVE_OUTPUT_DIR}/final_model_weighted'
ZIP_PATH = f'{DRIVE_OUTPUT_DIR}/final_model_weighted.zip'

if os.path.exists(FINAL_MODEL):
    !cd {FINAL_MODEL} && zip -r {ZIP_PATH} .
    print(f"\n‚úÖ Zip creato: {ZIP_PATH}")
    !ls -lh {ZIP_PATH}
else:
    print("‚ùå Modello finale non trovato")

---
## üéâ Fine

Il modello √® salvato su Google Drive:
- `final_model_weighted/` - Modello trainato
- `final_model_weighted.zip` - Per download rapido
- `training_curves.png` - Grafici
- `checkpoint-*/` - Checkpoint intermedi