In [1]:
# Cell 1: Imports et configuration
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import random
from sklearn.metrics import confusion_matrix, classification_report, f1_score, accuracy_score, precision_score, recall_score
import timm
import warnings
warnings.filterwarnings('ignore')

# Configuration optimis√©e
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
print(f"üî• Using device: {device}")

# V√©rification des donn√©es
print("üìÅ Checking data...")
data_path = r"D:\voice_processing\data\SpeechCommands\speech_commands_v0.02"
if os.path.exists(data_path):
    print(f"‚úÖ Data folder found: {data_path}")
    classes = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d)) and not d.startswith('_')]
    print(f"üéØ Detected classes: {len(classes)}")
    print(f"üìù Classes: {classes[:10]}...")
else:
    print("‚ùå Data folder not found! Automatic download will occur...")

# Configuration du mod√®le
class ModelConfig:
    def __init__(self):
        # Dropout configuration
        self.input_dropout = 0.1
        self.feature_dropout = 0.4
        self.classifier_dropout = 0.3
        
        # Training configuration
        self.weight_decay = 0.01
        self.label_smoothing = 0.1
        self.learning_rate = 1e-3
        self.batch_size = 16
        self.num_epochs = 100  # Augment√© pour permettre √† l'early stopping de fonctionner
        
        # Audio configuration
        self.n_mels = 128
        self.n_fft = 1024
        self.hop_length = 256
        self.sample_rate = 16000
        self.target_length = 16000  # 1 second

config = ModelConfig()
print("‚öôÔ∏è Model Configuration:")
for key, value in config.__dict__.items():
    print(f"   {key}: {value}")

üî• Using device: cuda
üìÅ Checking data...
‚úÖ Data folder found: D:\voice_processing\data\SpeechCommands\speech_commands_v0.02
üéØ Detected classes: 35
üìù Classes: ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward']...
‚öôÔ∏è Model Configuration:
   input_dropout: 0.1
   feature_dropout: 0.4
   classifier_dropout: 0.3
   weight_decay: 0.01
   label_smoothing: 0.1
   learning_rate: 0.001
   batch_size: 16
   num_epochs: 100
   n_mels: 128
   n_fft: 1024
   hop_length: 256
   sample_rate: 16000
   target_length: 16000


In [2]:
# Cell 2: Early Stopping am√©lior√© et Augmentation
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.best_acc = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None
        self.best_epoch = 0
        
    def __call__(self, val_loss, val_acc, model, epoch):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_acc = val_acc
            self.best_model_state = model.state_dict().copy()
            self.best_epoch = epoch
            self.counter = 0
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f"üîÑ EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                print("üõë Early stopping triggered!")
        else:
            self.best_loss = val_loss
            self.best_acc = val_acc
            self.best_model_state = model.state_dict().copy()
            self.best_epoch = epoch
            self.counter = 0
            print("‚úÖ New best model saved!")
        
        return self.early_stop
    
    def restore_best_model(self, model):
        if self.restore_best_weights and self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)
            print(f"‚úÖ Best weights restored from epoch {self.best_epoch}!")
            print(f"üèÜ Best validation accuracy: {self.best_acc:.2f}%")

class AudioAugmentation:
    def __init__(self):
        self.time_stretch = T.TimeStretch()
        self.pitch_shift = T.PitchShift(sample_rate=config.sample_rate, n_steps=4)
        
    def __call__(self, waveform):
        # Time stretching
        if random.random() > 0.6:
            rate = random.uniform(0.85, 1.15)
            try:
                waveform = self.time_stretch(waveform, rate)
            except:
                pass
        
        # Pitch shifting
        if random.random() > 0.6:
            try:
                waveform = self.pitch_shift(waveform)
            except:
                pass
        
        # Gaussian noise
        if random.random() > 0.7:
            noise = torch.randn_like(waveform) * 0.005
            waveform = waveform + noise
        
        # Random gain
        if random.random() > 0.5:
            gain = random.uniform(0.8, 1.2)
            waveform = waveform * gain
            
        return waveform

In [3]:
# Cell 3: Dataset et DataLoader (MODIFI√â)
class SpeechCommandsDataset(Dataset):
    def __init__(self, subset='training', apply_augmentation=False):
        # Utilisez votre chemin local existant au lieu de t√©l√©charger
        self.data_path = r"D:\voice_processing\data\SpeechCommands\speech_commands_v0.02"
        
        # Chargez manuellement les fichiers selon le subset
        if subset == 'training':
            # Pour l'entra√Ænement, utilisez tous les fichiers SAUF ceux dans les listes de validation/test
            with open(os.path.join(self.data_path, 'validation_list.txt'), 'r') as f:
                val_files = set(f.read().splitlines())
            with open(os.path.join(self.data_path, 'testing_list.txt'), 'r') as f:
                test_files = set(f.read().splitlines())
            
            all_files = []
            for class_name in os.listdir(self.data_path):
                class_path = os.path.join(self.data_path, class_name)
                if os.path.isdir(class_path) and not class_name.startswith('_'):
                    for file in os.listdir(class_path):
                        if file.endswith('.wav'):
                            rel_path = os.path.join(class_name, file)
                            if rel_path not in val_files and rel_path not in test_files:
                                all_files.append((os.path.join(class_path, file), class_name))
            
            self.samples = all_files
            
        elif subset == 'validation':
            # Pour la validation, utilisez validation_list.txt
            with open(os.path.join(self.data_path, 'validation_list.txt'), 'r') as f:
                val_files = f.read().splitlines()
            
            self.samples = []
            for rel_path in val_files:
                class_name = rel_path.split('/')[0]
                full_path = os.path.join(self.data_path, rel_path)
                if os.path.exists(full_path):
                    self.samples.append((full_path, class_name))
                    
        elif subset == 'testing':
            # Pour le test, utilisez testing_list.txt
            with open(os.path.join(self.data_path, 'testing_list.txt'), 'r') as f:
                test_files = f.read().splitlines()
            
            self.samples = []
            for rel_path in test_files:
                class_name = rel_path.split('/')[0]
                full_path = os.path.join(self.data_path, rel_path)
                if os.path.exists(full_path):
                    self.samples.append((full_path, class_name))
        
        # R√©cup√©rer toutes les classes
        self.labels = sorted(list(set([label for _, label in self.samples])))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
        self.apply_augmentation = apply_augmentation
        self.augment = AudioAugmentation()
        
        print(f"üìä {subset}: {len(self.samples)} samples, {len(self.labels)} classes")
        
        # Transformations audio (garder le reste du code inchang√©)
        self.transform = T.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=config.n_fft,
            hop_length=config.hop_length,
            n_mels=config.n_mels,
            f_min=20,
            f_max=8000
        )
        self.to_db = T.AmplitudeToDB()
        
        # SpecAugment
        self.time_mask = T.TimeMasking(time_mask_param=20)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=10)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        # Charger l'audio
        waveform, sample_rate = torchaudio.load(file_path)
        
        # Resampling
        if sample_rate != config.sample_rate:
            resampler = T.Resample(sample_rate, config.sample_rate)
            waveform = resampler(waveform)
        
        # Le reste du code __getitem__ reste inchang√©...
        # Normalisation
        waveform = waveform / (waveform.abs().max() + 1e-8)
        
        # Augmentation
        if self.apply_augmentation and random.random() > 0.5:
            waveform = self.augment(waveform)
        
        # Padding/truncation
        current_length = waveform.shape[1]
        if current_length < config.target_length:
            waveform = F.pad(waveform, (0, config.target_length - current_length))
        else:
            if self.apply_augmentation:
                start = random.randint(0, current_length - config.target_length)
            else:
                start = (current_length - config.target_length) // 2
            waveform = waveform[:, start:start + config.target_length]
        
        # Mel Spectrogram
        mel_spec = self.transform(waveform)
        mel_spec_db = self.to_db(mel_spec)
        
        # SpecAugment pour l'entra√Ænement
        if self.apply_augmentation:
            mel_spec_db = self.time_mask(mel_spec_db)
            mel_spec_db = self.freq_mask(mel_spec_db)
        
        # Normalisation
        mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
        
        # 3 channels pour MobileNetV3
        mel_spec_3ch = mel_spec_db.repeat(3, 1, 1)
        
        label_idx = self.label_to_idx[label]
        
        return mel_spec_3ch, label_idx

In [4]:
# Cell 4: Version alternative plus simple
class MobileNetV3AudioClassifier(nn.Module):
    def __init__(self, num_classes=35, freeze_layers=5):
        super().__init__()
        
        # Backbone MobileNetV3 Large pr√©-entra√Æn√©
        self.backbone = timm.create_model(
            'mobilenetv3_large_100',
            pretrained=True,
            in_chans=3,
            num_classes=num_classes  # Utiliser le classificateur original
        )
        
        # Geler les premi√®res couches
        self._freeze_layers(freeze_layers)
        
        print(f"‚úÖ MobileNetV3 Large cr√©√© avec {freeze_layers} couches gel√©es")
        
    def _freeze_layers(self, num_layers):
        """Geler les premi√®res couches du backbone"""
        layers_frozen = 0
        
        for name, param in self.backbone.named_parameters():
            if 'conv' in name or 'blocks' in name:
                if layers_frozen < num_layers:
                    param.requires_grad = False
                    layers_frozen += 1
                else:
                    param.requires_grad = True
        
        print(f"‚ùÑÔ∏è {layers_frozen} couches gel√©es")
        
    def forward(self, x):
        return self.backbone(x)

In [5]:
# Cell 5: Fonctions de perte et entra√Ænement
class LabelSmoothCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, x, target):
        log_probs = F.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

def train_epoch(model, loader, criterion, optimizer, device):
    """Entra√Ænement pour une epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='üöÄ Training')
    for data, target in pbar:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        accuracy = 100. * correct / total
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Accuracy': f'{accuracy:.2f}%'
        })
    
    return total_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
    """Validation"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in tqdm(loader, desc='üìä Validation'):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / total



In [6]:
# Cell 6: Chargement des donn√©es et cr√©ation du mod√®le
print("üîß Chargement des datasets...")

# Mode rapide pour les tests
FAST_MODE = True
SUBSET_SIZE = 5000 if FAST_MODE else None

# Chargement des datasets
train_dataset = SpeechCommandsDataset(subset='training', apply_augmentation=True)
val_dataset = SpeechCommandsDataset(subset='validation', apply_augmentation=False)
test_dataset = SpeechCommandsDataset(subset='testing', apply_augmentation=False)

num_classes = len(train_dataset.labels)

# Ajustement de la batch size selon le mode
batch_size = 32 if FAST_MODE else config.batch_size

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

print("‚úÖ Donn√©es charg√©es avec succ√®s!")
print(f"üìä R√©sum√©:")
print(f"   - Mode: {'RAPIDE (Test)' if FAST_MODE else 'COMPLET (Entra√Ænement)'}")
print(f"   - Batch size: {batch_size}")
print(f"   - √âchantillons d'entra√Ænement: {len(train_dataset):,}")
print(f"   - Classes: {num_classes}")

# Cr√©ation du mod√®le avec gel des 5 premi√®res couches
model = MobileNetV3AudioClassifier(num_classes=num_classes, freeze_layers=5).to(device)

# Affichage des informations du mod√®le
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"\nüìä Informations du mod√®le:")
print(f"   - Param√®tres totaux: {total_params:,}")
print(f"   - Param√®tres entra√Ænables: {trainable_params:,}")
print(f"   - Param√®tres gel√©s: {frozen_params:,}")
print(f"   - Pourcentage gel√©: {frozen_params/total_params*100:.1f}%")
print(f"   - Taux de dropout: {config.feature_dropout}")


üîß Chargement des datasets...
üìä training: 105829 samples, 35 classes
üìä validation: 9981 samples, 35 classes
üìä testing: 11005 samples, 35 classes
‚úÖ Donn√©es charg√©es avec succ√®s!
üìä R√©sum√©:
   - Mode: RAPIDE (Test)
   - Batch size: 32
   - √âchantillons d'entra√Ænement: 105,829
   - Classes: 35
‚ùÑÔ∏è 5 couches gel√©es
‚úÖ MobileNetV3 Large cr√©√© avec 5 couches gel√©es

üìä Informations du mod√®le:
   - Param√®tres totaux: 4,246,867
   - Param√®tres entra√Ænables: 4,246,003
   - Param√®tres gel√©s: 864
   - Pourcentage gel√©: 0.0%
   - Taux de dropout: 0.4


In [7]:
# Cell 7: ENTRA√éNEMENT AVEC EARLY STOPPING ET SAUVEGARDE
print("üöÄ D√©marrage de l'entra√Ænement avec MobileNetV3...")

# Fonction de perte avec lissage des labels
criterion = LabelSmoothCrossEntropyLoss(smoothing=config.label_smoothing)

# Optimizer avec weight decay (uniquement les param√®tres entra√Ænables)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.999)
)

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs)

# Early Stopping am√©lior√©
early_stopping = EarlyStopping(patience=10, min_delta=0.002, restore_best_weights=True)

# Historique d'entra√Ænement
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'learning_rates': []
}

best_val_acc = 0
best_epoch = 0

print(f"\nüéØ D√©marrage de l'entra√Ænement pour {config.num_epochs} epochs...")
print(f"‚è∞ Early Stopping: Patience = {early_stopping.patience} epochs")

for epoch in range(config.num_epochs):
    print(f"\nüìç Epoch {epoch+1}/{config.num_epochs}")
    
    # Phase d'entra√Ænement
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Phase de validation
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Mise √† jour de l'historique
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['learning_rates'].append(optimizer.param_groups[0]['lr'])
    
    print(f"  Train ‚Üí Loss: {train_loss:.4f} | Accuracy: {train_acc:.2f}%")
    print(f"  Val   ‚Üí Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # V√©rification Early Stopping
    if early_stopping(val_loss, val_acc, model, epoch+1):
        print(f"\nüõë Early Stopping d√©clench√© √† l'epoch {epoch + 1}!")
        break
    
    # Sauvegarde du meilleur mod√®le (backup suppl√©mentaire)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'history': history
        }, 'best_mobilenetv3_model_complete.pth')
        print(f"  üíæ Mod√®le complet sauvegard√©! (Accuracy: {val_acc:.2f}%)")
    
    # Ajustement du learning rate
    scheduler.step()

# Restauration du meilleur mod√®le
early_stopping.restore_best_model(model)

print(f"\nüéâ Entra√Ænement termin√©!")
print(f"üèÜ Meilleure accuracy de validation: {best_val_acc:.2f}% √† l'epoch {best_epoch}")

# Sauvegarde finale du mod√®le restaur√©
torch.save(model.state_dict(), 'final_mobilenetv3_model.pth')
print("üíæ Mod√®le final sauvegard√©!")


üöÄ D√©marrage de l'entra√Ænement avec MobileNetV3...

üéØ D√©marrage de l'entra√Ænement pour 100 epochs...
‚è∞ Early Stopping: Patience = 10 epochs

üìç Epoch 1/100


üöÄ Training:   0%|          | 9/3308 [00:53<5:27:34,  5.96s/it, Loss=4.7712, Accuracy=3.82%]


KeyboardInterrupt: 

In [None]:
# Cell 8: √âvaluation et m√©triques d√©taill√©es
print("üìä √âvaluation du mod√®le final...")

# √âvaluation sur le test set
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"üéØ Performance sur le test set:")
print(f"   - Loss: {test_loss:.4f}")
print(f"   - Accuracy: {test_acc:.2f}%")

# M√©triques d√©taill√©es
model.eval()
all_preds = []
all_targets = []
all_probabilities = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        probabilities = F.softmax(output, dim=1)
        preds = output.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

# Calcul des m√©triques
accuracy = accuracy_score(all_targets, all_preds)
f1 = f1_score(all_targets, all_preds, average='macro')
precision = precision_score(all_targets, all_preds, average='macro')
recall = recall_score(all_targets, all_preds, average='macro')

print(f"\nüìà M√©triques d√©taill√©es:")
print(f"   - Accuracy: {accuracy:.4f}")
print(f"   - F1-Score: {f1:.4f}")
print(f"   - Precision: {precision:.4f}")
print(f"   - Recall: {recall:.4f}")

# Rapport de classification d√©taill√©
print(f"\nüìã Rapport de classification:")
print(classification_report(all_targets, all_preds, target_names=train_dataset.labels))


In [None]:
# Cell 9: Export ONNX du meilleur mod√®le
print("\nüì§ Export du mod√®le en format ONNX...")

# Assurer que le mod√®le est en mode √©valuation
model.eval()

# Cr√©er un exemple d'input
dummy_input = torch.randn(1, 3, config.n_mels, config.target_length // config.hop_length + 1).to(device)

# Export ONNX
onnx_path = "mobilenetv3_speech_commands.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    verbose=False
)

print(f"‚úÖ Mod√®le export√© avec succ√®s: {onnx_path}")
print(f"üìä Taille du fichier ONNX: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB")

# V√©rification de l'export ONNX
try:
    import onnx
    import onnxruntime as ort
    
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("‚úÖ Mod√®le ONNX v√©rifi√© avec succ√®s!")
    
    # Test avec ONNX Runtime
    ort_session = ort.InferenceSession(onnx_path)
    
    # Test de pr√©diction
    dummy_np = dummy_input.cpu().numpy()
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_np}
    ort_outs = ort_session.run(None, ort_inputs)
    
    print("‚úÖ ONNX Runtime test r√©ussi!")
    print(f"üìã Informations ONNX:")
    print(f"   - Input shape: {ort_session.get_inputs()[0].shape}")
    print(f"   - Output shape: {ort_session.get_outputs()[0].shape}")
    print(f"   - Opset version: {onnx_model.opset_import[0].version}")
    
except ImportError:
    print("‚ö†Ô∏è ONNX non install√©, impossible de v√©rifier le mod√®le")
    print("üí° Installer avec: pip install onnx onnxruntime")


In [None]:
# Cell 10: Visualisations et sauvegarde finale
print("\nüé® Cr√©ation des visualisations...")

# Graphiques de performance
plt.figure(figsize=(15, 5))

# Loss
plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train Loss', linewidth=2, alpha=0.8)
plt.plot(history['val_loss'], label='Val Loss', linewidth=2, alpha=0.8)
plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch})')
plt.title('√âvolution des Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Accuracy
plt.subplot(1, 3, 2)
plt.plot(history['train_acc'], label='Train Accuracy', linewidth=2, alpha=0.8)
plt.plot(history['val_acc'], label='Val Accuracy', linewidth=2, alpha=0.8)
plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch})')
plt.title('√âvolution de l\'Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Learning Rate
plt.subplot(1, 3, 3)
plt.plot(history['learning_rates'], label='Learning Rate', linewidth=2, color='purple', alpha=0.8)
plt.title('√âvolution du Learning Rate')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_metrics_mobilenetv3.png', dpi=300, bbox_inches='tight')
plt.show()

# Matrice de confusion
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(12, 10))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(f'Matrice de Confusion - MobileNetV3\nAccuracy: {test_acc:.2f}%')
plt.colorbar()

# Afficher seulement quelques labels pour la lisibilit√©
if len(train_dataset.labels) > 20:
    tick_marks = np.arange(0, len(train_dataset.labels), max(1, len(train_dataset.labels)//20))
    plt.xticks(tick_marks, [train_dataset.labels[i] for i in tick_marks], rotation=45)
    plt.yticks(tick_marks, [train_dataset.labels[i] for i in tick_marks])
else:
    tick_marks = np.arange(len(train_dataset.labels))
    plt.xticks(tick_marks, train_dataset.labels, rotation=45)
    plt.yticks(tick_marks, train_dataset.labels)

plt.xlabel('Pr√©diction')
plt.ylabel('V√©rit√© terrain')
plt.tight_layout()
plt.savefig('confusion_matrix_mobilenetv3.png', dpi=300, bbox_inches='tight')
plt.show()

# Sauvegarde de l'historique
import json
history_serializable = {k: [float(x) if isinstance(x, (np.floating, float)) else x for x in v] 
                       for k, v in history.items()}

with open('training_history.json', 'w') as f:
    json.dump(history_serializable, f, indent=2)

print("\n‚úÖ Toutes les op√©rations sont termin√©es!")
print("üìÅ Fichiers cr√©√©s:")
print(f"   - best_mobilenetv3_model_complete.pth (mod√®le complet avec historique)")
print(f"   - final_mobilenetv3_model.pth (mod√®le final)")
print(f"   - mobilenetv3_speech_commands.onnx (mod√®le ONNX)")
print(f"   - training_metrics_mobilenetv3.png (graphiques)")
print(f"   - confusion_matrix_mobilenetv3.png (matrice de confusion)")
print(f"   - training_history.json (historique d'entra√Ænement)")

print(f"\nüéØ R√©sultats finaux:")
print(f"   - Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"   - Test Accuracy: {test_acc:.2f}%")
print(f"   - Best Epoch: {best_epoch}")
print(f"   - Total Epochs: {epoch + 1}")