#  Notebook 4: Évaluation du Modèle

Ce notebook évalue les performances du modèle VoiceGAN entraîné.

**Objectifs:**
- Charger un modèle entraîné
- Calculer métriques objectives (MCD, similarité)
- Visualiser exemples de conversion
- Analyser la qualité
- Comparaisons qualitatives

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import librosa.display
import IPython.display as ipd
from pathlib import Path
from tqdm.notebook import tqdm
import pandas as pd

from config.model_config import Config
from src.models.voicegan import VoiceGAN
from src.evaluation.metrics import VoiceConversionMetrics
from src.preprocessing.audio_processor import AudioProcessor
from src.preprocessing.mel_spectrogram import MelSpectrogramProcessor

%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Configuration

In [None]:
# Chemins
CONFIG_PATH = '../config/config.yaml'
CHECKPOINT_PATH = '../checkpoints/best_model.pt'  # Modifier selon votre checkpoint
TEST_DIR = Path('../data/test')

# Charger config
config = Config(CONFIG_PATH)

print(f"Config loaded from: {CONFIG_PATH}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Test directory: {TEST_DIR}")

## 2. Charger le Modèle

In [None]:
def load_model(checkpoint_path, config, device):
    """Charger modèle depuis checkpoint"""
    model = VoiceGAN(
        n_mels=config.audio.n_mels,
        content_channels=config.content_encoder.channels,
        transformer_dim=config.content_encoder.transformer_dim,
        num_heads=config.content_encoder.num_heads,
        num_transformer_layers=config.content_encoder.num_layers,
        style_dim=config.style_encoder.style_dim
    )
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, checkpoint

print("Loading model...")
model, checkpoint = load_model(CHECKPOINT_PATH, config, device)

print(f"\n Model loaded!")
print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
print(f"  Global step: {checkpoint.get('global_step', 'unknown')}")

## 3. Initialiser Processeurs et Métriques

In [None]:
# Audio processors
audio_processor = AudioProcessor(
    sample_rate=config.audio.sample_rate,
    segment_length=config.audio.segment_length
)

mel_processor = MelSpectrogramProcessor(
    sample_rate=config.audio.sample_rate,
    n_fft=config.audio.n_fft,
    hop_length=config.audio.hop_length,
    win_length=config.audio.win_length,
    n_mels=config.audio.n_mels,
    fmin=config.audio.fmin,
    fmax=config.audio.fmax
)

# Metrics
metrics_calculator = VoiceConversionMetrics(
    sample_rate=config.audio.sample_rate
)

print(" Processors and metrics initialized")

## 4. Évaluation sur Dataset de Test

In [None]:
def evaluate_test_set(model, test_dir, num_samples=20):
    """Évaluer sur un ensemble de test"""
    speakers = [d for d in test_dir.iterdir() if d.is_dir()]
    
    results = []
    
    # Pour chaque paire de speakers
    for source_speaker in tqdm(speakers[:2], desc="Source speakers"):  # Limiter pour demo
        for target_speaker in speakers[:2]:
            if source_speaker == target_speaker:
                continue
            
            source_files = list(source_speaker.glob('*.wav'))[:num_samples]
            target_files = list(target_speaker.glob('*.wav'))[:num_samples]
            
            for source_file in source_files[:3]:  # 3 samples par paire
                target_file = np.random.choice(target_files)
                
                try:
                    # Charger audio
                    source_audio = audio_processor.load_audio(source_file)
                    target_audio = audio_processor.load_audio(target_file)
                    
                    # Segment
                    source_audio = audio_processor.segment_audio(source_audio, random=False)
                    target_audio = audio_processor.segment_audio(target_audio, random=False)
                    
                    # Mels
                    source_mel = mel_processor.wav_to_mel(source_audio).unsqueeze(0).to(device)
                    target_mel = mel_processor.wav_to_mel(target_audio).unsqueeze(0).to(device)
                    
                    # Convert
                    with torch.no_grad():
                        converted_mel = model.convert(source_mel, target_mel)
                        
                        # Extract styles
                        converted_style = model.encode_style(converted_mel)
                        target_style = model.encode_style(target_mel)
                    
                    # Calculer métriques
                    metrics = metrics_calculator.compute_all_metrics(
                        converted_mel.squeeze(0).cpu(),
                        target_mel.squeeze(0).cpu(),
                        converted_style.squeeze(0).cpu(),
                        target_style.squeeze(0).cpu()
                    )
                    
                    results.append({
                        'source_speaker': source_speaker.name,
                        'target_speaker': target_speaker.name,
                        **metrics
                    })
                    
                except Exception as e:
                    print(f"Error processing {source_file}: {e}")
    
    return pd.DataFrame(results)

print("Evaluating on test set...")
results_df = evaluate_test_set(model, TEST_DIR, num_samples=5)

print(f"\n✅ Evaluated {len(results_df)} samples")
print("\nResults:")
print(results_df.describe())

In [None]:
# Visualiser distributions des métriques
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

metrics_to_plot = ['mcd', 'cosine_similarity', 'spectral_convergence', 'log_spectral_distance']
titles = ['MCD (lower is better)', 'Cosine Similarity (higher is better)', 
          'Spectral Convergence (lower is better)', 'Log Spectral Distance (lower is better)']

for i, (metric, title) in enumerate(zip(metrics_to_plot, titles)):
    axes[i].hist(results_df[metric], bins=20, edgecolor='black', alpha=0.7)
    axes[i].axvline(results_df[metric].mean(), color='red', linestyle='--', 
                    linewidth=2, label=f'Mean: {results_df[metric].mean():.3f}')
    axes[i].set_title(title)
    axes[i].set_xlabel('Value')
    axes[i].set_ylabel('Count')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Résumé statistique
print("\n=== METRICS SUMMARY ===")
for metric in metrics_to_plot:
    print(f"\n{metric.upper()}:")
    print(f"  Mean: {results_df[metric].mean():.4f}")
    print(f"  Std: {results_df[metric].std():.4f}")
    print(f"  Min: {results_df[metric].min():.4f}")
    print(f"  Max: {results_df[metric].max():.4f}")

## 5. Exemples Visuels de Conversion

In [None]:
def visualize_conversion(source_file, target_file, model):
    """Visualiser une conversion A→B"""
    # Charger
    source_audio = audio_processor.load_audio(source_file)
    target_audio = audio_processor.load_audio(target_file)
    
    # Segment
    source_audio = audio_processor.segment_audio(source_audio, random=False)
    target_audio = audio_processor.segment_audio(target_audio, random=False)
    
    # Mels
    source_mel = mel_processor.wav_to_mel(source_audio).unsqueeze(0).to(device)
    target_mel = mel_processor.wav_to_mel(target_audio).unsqueeze(0).to(device)
    
    # Convert
    with torch.no_grad():
        converted_mel = model.convert(source_mel, target_mel)
    
    # Convert to audio
    converted_audio = mel_processor.mel_to_wav(converted_mel.squeeze(0).cpu())
    
    # Visualiser
    fig, axes = plt.subplots(3, 2, figsize=(16, 12))
    
    # Source
    librosa.display.waveshow(source_audio.numpy(), sr=config.audio.sample_rate, ax=axes[0, 0])
    axes[0, 0].set_title(f'Source (A) - Waveform')
    axes[0, 0].set_ylabel('Amplitude')
    
    axes[0, 1].imshow(source_mel.squeeze(0).cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
    axes[0, 1].set_title('Source (A) - Mel-Spectrogram')
    axes[0, 1].set_ylabel('Mel Bin')
    
    # Target
    librosa.display.waveshow(target_audio.numpy(), sr=config.audio.sample_rate, ax=axes[1, 0])
    axes[1, 0].set_title('Target (B) - Waveform')
    axes[1, 0].set_ylabel('Amplitude')
    
    axes[1, 1].imshow(target_mel.squeeze(0).cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
    axes[1, 1].set_title('Target (B) - Mel-Spectrogram')
    axes[1, 1].set_ylabel('Mel Bin')
    
    # Converted
    librosa.display.waveshow(converted_audio.numpy(), sr=config.audio.sample_rate, ax=axes[2, 0])
    axes[2, 0].set_title('Converted (A→B) - Waveform')
    axes[2, 0].set_ylabel('Amplitude')
    axes[2, 0].set_xlabel('Time (s)')
    
    axes[2, 1].imshow(converted_mel.squeeze(0).cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
    axes[2, 1].set_title('Converted (A→B) - Mel-Spectrogram')
    axes[2, 1].set_ylabel('Mel Bin')
    axes[2, 1].set_xlabel('Time')
    
    plt.tight_layout()
    plt.show()
    
    # Audio playback
    print("\nListen to:")
    print("\n1. Source (A):")
    display(ipd.Audio(source_audio.numpy(), rate=config.audio.sample_rate))
    
    print("\n2. Target (B):")
    display(ipd.Audio(target_audio.numpy(), rate=config.audio.sample_rate))
    
    print("\n3. Converted (A→B):")
    display(ipd.Audio(converted_audio.numpy(), rate=config.audio.sample_rate))
    
    return source_mel, target_mel, converted_mel

# Sélectionner fichiers de test
speakers = [d for d in TEST_DIR.iterdir() if d.is_dir()][:2]
source_file = list(speakers[0].glob('*.wav'))[0]
target_file = list(speakers[1].glob('*.wav'))[0]

print(f"Source: {source_file.name}")
print(f"Target: {target_file.name}")
print("")

source_mel, target_mel, converted_mel = visualize_conversion(source_file, target_file, model)

## 6. Analyse par Paire de Locuteurs

In [None]:
# Grouper par paire de locuteurs
if len(results_df) > 0:
    results_df['pair'] = results_df['source_speaker'] + ' → ' + results_df['target_speaker']
    
    pair_stats = results_df.groupby('pair').agg({
        'mcd': ['mean', 'std'],
        'cosine_similarity': ['mean', 'std']
    }).round(4)
    
    print("\n=== PERFORMANCE BY SPEAKER PAIR ===")
    print(pair_stats)
    
    # Visualiser
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # MCD par paire
    pair_mcd = results_df.groupby('pair')['mcd'].mean().sort_values()
    axes[0].barh(range(len(pair_mcd)), pair_mcd.values)
    axes[0].set_yticks(range(len(pair_mcd)))
    axes[0].set_yticklabels(pair_mcd.index)
    axes[0].set_xlabel('MCD (lower is better)')
    axes[0].set_title('MCD by Speaker Pair')
    axes[0].grid(True, alpha=0.3)
    
    # Similarité par paire
    pair_sim = results_df.groupby('pair')['cosine_similarity'].mean().sort_values(ascending=False)
    axes[1].barh(range(len(pair_sim)), pair_sim.values, color='orange')
    axes[1].set_yticks(range(len(pair_sim)))
    axes[1].set_yticklabels(pair_sim.index)
    axes[1].set_xlabel('Cosine Similarity (higher is better)')
    axes[1].set_title('Style Similarity by Speaker Pair')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 7. Interprétation des Résultats

### Guides d'Interprétation:

**MCD (Mel Cepstral Distortion)**
- < 5 dB: Excellent
- 5-7 dB: Bon
- 7-10 dB: Acceptable
- > 10 dB: Nécessite amélioration

**Cosine Similarity**
- > 0.9: Excellent
- 0.8-0.9: Bon
- 0.7-0.8: Acceptable
- < 0.7: Nécessite amélioration

In [None]:
# Évaluation qualitative
def interpret_results(results_df):
    """Interpréter les résultats"""
    mean_mcd = results_df['mcd'].mean()
    mean_sim = results_df['cosine_similarity'].mean()
    
    print("=== QUALITY ASSESSMENT ===")
    print(f"\nMCD: {mean_mcd:.2f} dB")
    if mean_mcd < 5:
        print("  →  Excellent quality")
    elif mean_mcd < 7:
        print("  →  Good quality")
    elif mean_mcd < 10:
        print("  →  Acceptable quality")
    else:
        print("  →  Needs improvement")
    
    print(f"\nStyle Similarity: {mean_sim:.3f}")
    if mean_sim > 0.9:
        print("  →  Excellent style transfer")
    elif mean_sim > 0.8:
        print("  →  Good style transfer")
    elif mean_sim > 0.7:
        print("  →  Acceptable style transfer")
    else:
        print("  →  Style transfer needs improvement")
    
    print("\n=== RECOMMENDATIONS ===")
    if mean_mcd > 7:
        print("  • Increase lambda_reconstruction")
        print("  • Train for more epochs")
    if mean_sim < 0.8:
        print("  • Increase lambda_identity")
        print("  • Verify style encoder learning")

interpret_results(results_df)

## 8. Sauvegarde des Résultats

In [None]:
# Sauvegarder résultats
output_dir = Path('../outputs/evaluation')
output_dir.mkdir(parents=True, exist_ok=True)

# CSV des résultats
results_df.to_csv(output_dir / 'evaluation_results.csv', index=False)
print(f"Results saved to: {output_dir / 'evaluation_results.csv'}")

# Statistiques summary
summary = {
    'num_samples': len(results_df),
    'mean_mcd': float(results_df['mcd'].mean()),
    'std_mcd': float(results_df['mcd'].std()),
    'mean_cosine_similarity': float(results_df['cosine_similarity'].mean()),
    'std_cosine_similarity': float(results_df['cosine_similarity'].std()),
    'mean_spectral_convergence': float(results_df['spectral_convergence'].mean()),
    'mean_log_spectral_distance': float(results_df['log_spectral_distance'].mean())
}

import json
with open(output_dir / 'summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Summary saved to: {output_dir / 'summary.json'}")

print("\n Evaluation completed!")

## 9. Résumé de l'Évaluation

### Métriques Calculées 
1. MCD (Mel Cepstral Distortion)
2. Cosine Similarity
3. Spectral Convergence
4. Log Spectral Distance

### Analyses Effectuées 
1. Distribution des métriques
2. Performance par paire de locuteurs
3. Exemples visuels et auditifs
4. Interprétation qualitative

### Pour le Rapport
- Inclure les graphiques de distribution
- Présenter les exemples de conversion
- Discuter les résultats par rapport à l'état de l'art
- Mentionner les limitations et améliorations possibles