In [2]:
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from audiocraft.modules.conditioners import ConditioningAttributes
import os
import numpy as np
from pathlib import Path
import librosa
from IPython.display import Audio, display
import warnings
import gc
import json
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm

warnings.filterwarnings('ignore')

  return torch._C._cuda_getDeviceCount() > 0
2025-11-26 14:14:04.384248: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-26 14:14:04.520675: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-26 14:14:05.917003: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [3]:
DATASET_PATH = "./dataset/agbadja"
OUTPUT_PATH = "./musicgen_outputs"
MODEL_SAVE_PATH = "./musicgen_agbadja_model"
LOG_FILE = "./training_log.json"

EPOCHS = 5
LEARNING_RATE = 1e-5  
GRADIENT_ACCUMULATION_STEPS = 4
SAVE_EVERY = 2
MAX_GRAD_NORM = 1.0

os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

print(" Configuration:")
print(f" Dataset: {DATASET_PATH}")
print(f" Outputs: {OUTPUT_PATH}")
print(f" Model: {MODEL_SAVE_PATH}")
print(f" Epochs: {EPOCHS}")
print(f" Learning Rate: {LEARNING_RATE}")

 Configuration:
 Dataset: ./dataset/agbadja
 Outputs: ./musicgen_outputs
 Model: ./musicgen_agbadja_model
 Epochs: 5
 Learning Rate: 1e-05


In [5]:
import pandas as pd

metadata_path = "./dataset/metadata.csv"

try:
    metadata_df = pd.read_csv(metadata_path)
    print(f" Metadata charg√©: {len(metadata_df)} entr√©es\n")
    
    audio_captions = {}
    for _, row in metadata_df.iterrows():
        filename = Path(row['audio_path']).name
        audio_captions[filename] = row['caption']
    
    print(f"{len(audio_captions)} descriptions charg√©es\n")
    
except Exception as e:
    print(f"  Erreur chargement metadata: {e}")
    audio_captions = {}

# Charger les fichiers audio
def load_audio_files(dataset_path):
    audio_extensions = ['.wav', '.mp3', '.flac', '.ogg']
    audio_files = []
    
    for ext in audio_extensions:
        audio_files.extend(list(Path(dataset_path).glob(f'*{ext}')))
    
    return sorted(audio_files)

audio_files = load_audio_files(DATASET_PATH)
print(f" {len(audio_files)} fichiers audio trouv√©s dans {DATASET_PATH}\n")

if len(audio_files) == 0:
    raise ValueError(f" Aucun fichier audio dans {DATASET_PATH}")

# Afficher les fichiers avec leurs descriptions
for i, file in enumerate(audio_files, 1):
    caption = audio_captions.get(file.name, "(pas de description)")
    print(f"   {i:2d}. {file.name}")
    print(f"        {caption[:60]}..." if len(caption) > 60 else f"        {caption}")


 Metadata charg√©: 16 entr√©es

16 descriptions charg√©es

 16 fichiers audio trouv√©s dans ./dataset/agbadja

    1. agbadja_1.wav
        traditional Agbadja instrumental with energetic percussions ...
    2. agbadja_10.wav
        agbadja instrumental with layered percussion and steady cult...
    3. agbadja_11.wav
        Agbadja rhythm instrumental focused on drums and traditional...
    4. agbadja_12.wav
        deep traditional Agbadja groove with evolving rhythmic textu...
    5. agbadja_13.wav
        instrumental Agbadja beat highlighting dynamic drum sequence...
    6. agbadja_14.wav
        Agbadja percussive progression with cultural rhythmic identi...
    7. agbadja_15.wav
        traditional Agbadja instrumental loop with warm percussive t...
    8. agbadja_16.wav
        authentic Agbadja rhythm with layered percussion and stable ...
    9. agbadja_2.wav
        Agbadja style instrumental featuring steady drums and melodi...
   10. agbadja_3.wav
        traditional Agba

In [7]:
print(" EXEMPLES DE DESCRIPTIONS (metadata.csv):\n")
print("="*70)

for i, (filename, caption) in enumerate(list(audio_captions.items())[:5], 1):
    print(f"\n{i}. Fichier: {filename}")
    print(f"   Description: \"{caption}\"")

print("\n" + "="*70)

 EXEMPLES DE DESCRIPTIONS (metadata.csv):


1. Fichier: agbadja_1.wav
   Description: "traditional Agbadja instrumental with energetic percussions and rhythmic patterns"

2. Fichier: agbadja_2.wav
   Description: "Agbadja style instrumental featuring steady drums and melodic rhythmic flow"

3. Fichier: agbadja_3.wav
   Description: "traditional Agbadja groove recorded with percussive texture and cultural rhythm"

4. Fichier: agbadja_4.wav
   Description: "Agbadja instrumental loop with steady tempo and ancestral percussive elements"

5. Fichier: agbadja_5.wav
   Description: "deep Agbadja percussive ensemble creating a traditional rhythmic atmosphere"



In [8]:
def analyze_audio_dataset(audio_files, max_files=None):
    durations = []
    sample_rates = []
    
    files_to_analyze = audio_files if max_files is None else audio_files[:max_files]
    print(f" Analyse de {len(files_to_analyze)} fichiers...\n")
    
    for audio_file in tqdm(files_to_analyze, desc="Analyse"):
        try:
            y, sr = librosa.load(str(audio_file), sr=None, mono=True)
            duration = librosa.get_duration(y=y, sr=sr)
            durations.append(duration)
            sample_rates.append(sr)
        except Exception as e:
            print(f" Erreur avec {audio_file.name}: {e}")
    
    if durations:
        print(f"\n Statistiques du dataset:")
        print(f"   Dur√©e moyenne: {np.mean(durations):.2f}s")
        print(f"   Dur√©e min: {np.min(durations):.2f}s")
        print(f"   Dur√©e max: {np.max(durations):.2f}s")
        print(f"   Sample rate moyen: {int(np.mean(sample_rates))} Hz")
        return np.mean(durations), int(np.mean(sample_rates))
    return None, None

avg_duration, avg_sr = analyze_audio_dataset(audio_files)

 Analyse de 16 fichiers...



Analyse: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16/16 [00:01<00:00,  9.63it/s]


 Statistiques du dataset:
   Dur√©e moyenne: 30.00s
   Dur√©e min: 30.00s
   Dur√©e max: 30.04s
   Sample rate moyen: 44100 Hz





In [9]:
MODEL_SIZE = 'small'  

print(f" Chargement de MusicGen-{MODEL_SIZE}...")
model = MusicGen.get_pretrained(f'facebook/musicgen-{MODEL_SIZE}')
print(f" Mod√®le charg√© avec succ√®s\n")

# Configuration du device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.lm.to(device)
print(f" Device: {device}")

# Param√®tres du mod√®le
total_params = sum(p.numel() for p in model.lm.parameters())
trainable_params = sum(p.numel() for p in model.lm.parameters() if p.requires_grad)
print(f"Param√®tres totaux: {total_params:,}")
print(f"Param√®tres entra√Ænables: {trainable_params:,}")

 Chargement de MusicGen-small...
 Mod√®le charg√© avec succ√®s

 Device: cpu
Param√®tres totaux: 420,371,456
Param√®tres entra√Ænables: 420,371,456


In [10]:
class AudioDataset(Dataset):
    def __init__(self, audio_files, audio_captions, target_sr=32000, duration=10):
        self.audio_files = audio_files
        self.audio_captions = audio_captions
        self.target_sr = target_sr
        self.duration = duration
        self.target_length = int(target_sr * duration)
        
        print(f"   Dataset cr√©√©:")
        print(f"   Fichiers: {len(audio_files)}")
        print(f"   Captions: {len(audio_captions)}")
        print(f"   Sample Rate: {target_sr} Hz")
        print(f"   Dur√©e par segment: {duration}s")
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = str(self.audio_files[idx])
        filename = self.audio_files[idx].name
        caption = self.audio_captions.get(filename, "")
        
        try:
            # Charger l'audio
            waveform_np, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
            waveform = torch.from_numpy(waveform_np).float().unsqueeze(0)
            
            # Ajuster la longueur
            if waveform.shape[1] > self.target_length:
                # D√©couper al√©atoirement
                start = np.random.randint(0, waveform.shape[1] - self.target_length)
                waveform = waveform[:, start:start + self.target_length]
            elif waveform.shape[1] < self.target_length:
                # Padding
                padding = self.target_length - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            
            return waveform, caption
            
        except Exception as e:
            print(f" Erreur: {filename}: {e}")
            return torch.zeros(1, self.target_length), ""

# Cr√©er le dataset et dataloader
dataset = AudioDataset(audio_files, audio_captions, target_sr=32000, duration=10)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
print(f" DataLoader pr√™t ({len(dataset)} samples avec captions)")

   Dataset cr√©√©:
   Fichiers: 16
   Captions: 16
   Sample Rate: 32000 Hz
   Dur√©e par segment: 10s
 DataLoader pr√™t (16 samples avec captions)


In [12]:
print(" TEST DE LA BOUCLE D'ENTRA√éNEMENT")
print("="*70)

try:
    # Pr√©parer le mod√®le
    model.compression_model.eval()
    for param in model.compression_model.parameters():
        param.requires_grad = False
    
    model.lm.train()
    
    # Test avec un seul batch
    test_batch, test_caption = next(iter(dataloader))
    test_batch = test_batch.to(device)
    print(f" Batch de test charg√©: {test_batch.shape}")
    print(f" Caption: {test_caption[0][:80]}...")
    
    # Encoder
    with torch.no_grad():
        encoded = model.compression_model.encode(test_batch)
        codes = encoded[0] if isinstance(encoded, tuple) else encoded
        
        if codes.dim() > 2:
            codes = codes[0].unsqueeze(0)
    
    print(f" Codes encod√©s: {codes.shape}")
    
    # Cr√©er l'objet de conditioning
    conditioning = ConditioningAttributes(text={'description': test_caption[0]})
    
    # Forward pass AVEC conditioning (text)
    lm_output = model.lm.compute_predictions(
        codes=codes,
        conditions=[conditioning]
    )
    
    # Extraire les logits de l'objet LMOutput
    logits = lm_output.logits
    
    print(f" Logits g√©n√©r√©s: {logits.shape}")
    
    # Loss - logits: [B, K, T, vocab_size], codes: [B, K, T]
    # On calcule la loss sur la premi√®re codebook (K=0)
    B, K, T, vocab_size = logits.shape
    loss = torch.nn.functional.cross_entropy(
        logits[:, 0, :, :].reshape(-1, vocab_size),  # [B*T, vocab_size]
        codes[:, 0, :].reshape(-1)                     # [B*T]
    )
    
    print(f" Loss calcul√©e: {loss.item():.4f}")
    print("\n" + "="*70)
    print(" TEST R√âUSSI! La boucle d'entra√Ænement fonctionne correctement")
    print("="*70)
    
    # Nettoyer
    del test_batch, codes, logits, loss
    gc.collect()
    
except Exception as e:
    print(f"\n ERREUR: {e}")


    import traceback        

    traceback.print_exc()    
    print("\n V√©rifiez les erreurs ci-dessus avant de lancer l'entra√Ænement")

 TEST DE LA BOUCLE D'ENTRA√éNEMENT
 Batch de test charg√©: torch.Size([1, 1, 320000])
 Caption: traditional Agbadja instrumental loop with warm percussive tones...
 Codes encod√©s: torch.Size([1, 4, 500])
 Logits g√©n√©r√©s: torch.Size([1, 4, 500, 2048])
 Loss calcul√©e: 3.8217

 TEST R√âUSSI! La boucle d'entra√Ænement fonctionne correctement


In [16]:
def train_model():
    training_log = {
        'start_time': datetime.now().isoformat(),
        'config': {
            'model': MODEL_SIZE,
            'epochs': EPOCHS,
            'learning_rate': LEARNING_RATE,
            'gradient_accumulation': GRADIENT_ACCUMULATION_STEPS,
            'num_files': len(audio_files),
            'device': str(device)
        },
        'epochs': []
    }
    
    # Configuration du mod√®le
    model.compression_model.eval()
    for param in model.compression_model.parameters():
        param.requires_grad = False
    
    model.lm.train()
    
    # Optimiseur
    optimizer = torch.optim.AdamW(
        model.lm.parameters(),
        lr=LEARNING_RATE,
        weight_decay=0.01
    )
    
    print(f" Configuration pr√™te")
    print(f"   Mod√®le: MusicGen-{MODEL_SIZE}")
    print(f"   Epochs: {EPOCHS}")
    print(f"   Learning Rate: {LEARNING_RATE}")
    print(f"   Device: {device}\n")
    
    try:
        for epoch in range(EPOCHS):
            epoch_loss = 0
            batch_count = 0
            successful_batches = 0
            
            print(f"\n{'='*70}")
            print(f" EPOCH {epoch + 1}/{EPOCHS}")
            print(f"{'='*70}")
            
            pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}", ncols=100)
            
            for batch_idx, batch_data in enumerate(pbar):
                try:
                    audio_batch, captions = batch_data
                    audio_batch = audio_batch.to(device)
                    
                    # Encoder l'audio
                    with torch.no_grad():
                        encoded = model.compression_model.encode(audio_batch)
                        codes = encoded[0] if isinstance(encoded, tuple) else encoded
                        
                        if codes.dim() > 2:
                            codes = codes[0].unsqueeze(0)
                    
                    if codes.shape[0] == 0:
                        continue
                    
                    # Utiliser le caption du metadata.csv pour conditioning
                    B, K, T = codes.shape
                    caption = captions[0] if len(captions) > 0 else ""
                    
                    # Cr√©er l'objet de conditioning
                    conditioning = ConditioningAttributes(text={'description': caption})
                    
                    # Forward pass AVEC le texte de description
                    lm_output = model.lm.compute_predictions(
                        codes=codes,
                        conditions=[conditioning]
                    )
                    
                    # Extraire les logits de l'objet LMOutput
                    logits = lm_output.logits
                    
                    # Loss - logits: [B, K, T, vocab_size], codes: [B, K, T]
                    # On calcule la loss sur la premi√®re codebook (K=0)
                    B_train, K_train, T_train, vocab_size_train = logits.shape
                    loss = torch.nn.functional.cross_entropy(
                        logits[:, 0, :, :].reshape(-1, vocab_size_train),  # [B*T, vocab_size]
                        codes[:, 0, :].reshape(-1)                           # [B*T]
                    )
                    
                    loss = loss / GRADIENT_ACCUMULATION_STEPS
                    loss.backward()
                    
                    # Update
                    if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                        torch.nn.utils.clip_grad_norm_(
                            model.lm.parameters(),
                            max_norm=MAX_GRAD_NORM
                        )
                        optimizer.step()
                        optimizer.zero_grad()
                    
                    # Tracking
                    epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
                    batch_count += 1
                    successful_batches += 1
                    
                    pbar.set_postfix({
                        'loss': f'{loss.item() * GRADIENT_ACCUMULATION_STEPS:.4f}',
                        'success': f'{successful_batches}/{len(dataloader)}'
                    })
                    
                    # Nettoyage m√©moire
                    del audio_batch, codes, logits, loss
                    if batch_idx % 5 == 0:
                        gc.collect()
                    
                except Exception as e:
                    print(f"\n Erreur batch {batch_idx}: {str(e)}")
                    import traceback
                    print(f"D√©tails: {traceback.format_exc()[:500]}")
                    optimizer.zero_grad()
                    gc.collect()
                    continue
            
            # Statistiques epoch
            avg_loss = epoch_loss / batch_count if batch_count > 0 else float('inf')
            
            epoch_stats = {
                'epoch': epoch + 1,
                'avg_loss': avg_loss,
                'successful_batches': successful_batches,
                'total_batches': len(dataloader)
            }
            training_log['epochs'].append(epoch_stats)
            
            print(f"\n Epoch {epoch + 1} termin√©")
            print(f"   Loss moyenne: {avg_loss:.4f}")
            print(f"   Batches r√©ussis: {successful_batches}/{len(dataloader)}")
            
            # Sauvegarder checkpoint
            if (epoch + 1) % SAVE_EVERY == 0:
                checkpoint_path = os.path.join(
                    MODEL_SAVE_PATH,
                    f'checkpoint_epoch_{epoch + 1}.pt'
                )
                
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.lm.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                    'config': training_log['config']
                }, checkpoint_path)
                
                print(f" Checkpoint sauvegard√©: checkpoint_epoch_{epoch + 1}.pt")
            
            # Sauvegarder le log
            with open(LOG_FILE, 'w') as f:
                json.dump(training_log, f, indent=2)
        
        # Sauvegarder le mod√®le final
        final_path = os.path.join(MODEL_SAVE_PATH, 'final_model.pt')
        torch.save({
            'model_state_dict': model.lm.state_dict(),
            'training_log': training_log
        }, final_path)
        
        print(f"\n{'='*70}")
        print(f" ENTRA√éNEMENT TERMIN√â AVEC SUCC√àS")
        print(f"{'='*70}")
        print(f" Mod√®le final: {final_path}")
        print(f" Log: {LOG_FILE}")
        
        return training_log
        
    except KeyboardInterrupt:
        print("\n  Entra√Ænement interrompu par l'utilisateur")
        return training_log
    
    except Exception as e:
        print(f"\n ERREUR CRITIQUE: {e}")
        import traceback
        traceback.print_exc()
        return training_log



print(" Pr√™t √† lancer l'entra√Ænement")# LANCER L'ENTRA√éNEMENTprint(" Ex√©cutez cette cellule pour commencer\n")

 Pr√™t √† lancer l'entra√Ænement


In [None]:
training_log = train_model()

 Configuration pr√™te
   Mod√®le: MusicGen-small
   Epochs: 5
   Learning Rate: 1e-05
   Device: cpu


 EPOCH 1/5


Epoch 1:  19%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé                      | 3/16 [00:18<01:21,  6.28s/it, loss=2.4129, success=3/16]

---
## üéµ G√âN√âRATION DE MUSIQUE AVEC PROMPTS PERSONNALIS√âS

Maintenant que le mod√®le est entra√Æn√©, vous pouvez g√©n√©rer de la musique en utilisant des prompts personnalis√©s.
Le mod√®le a appris le style Agbadja et peut cr√©er des variations bas√©es sur vos descriptions.

In [None]:
# üéº FONCTION DE G√âN√âRATION AVEC PROMPT PERSONNALIS√â
def generate_music(prompt, duration=10, temperature=0.8, top_k=200, cfg_coef=4.0, output_name="generation"):
    """
    G√©n√®re de la musique bas√©e sur un prompt personnalis√©
    
    Args:
        prompt (str): Description de la musique √† g√©n√©rer
        duration (int): Dur√©e en secondes (d√©faut: 10)
        temperature (float): Cr√©ativit√© (0.5=conservateur, 1.0=cr√©atif) (d√©faut: 0.8)
        top_k (int): Diversit√© des choix (100-250) (d√©faut: 200)
        cfg_coef (float): Force du prompt (3.0-6.0) (d√©faut: 4.0)
        output_name (str): Nom du fichier de sortie
    
    Returns:
        str: Chemin du fichier audio g√©n√©r√©
    """
    
    print("üéµ G√âN√âRATION DE MUSIQUE")
    print("=" * 70)
    print(f"üìù Prompt: {prompt}")
    print(f"‚è±Ô∏è  Dur√©e: {duration}s")
    print(f"üå°Ô∏è  Temperature: {temperature}")
    print(f"üé≤ Top-K: {top_k}")
    print(f"üéØ CFG Coefficient: {cfg_coef}")
    print("=" * 70 + "\n")
    
    # Mettre le mod√®le en mode √©valuation
    model.lm.eval()
    
    # Configurer les param√®tres de g√©n√©ration
    model.set_generation_params(
        duration=duration,
        temperature=temperature,
        top_k=top_k,
        top_p=0.0,
        cfg_coef=cfg_coef
    )
    
    try:
        with torch.no_grad():
            # G√©n√©rer
            print("‚è≥ G√©n√©ration en cours...")
            wav = model.generate(
                descriptions=[prompt],
                progress=True
            )
            
            # Sauvegarder
            output_path = os.path.join(OUTPUT_PATH, output_name)
            audio_write(
                output_path,
                wav[0].cpu(),
                model.sample_rate,
                strategy="loudness",
                loudness_compressor=True
            )
            
            output_file = output_path + '.wav'
            file_size_kb = os.path.getsize(output_file) / 1024
            
            print(f"\n‚úÖ SUCC√àS!")
            print(f"üìÅ Fichier: {output_name}.wav ({file_size_kb:.1f} KB)")
            print(f"üìÇ Dossier: {OUTPUT_PATH}/")
            
            # Nettoyer
            del wav
            gc.collect()
            
            return output_file
            
    except Exception as e:
        print(f"‚ùå ERREUR: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Fonction generate_music() pr√™te")
print("\nüí° Exemple d'utilisation:")
print('   generate_music("traditional Agbadja drums with energetic rhythm", duration=10)')

In [None]:
# üé® EXEMPLES DE G√âN√âRATION

# Exemple 1: Style traditionnel pur
print("ü•Å Exemple 1: Style traditionnel Agbadja")
file1 = generate_music(
    prompt="traditional Agbadja percussion, West African ceremonial drums, djembe and talking drum ensemble",
    duration=10,
    temperature=0.7,
    output_name="agbadja_traditional"
)

# Exemple 2: Avec √©nergie accrue
print("\n" + "="*70 + "\n")
print("‚ö° Exemple 2: Style √©nergique")
file2 = generate_music(
    prompt="energetic Agbadja drums with powerful polyrhythmic",
    duration=10,
    temperature=0.8,
    output_name="agbadja_energetic"
)

# Exemple 3: Ambiance c√©r√©monielle
print("\n" + "="*70 + "\n")
print("üåô Exemple 3: Ambiance c√©r√©monielle")
file3 = generate_music(
    prompt="ceremonial Agbadja ritual drums, deep resonant tones, ancestral rhythm patterns",
    duration=10,
    temperature=0.7,
    output_name="agbadja_ceremonial"
)

print("\n" + "="*70)
print("‚úÖ 3 EXEMPLES G√âN√âR√âS")
print("="*70)
print(f"üìÅ Fichiers dans {OUTPUT_PATH}/")
print("   1. agbadja_traditional.wav")
print("   2. agbadja_energetic.wav")
print("   3. agbadja_ceremonial.wav")

In [None]:
# üéß √âCOUTER LES G√âN√âRATIONS
print("üéß √âCOUTE DES G√âN√âRATIONS\n")

# Lister tous les fichiers g√©n√©r√©s
generated_files = sorted(Path(OUTPUT_PATH).glob('*.wav'))

if len(generated_files) == 0:
    print("‚ùå Aucun fichier g√©n√©r√© trouv√©")
else:
    print(f"üìÅ {len(generated_files)} fichier(s) trouv√©(s) dans {OUTPUT_PATH}/\n")
    
    for i, filepath in enumerate(generated_files, 1):
        print(f"{'='*70}")
        print(f"üéµ {i}. {filepath.name}")
        print(f"{'='*70}")
        
        try:
            # Charger et afficher l'audio
            y, sr = librosa.load(str(filepath), sr=None)
            duration = librosa.get_duration(y=y, sr=sr)
            file_size = filepath.stat().st_size / 1024
            
            print(f"   Dur√©e: {duration:.2f}s")
            print(f"   Taille: {file_size:.1f} KB")
            print(f"   Sample Rate: {sr} Hz\n")
            
            display(Audio(y, rate=sr))
            
        except Exception as e:
            print(f"‚ùå Erreur de lecture: {e}\n")

In [None]:
# ‚ú® G√âN√âRATION PERSONNALIS√âE INTERACTIVE

print("‚ú® G√âN√âRATION PERSONNALIS√âE")
print("="*70)
print("Entrez votre propre prompt pour g√©n√©rer de la musique!")
print("="*70 + "\n")

# PERSONNALISEZ ICI VOTRE PROMPT
user_prompt = "traditional Agbadja drums with fast tempo and complex polyrhythms"

# PARAM√àTRES AJUSTABLES
duration = 10          # Dur√©e en secondes
temperature = 0.8      # 0.5 = conservateur, 1.0 = cr√©atif
top_k = 200           # 100-250, contr√¥le la diversit√©
cfg_coef = 4.0        # 3.0-6.0, force du prompt
output_name = "custom_generation"

print(f"üìù Votre prompt: {user_prompt}")
print(f"‚öôÔ∏è  Param√®tres:")
print(f"   - Dur√©e: {duration}s")
print(f"   - Temperature: {temperature}")
print(f"   - Top-K: {top_k}")
print(f"   - CFG Coef: {cfg_coef}\n")

# G√©n√©rer
generated_file = generate_music(
    prompt=user_prompt,
    duration=duration,
    temperature=temperature,
    top_k=top_k,
    cfg_coef=cfg_coef,
    output_name=output_name
)

# √âcouter
if generated_file:
    print("\nüéß √âcoute du r√©sultat:")
    y, sr = librosa.load(generated_file, sr=None)
    display(Audio(y, rate=sr))

---
## üìä VISUALISATION DES R√âSULTATS D'ENTRA√éNEMENT

In [None]:
# üìà COURBE D'ENTRA√éNEMENT
import matplotlib.pyplot as plt

if training_log and len(training_log.get('epochs', [])) > 0:
    losses = [e['avg_loss'] for e in training_log['epochs']]
    epochs_nums = [e['epoch'] for e in training_log['epochs']]
    success_rates = [e['successful_batches']/e['total_batches']*100 for e in training_log['epochs']]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Courbe de loss
    ax1.plot(epochs_nums, losses, marker='o', linewidth=2, markersize=8, color='#2E86AB')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('√âvolution de la Loss', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Taux de r√©ussite
    ax2.plot(epochs_nums, success_rates, marker='s', linewidth=2, markersize=8, color='#06A77D')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Taux de r√©ussite (%)', fontsize=12)
    ax2.set_title('Taux de r√©ussite des batches', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 105)
    
    plt.tight_layout()
    
    plot_path = os.path.join(OUTPUT_PATH, 'training_curves.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nüìä Statistiques d'entra√Ænement:")
    print(f"   Loss initiale: {losses[0]:.4f}")
    print(f"   Loss finale: {losses[-1]:.4f}")
    if losses[0] > 0:
        improvement = ((losses[0] - losses[-1]) / losses[0]) * 100
        print(f"   Am√©lioration: {improvement:.2f}%")
    print(f"   Taux de r√©ussite final: {success_rates[-1]:.1f}%")
    print(f"\nüíæ Graphique sauvegard√©: {plot_path}")
else:
    print("‚ö†Ô∏è  Aucune donn√©e d'entra√Ænement disponible")
    print("üí° Ex√©cutez d'abord la cellule d'entra√Ænement")

---
## üí° GUIDE D'UTILISATION

### üöÄ Comment utiliser ce notebook:

1. **Ex√©cutez les cellules 1-7** pour charger le dataset et le mod√®le
2. **Cellule 8**: Lance l'entra√Ænement (peut prendre plusieurs heures)
3. **Cellule 10-11**: G√©n√®re des exemples avec prompts pr√©d√©finis
4. **Cellule 12**: Personnalisez votre prompt pour cr√©er votre propre musique

### üé® Conseils pour les prompts:

**√âl√©ments √† inclure:**
- Style musical: "traditional Agbadja", "ceremonial drums"
- Instruments: "djembe", "talking drums", "dundun"
- Tempo: "fast", "slow", "steady"
- Ambiance: "energetic", "calm", "powerful", "ritual"
- Patterns: "polyrhythmic", "complex patterns", "simple rhythm"

**Exemples de bons prompts:**
```
"traditional Agbadja percussion with fast polyrhythmic djembe patterns"
"ceremonial West African drums, deep dundun bass with talking drum accents"
"energetic Agbadja ensemble, complex polyrhythms, powerful ceremonial energy"
```

### ‚öôÔ∏è Param√®tres de g√©n√©ration:

- **temperature** (0.5-1.0): Plus bas = plus fid√®le au dataset, plus haut = plus cr√©atif
- **top_k** (100-250): Contr√¥le la diversit√© des choix
- **cfg_coef** (3.0-6.0): Force avec laquelle le mod√®le suit le prompt
- **duration**: Dur√©e en secondes (recommand√©: 10-30s)

### üìÇ Fichiers g√©n√©r√©s:

- **Mod√®le**: `./musicgen_agbadja_model/final_model.pt`
- **Checkpoints**: `./musicgen_agbadja_model/checkpoint_epoch_X.pt`
- **G√©n√©rations**: `./musicgen_outputs/*.wav`
- **Log**: `./training_log.json`