#  Notebook 3: Entraînement du Modèle

Ce notebook démontre l'entraînement du modèle VoiceGAN.

**Objectifs:**
- Tester le modèle sur des données synthétiques
- Vérifier les dimensions
- Lancer un mini-entraînement
- Analyser les pertes
- Générer des exemples

In [None]:
# Imports
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import IPython.display as ipd
from pathlib import Path

from config.model_config import Config
from src.models.voicegan import VoiceGAN
from src.training.dataset import VoiceDataset, VoiceCollator
from src.losses.losses import VoiceGANLoss
from src.preprocessing.mel_spectrogram import MelSpectrogramProcessor

%matplotlib inline

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 1. Charger Configuration

In [None]:
config = Config('../config/config.yaml')

print("Configuration:")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Learning rate G: {config.training.learning_rate_g}")
print(f"  Learning rate D: {config.training.learning_rate_d}")
print(f"  N-mels: {config.audio.n_mels}")

## 2. Test: Initialisation du Modèle

In [None]:
# Créer modèle
model = VoiceGAN(
    n_mels=config.audio.n_mels,
    content_channels=config.content_encoder.channels,
    transformer_dim=config.content_encoder.transformer_dim,
    num_heads=config.content_encoder.num_heads,
    num_transformer_layers=config.content_encoder.num_layers,
    style_dim=config.style_encoder.style_dim
).to(device)

print(" Model created")

# Compter paramètres
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

# Détails par composant
print(f"\nComponent parameters:")
print(f"  Content Encoder: {sum(p.numel() for p in model.content_encoder.parameters()):,}")
print(f"  Style Encoder: {sum(p.numel() for p in model.style_encoder.parameters()):,}")
print(f"  Generator: {sum(p.numel() for p in model.generator.parameters()):,}")
print(f"  Discriminator: {sum(p.numel() for p in model.discriminator.parameters()):,}")

## 3. Test: Forward Pass avec Données Synthétiques

In [None]:
# Créer données synthétiques
batch_size = 2
time_steps = 100

source_mel = torch.randn(batch_size, config.audio.n_mels, time_steps).to(device)
target_mel = torch.randn(batch_size, config.audio.n_mels, time_steps).to(device)

print(f"Source mel shape: {source_mel.shape}")
print(f"Target mel shape: {target_mel.shape}")

# Forward pass
with torch.no_grad():
    results = model(source_mel, target_mel)

print(f"\nOutput shapes:")
print(f"  Generated mel: {results['generated_mel'].shape}")
print(f"  Content: {results['content'].shape}")
print(f"  Style: {results['style'].shape}")

print("\n Forward pass successful!")

## 4. Test: Discriminateur

In [None]:
# Test discriminateur
with torch.no_grad():
    disc_outputs, disc_features = model.discriminate(results['generated_mel'])

print(f"Discriminator outputs (multi-scale):")
for i, output in enumerate(disc_outputs):
    print(f"  Scale {i}: {output.shape}")

print(f"\nDiscriminator features:")
for i, features in enumerate(disc_features):
    print(f"  Scale {i}: {len(features)} feature maps")

print("\n Discriminator working!")

## 5. Test: Calcul des Pertes

In [None]:
# Créer criterion
criterion = VoiceGANLoss(
    lambda_reconstruction=config.training.lambda_reconstruction,
    lambda_adversarial=config.training.lambda_adversarial,
    lambda_identity=config.training.lambda_identity,
    lambda_content=config.training.lambda_content,
    lambda_feature_matching=config.training.lambda_feature_matching
)

# Calculer pertes
with torch.no_grad():
    # Forward pass
    results = model(source_mel, target_mel)
    fake_mel = results['generated_mel']
    
    # Discriminator outputs
    disc_fake_out, disc_fake_feats = model.discriminate(fake_mel)
    disc_real_out, disc_real_feats = model.discriminate(target_mel)
    
    # Re-encode
    content_fake = model.encode_content(fake_mel)
    style_fake = model.encode_style(fake_mel)
    
    # Generator loss
    g_loss, g_loss_dict = criterion.generator_loss(
        real_mel=target_mel,
        fake_mel=fake_mel,
        disc_fake_outputs=disc_fake_out,
        disc_fake_features=disc_fake_feats,
        disc_real_features=disc_real_feats,
        content_source=results['content'],
        content_fake=content_fake,
        style_target=results['style'],
        style_fake=style_fake
    )
    
    # Discriminator loss
    d_loss, d_loss_dict = criterion.discriminator_loss(
        disc_real_out, disc_fake_out
    )

print("Generator losses:")
for key, value in g_loss_dict.items():
    print(f"  {key}: {value:.4f}")

print("\nDiscriminator losses:")
for key, value in d_loss_dict.items():
    print(f"  {key}: {value:.4f}")

print("\n Loss computation successful!")

## 6. Test: Mini-Entraînement (Overfitting Test)

In [None]:
# Charger un petit batch pour overfitting test
print("Creating mini dataset...")

# Créer dataset de test
mini_dataset = VoiceDataset(
    data_dir='../data/train',
    audio_config=config.audio.__dict__,
    segment_length=config.audio.segment_length,
    split='train'
)

# Limiter à 10 samples
mini_dataset.pairs = mini_dataset.pairs[:10]

mini_loader = DataLoader(
    mini_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=VoiceCollator()
)

print(f"Mini dataset size: {len(mini_dataset)}")

In [None]:
# Setup optimizers
optimizer_g = torch.optim.Adam(
    list(model.content_encoder.parameters()) +
    list(model.style_encoder.parameters()) +
    list(model.generator.parameters()),
    lr=config.training.learning_rate_g,
    betas=(config.training.beta1, config.training.beta2)
)

optimizer_d = torch.optim.Adam(
    model.discriminator.parameters(),
    lr=config.training.learning_rate_d,
    betas=(config.training.beta1, config.training.beta2)
)

print(" Optimizers created")

In [None]:
# Mini training loop
num_epochs = 10
losses_history = {'g_total': [], 'g_recon': [], 'd_total': []}

model.train()

print("Starting mini-training...\n")

for epoch in range(num_epochs):
    epoch_losses = {'g_total': [], 'g_recon': [], 'd_total': []}
    
    pbar = tqdm(mini_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for source_mel, target_mel in pbar:
        source_mel = source_mel.to(device)
        target_mel = target_mel.to(device)
        
        # === Train Generator ===
        optimizer_g.zero_grad()
        
        results = model(source_mel, target_mel)
        fake_mel = results['generated_mel']
        
        disc_fake_out, disc_fake_feats = model.discriminate(fake_mel)
        disc_real_out, disc_real_feats = model.discriminate(target_mel)
        
        content_fake = model.encode_content(fake_mel)
        style_fake = model.encode_style(fake_mel)
        
        g_loss, g_loss_dict = criterion.generator_loss(
            real_mel=target_mel,
            fake_mel=fake_mel,
            disc_fake_outputs=disc_fake_out,
            disc_fake_features=disc_fake_feats,
            disc_real_features=disc_real_feats,
            content_source=results['content'],
            content_fake=content_fake,
            style_target=results['style'],
            style_fake=style_fake
        )
        
        g_loss.backward()
        optimizer_g.step()
        
        # === Train Discriminator ===
        optimizer_d.zero_grad()
        
        disc_real_out, _ = model.discriminate(target_mel)
        disc_fake_out, _ = model.discriminate(fake_mel.detach())
        
        d_loss, d_loss_dict = criterion.discriminator_loss(
            disc_real_out, disc_fake_out
        )
        
        d_loss.backward()
        optimizer_d.step()
        
        # Record losses
        epoch_losses['g_total'].append(g_loss_dict['g_total'])
        epoch_losses['g_recon'].append(g_loss_dict['g_recon'])
        epoch_losses['d_total'].append(d_loss_dict['d_total'])
        
        pbar.set_postfix({
            'G': f"{g_loss_dict['g_total']:.3f}",
            'D': f"{d_loss_dict['d_total']:.3f}"
        })
    
    # Average epoch losses
    for key in epoch_losses:
        avg = np.mean(epoch_losses[key])
        losses_history[key].append(avg)
    
    print(f"Epoch {epoch+1}: G={losses_history['g_total'][-1]:.4f}, D={losses_history['d_total'][-1]:.4f}")

print("\n Mini-training completed!")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Generator losses
axes[0].plot(losses_history['g_total'], label='Total', linewidth=2)
axes[0].plot(losses_history['g_recon'], label='Reconstruction', linewidth=2)
axes[0].set_title('Generator Losses')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Discriminator loss
axes[1].plot(losses_history['d_total'], label='Discriminator', color='red', linewidth=2)
axes[1].axhline(y=0.5, color='gray', linestyle='--', label='Target (~0.5)')
axes[1].set_title('Discriminator Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nObservations:")
print(f"  Final G loss: {losses_history['g_total'][-1]:.4f}")
print(f"  Final D loss: {losses_history['d_total'][-1]:.4f}")
print(f"  G loss change: {losses_history['g_total'][0] - losses_history['g_total'][-1]:.4f}")

if losses_history['g_total'][-1] < losses_history['g_total'][0]:
    print("   Model is learning (G loss decreasing)")
else:
    print("   Check: G loss not decreasing")

## 7. Génération d'Exemples

In [None]:
# Générer exemples
model.eval()

# Prendre un batch
source_mel, target_mel = next(iter(mini_loader))
source_mel = source_mel.to(device)
target_mel = target_mel.to(device)

with torch.no_grad():
    converted_mel = model.convert(source_mel, target_mel)

# Visualiser
idx = 0  # Premier échantillon du batch

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Source
im0 = axes[0].imshow(source_mel[idx].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title('Source (A)')
axes[0].set_ylabel('Mel Bin')
axes[0].set_xlabel('Time')
plt.colorbar(im0, ax=axes[0])

# Target
im1 = axes[1].imshow(target_mel[idx].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[1].set_title('Target (B)')
axes[1].set_ylabel('Mel Bin')
axes[1].set_xlabel('Time')
plt.colorbar(im1, ax=axes[1])

# Converted
im2 = axes[2].imshow(converted_mel[idx].cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[2].set_title('Converted (A→B)')
axes[2].set_ylabel('Mel Bin')
axes[2].set_xlabel('Time')
plt.colorbar(im2, ax=axes[2])

plt.tight_layout()
plt.show()

print("\nNote: After only 10 epochs, the conversion won't be perfect.")
print("This is expected. Full training requires 50-100+ epochs.")

## 8. Conversion en Audio (Griffin-Lim)

In [None]:
# Initialiser mel processor
mel_processor = MelSpectrogramProcessor(
    sample_rate=config.audio.sample_rate,
    n_fft=config.audio.n_fft,
    hop_length=config.audio.hop_length,
    win_length=config.audio.win_length,
    n_mels=config.audio.n_mels,
    fmin=config.audio.fmin,
    fmax=config.audio.fmax
)

# Convertir mels en audio
print("Converting to audio...")
source_audio = mel_processor.mel_to_wav(source_mel[idx].cpu())
target_audio = mel_processor.mel_to_wav(target_mel[idx].cpu())
converted_audio = mel_processor.mel_to_wav(converted_mel[idx].cpu())

print("\nSource audio (A):")
display(ipd.Audio(source_audio.numpy(), rate=config.audio.sample_rate))

print("\nTarget audio (B):")
display(ipd.Audio(target_audio.numpy(), rate=config.audio.sample_rate))

print("\nConverted audio (A→B):")
display(ipd.Audio(converted_audio.numpy(), rate=config.audio.sample_rate))

print("\n Note: Griffin-Lim produces artifacts. Use neural vocoder for better quality.")

## 9. Résumé

### Tests Réussis 
1. Initialisation du modèle
2. Forward pass
3. Calcul des pertes
4. Mini-entraînement
5. Génération d'exemples

### Observations
- Le modèle peut apprendre (pertes décroissantes)
- Après 10 epochs, conversion visible mais imparfaite
- Besoin de plus d'epochs pour qualité production

### Prochaines Étapes
1. Entraîner sur dataset complet (50-100 epochs)
2. Utiliser vocoder neuronal
3. Évaluer avec métriques objectives
4. Fine-tuning des hyperparamètres

In [None]:
print(" Model training pipeline validated!")
print("\nReady for full training with:")
print("  python ../scripts/train.py --data_dir ../data/")