In [1]:
# =============================================================================
# CELL 1: Configuration Optimis√©e
# =============================================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchaudio
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import random
from sklearn.metrics import confusion_matrix, classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight
import timm
import warnings
warnings.filterwarnings('ignore')

# Configuration GPU optimis√©e
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
print(f"üî• Using device: {device}")

# Seed pour reproductibilit√©
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Configuration optimis√©e
class ModelConfig:
    def __init__(self):
        # Dropout configuration - OPTIMIS√â
        self.input_dropout = 0.15
        self.feature_dropout = 0.35  # R√©duit pour √©viter l'underfitting
        self.classifier_dropout = 0.4
        
        # Training configuration - OPTIMIS√â
        self.weight_decay = 0.005  # R√©duit pour plus de flexibilit√©
        self.label_smoothing = 0.1
        self.learning_rate = 5e-4  # Learning rate initial r√©duit
        self.batch_size = 32  # Augment√© pour stabilit√©
        self.num_epochs = 100
        self.warmup_epochs = 5  # NOUVEAU: Warm-up du LR
        
        # Audio configuration - OPTIMIS√â
        self.n_mels = 128
        self.n_fft = 2048  # Augment√© pour meilleure r√©solution
        self.hop_length = 512  # Ajust√© proportionnellement
        self.sample_rate = 16000
        self.target_length = 16000
        
        # Mixed Precision Training
        self.use_amp = True  # NOUVEAU: Automatic Mixed Precision
        
        # Data augmentation strength
        self.aug_strength = 0.6  # NOUVEAU: Contr√¥le de l'augmentation

config = ModelConfig()

print("\n‚öôÔ∏è Configuration Optimis√©e:")
for key, value in config.__dict__.items():
    print(f"   {key}: {value}")


üî• Using device: cuda

‚öôÔ∏è Configuration Optimis√©e:
   input_dropout: 0.15
   feature_dropout: 0.35
   classifier_dropout: 0.4
   weight_decay: 0.005
   label_smoothing: 0.1
   learning_rate: 0.0005
   batch_size: 32
   num_epochs: 100
   warmup_epochs: 5
   n_mels: 128
   n_fft: 2048
   hop_length: 512
   sample_rate: 16000
   target_length: 16000
   use_amp: True
   aug_strength: 0.6


In [2]:
# =============================================================================
# CELL 2: Early Stopping et Augmentation Am√©lior√©s
# =============================================================================
class EarlyStopping:
    def __init__(self, patience=15, min_delta=0.0005, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.best_acc = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None
        self.best_epoch = 0
        
    def __call__(self, val_loss, val_acc, model, epoch):
        # Crit√®re combin√©: am√©lioration de la loss OU de l'accuracy
        improved = False
        
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_acc = val_acc
            self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            self.best_epoch = epoch
            improved = True
        else:
            # Am√©lioration si la loss diminue OU l'accuracy augmente significativement
            if val_loss < self.best_loss - self.min_delta or val_acc > self.best_acc + 0.5:
                self.best_loss = val_loss
                self.best_acc = val_acc
                self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                self.best_epoch = epoch
                self.counter = 0
                improved = True
            else:
                self.counter += 1
        
        if improved:
            print("‚úÖ New best model saved!")
        else:
            print(f"üìà EarlyStopping counter: {self.counter}/{self.patience}")
            
        if self.counter >= self.patience:
            self.early_stop = True
            print("üõë Early stopping triggered!")
            
        return self.early_stop
    
    def restore_best_model(self, model):
        if self.restore_best_weights and self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)
            print(f"‚úÖ Best weights restored from epoch {self.best_epoch}!")
            print(f"üèÜ Best validation accuracy: {self.best_acc:.2f}%")

class AudioAugmentation:
    """Augmentation audio plus contr√¥l√©e"""
    def __init__(self, strength=0.6):
        self.strength = strength
        self.sample_rate = config.sample_rate
        
    def __call__(self, waveform):
        if random.random() > 1 - self.strength:
            # Time stretching (plus subtil)
            if random.random() > 0.5:
                rate = random.uniform(0.9, 1.1)
                try:
                    waveform_stretched = F.interpolate(
                        waveform.unsqueeze(0),
                        size=int(waveform.shape[-1] * rate),
                        mode='linear',
                        align_corners=False
                    ).squeeze(0)
                    
                    # Recadrage/padding
                    if waveform_stretched.shape[-1] > waveform.shape[-1]:
                        waveform = waveform_stretched[..., :waveform.shape[-1]]
                    else:
                        waveform = F.pad(waveform_stretched, 
                                       (0, waveform.shape[-1] - waveform_stretched.shape[-1]))
                except:
                    pass
            
            # Gaussian noise (plus l√©ger)
            if random.random() > 0.6:
                noise_level = random.uniform(0.001, 0.005)
                noise = torch.randn_like(waveform) * noise_level
                waveform = waveform + noise
            
            # Random gain (plus subtil)
            if random.random() > 0.4:
                gain = random.uniform(0.85, 1.15)
                waveform = waveform * gain
            
            # Polarity inversion
            if random.random() > 0.8:
                waveform = -waveform
                
        return waveform


In [3]:
# =============================================================================
# CELL 3: Dataset avec Weighted Sampling
# =============================================================================
class SpeechCommandsDataset(Dataset):
    def __init__(self, subset='training', apply_augmentation=False):
        self.data_path = r"D:\voice_processing\data\SpeechCommands\speech_commands_v0.02"
        
        # Chargement des fichiers selon le subset
        if subset == 'training':
            with open(os.path.join(self.data_path, 'validation_list.txt'), 'r') as f:
                val_files = set(f.read().splitlines())
            with open(os.path.join(self.data_path, 'testing_list.txt'), 'r') as f:
                test_files = set(f.read().splitlines())
            
            all_files = []
            for class_name in os.listdir(self.data_path):
                class_path = os.path.join(self.data_path, class_name)
                if os.path.isdir(class_path) and not class_name.startswith('_'):
                    for file in os.listdir(class_path):
                        if file.endswith('.wav'):
                            rel_path = os.path.join(class_name, file)
                            if rel_path not in val_files and rel_path not in test_files:
                                all_files.append((os.path.join(class_path, file), class_name))
            
            self.samples = all_files
            
        elif subset == 'validation':
            with open(os.path.join(self.data_path, 'validation_list.txt'), 'r') as f:
                val_files = f.read().splitlines()
            
            self.samples = []
            for rel_path in val_files:
                class_name = rel_path.split('/')[0]
                full_path = os.path.join(self.data_path, rel_path)
                if os.path.exists(full_path):
                    self.samples.append((full_path, class_name))
                    
        elif subset == 'testing':
            with open(os.path.join(self.data_path, 'testing_list.txt'), 'r') as f:
                test_files = f.read().splitlines()
            
            self.samples = []
            for rel_path in test_files:
                class_name = rel_path.split('/')[0]
                full_path = os.path.join(self.data_path, rel_path)
                if os.path.exists(full_path):
                    self.samples.append((full_path, class_name))
        
        # Classes et mapping
        self.labels = sorted(list(set([label for _, label in self.samples])))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
        self.apply_augmentation = apply_augmentation
        self.augment = AudioAugmentation(strength=config.aug_strength)
        
        print(f"üìä {subset}: {len(self.samples)} samples, {len(self.labels)} classes")
        
        # Transformations audio
        self.transform = T.MelSpectrogram(
            sample_rate=config.sample_rate,
            n_fft=config.n_fft,
            hop_length=config.hop_length,
            n_mels=config.n_mels,
            f_min=20,
            f_max=8000,
            power=2.0
        )
        self.to_db = T.AmplitudeToDB(stype='power', top_db=80)
        
        # SpecAugment avec param√®tres optimis√©s
        self.time_mask = T.TimeMasking(time_mask_param=15)
        self.freq_mask = T.FrequencyMasking(freq_mask_param=8)
    
    def get_class_weights(self):
        """Calcule les poids des classes pour l'√©quilibrage"""
        labels = [self.label_to_idx[label] for _, label in self.samples]
        class_weights = compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )
        return torch.FloatTensor(class_weights)
    
    def get_sample_weights(self):
        """Retourne les poids pour chaque √©chantillon"""
        class_weights = self.get_class_weights()
        sample_weights = []
        for _, label in self.samples:
            class_idx = self.label_to_idx[label]
            sample_weights.append(class_weights[class_idx])
        return sample_weights
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        # Charger l'audio
        waveform, sample_rate = torchaudio.load(file_path)
        
        # Resampling si n√©cessaire
        if sample_rate != config.sample_rate:
            resampler = T.Resample(sample_rate, config.sample_rate)
            waveform = resampler(waveform)
        
        # Mono si st√©r√©o
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Normalisation robuste
        waveform = waveform / (waveform.abs().max() + 1e-8)
        
        # Augmentation
        if self.apply_augmentation and random.random() > 0.3:
            waveform = self.augment(waveform)
        
        # Padding/truncation
        current_length = waveform.shape[1]
        if current_length < config.target_length:
            waveform = F.pad(waveform, (0, config.target_length - current_length))
        else:
            if self.apply_augmentation:
                start = random.randint(0, current_length - config.target_length)
            else:
                start = (current_length - config.target_length) // 2
            waveform = waveform[:, start:start + config.target_length]
        
        # Mel Spectrogram
        mel_spec = self.transform(waveform)
        mel_spec_db = self.to_db(mel_spec)
        
        # SpecAugment pour l'entra√Ænement
        if self.apply_augmentation and random.random() > 0.5:
            mel_spec_db = self.time_mask(mel_spec_db)
            if random.random() > 0.5:
                mel_spec_db = self.freq_mask(mel_spec_db)
        
        # Normalisation par instance
        mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
        
        # 3 channels pour MobileNetV3
        mel_spec_3ch = mel_spec_db.repeat(3, 1, 1)
        
        label_idx = self.label_to_idx[label]
        
        return mel_spec_3ch, label_idx

In [4]:
# =============================================================================
# CELL 4: Mod√®le Optimis√© avec Attention
# =============================================================================
class SqueezeExcitation(nn.Module):
    """Module SE pour am√©liorer les features"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        
    def forward(self, x):
        b, c, _, _ = x.size()
        y = F.adaptive_avg_pool2d(x, 1).view(b, c)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1)
        return x * y

class MobileNetV3AudioClassifier(nn.Module):
    def __init__(self, num_classes=35, freeze_layers=3):
        super().__init__()
        
        # Backbone MobileNetV3 Large pr√©-entra√Æn√©
        self.backbone = timm.create_model(
            'mobilenetv3_large_100',
            pretrained=True,
            in_chans=3,
            num_classes=0,  # Pas de classifieur
            global_pool=''
        )
        
        # Obtenir la dimension des features
        with torch.no_grad():
            dummy = torch.randn(1, 3, 128, 32)
            features = self.backbone(dummy)
            feature_dim = features.shape[1]
        
        # Geler les premi√®res couches
        self._freeze_layers(freeze_layers)
        
        # Attention suppl√©mentaire
        self.se = SqueezeExcitation(feature_dim, reduction=16)
        
        # Pooling adaptatif
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Classifieur optimis√©
        self.classifier = nn.Sequential(
            nn.Dropout(config.classifier_dropout),
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
        print(f"‚úÖ MobileNetV3 cr√©√© avec {freeze_layers} couches gel√©es")
        
    def _freeze_layers(self, num_layers):
        """Geler les premi√®res couches du backbone"""
        layers_frozen = 0
        
        for name, param in self.backbone.named_parameters():
            if 'conv_stem' in name or 'bn1' in name or 'blocks.0' in name or 'blocks.1' in name:
                if layers_frozen < num_layers:
                    param.requires_grad = False
                    layers_frozen += 1
        
        print(f"‚ùÑÔ∏è {layers_frozen} couches initiales gel√©es")
        
    def forward(self, x):
        # Extraction des features
        x = self.backbone(x)
        
        # Attention
        x = self.se(x)
        
        # Pooling et classification
        x = self.global_pool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        
        return x


In [5]:
# =============================================================================
# CELL 5: Training avec Mixed Precision et Warm-up
# =============================================================================
class LabelSmoothCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
    
    def forward(self, x, target):
        log_probs = F.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        
        if self.weight is not None:
            nll_loss = nll_loss * self.weight[target]
        
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

def get_lr_scheduler_with_warmup(optimizer, num_warmup_epochs, num_epochs):
    """Learning rate scheduler avec warm-up"""
    def lr_lambda(epoch):
        if epoch < num_warmup_epochs:
            # Warm-up lin√©aire
            return (epoch + 1) / num_warmup_epochs
        else:
            # Cosine annealing apr√®s warm-up
            progress = (epoch - num_warmup_epochs) / (num_epochs - num_warmup_epochs)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def train_epoch(model, loader, criterion, optimizer, device, scaler=None):
    """Entra√Ænement pour une epoch avec AMP"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='üöÄ Training')
    for data, target in pbar:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Mixed Precision Training
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        accuracy = 100. * correct / total
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Accuracy': f'{accuracy:.2f}%'
        })
    
    return total_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
    """Validation"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in tqdm(loader, desc='üìä Validation'):
            data, target = data.to(device), target.to(device)
            
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

print("‚úÖ Toutes les fonctions optimis√©es sont charg√©es!")
print("\nüí° Am√©liorations principales:")
print("   - Weighted sampling pour √©quilibrage des classes")
print("   - Mixed Precision Training (AMP)")
print("   - Learning rate warm-up")
print("   - Attention mechanism (Squeeze-Excitation)")
print("   - Augmentation audio contr√¥l√©e")
print("   - Early stopping am√©lior√©")
print("   - Batch size augment√© pour stabilit√©")

‚úÖ Toutes les fonctions optimis√©es sont charg√©es!

üí° Am√©liorations principales:
   - Weighted sampling pour √©quilibrage des classes
   - Mixed Precision Training (AMP)
   - Learning rate warm-up
   - Attention mechanism (Squeeze-Excitation)
   - Augmentation audio contr√¥l√©e
   - Early stopping am√©lior√©
   - Batch size augment√© pour stabilit√©


In [6]:
# =============================================================================
# CELL 6: Chargement des Donn√©es avec Weighted Sampling
# =============================================================================
print("üîß Chargement des datasets optimis√©s...")

# Chargement des datasets
train_dataset = SpeechCommandsDataset(subset='training', apply_augmentation=True)
val_dataset = SpeechCommandsDataset(subset='validation', apply_augmentation=False)
test_dataset = SpeechCommandsDataset(subset='testing', apply_augmentation=False)

num_classes = len(train_dataset.labels)

# Calcul des poids de classe
class_weights = train_dataset.get_class_weights().to(device)
print(f"\n‚öñÔ∏è Poids des classes calcul√©s pour l'√©quilibrage")

# Weighted Random Sampler pour √©quilibrer les classes
sample_weights = train_dataset.get_sample_weights()
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# DataLoaders avec sampler √©quilibr√©
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size,
    sampler=sampler,  # Utilisation du sampler au lieu de shuffle
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

print("‚úÖ Donn√©es charg√©es avec sampler √©quilibr√©!")
print(f"üìä R√©sum√©:")
print(f"   - Batch size: {config.batch_size}")
print(f"   - √âchantillons d'entra√Ænement: {len(train_dataset):,}")
print(f"   - √âchantillons de validation: {len(val_dataset):,}")
print(f"   - √âchantillons de test: {len(test_dataset):,}")
print(f"   - Classes: {num_classes}")

# Cr√©ation du mod√®le optimis√©
model = MobileNetV3AudioClassifier(num_classes=num_classes, freeze_layers=3).to(device)

# Affichage des informations du mod√®le
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params

print(f"\nüìä Informations du mod√®le:")
print(f"   - Param√®tres totaux: {total_params:,}")
print(f"   - Param√®tres entra√Ænables: {trainable_params:,}")
print(f"   - Param√®tres gel√©s: {frozen_params:,}")
print(f"   - Pourcentage gel√©: {frozen_params/total_params*100:.1f}%")


üîß Chargement des datasets optimis√©s...
üìä training: 105829 samples, 35 classes
üìä validation: 9981 samples, 35 classes
üìä testing: 11005 samples, 35 classes

‚öñÔ∏è Poids des classes calcul√©s pour l'√©quilibrage
‚úÖ Donn√©es charg√©es avec sampler √©quilibr√©!
üìä R√©sum√©:
   - Batch size: 32
   - √âchantillons d'entra√Ænement: 105,829
   - √âchantillons de validation: 9,981
   - √âchantillons de test: 11,005
   - Classes: 35
‚ùÑÔ∏è 3 couches initiales gel√©es
‚úÖ MobileNetV3 cr√©√© avec 3 couches gel√©es

üìä Informations du mod√®le:
   - Param√®tres totaux: 5,083,043
   - Param√®tres entra√Ænables: 5,082,579
   - Param√®tres gel√©s: 464
   - Pourcentage gel√©: 0.0%


In [None]:
# =============================================================================
# CELL 7: ENTRA√éNEMENT OPTIMIS√â
# =============================================================================
print("\nüöÄ D√©marrage de l'entra√Ænement optimis√©...\n")

# Fonction de perte avec poids de classe
criterion = LabelSmoothCrossEntropyLoss(
    smoothing=config.label_smoothing,
    weight=class_weights
)

# Optimizer avec param√®tres group√©s
param_groups = [
    {'params': model.backbone.parameters(), 'lr': config.learning_rate * 0.1},  # LR plus faible pour backbone
    {'params': model.se.parameters(), 'lr': config.learning_rate},
    {'params': model.classifier.parameters(), 'lr': config.learning_rate}
]

optimizer = torch.optim.AdamW(
    param_groups,
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.999)
)

# Scheduler avec warm-up
scheduler = get_lr_scheduler_with_warmup(
    optimizer,
    num_warmup_epochs=config.warmup_epochs,
    num_epochs=config.num_epochs
)

# Mixed Precision Scaler
scaler = torch.cuda.amp.GradScaler() if config.use_amp else None

# Early Stopping optimis√©
early_stopping = EarlyStopping(patience=15, min_delta=0.0005, restore_best_weights=True)

# Historique d'entra√Ænement
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'learning_rates': []
}

best_val_acc = 0
best_epoch = 0

print(f"üéØ Configuration d'entra√Ænement:")
print(f"   - Epochs: {config.num_epochs}")
print(f"   - Warm-up epochs: {config.warmup_epochs}")
print(f"   - Early Stopping patience: {early_stopping.patience}")
print(f"   - Mixed Precision: {'Activ√©' if config.use_amp else 'D√©sactiv√©'}")
print(f"   - Weighted Sampling: Activ√©")
print(f"   - Class Weighting: Activ√©")
print()

for epoch in range(config.num_epochs):
    print(f"\n{'='*70}")
    print(f"üìÖ Epoch {epoch+1}/{config.num_epochs}")
    print(f"{'='*70}")
    
    # Learning rates actuels
    current_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
    print(f"üìà Learning Rates: Backbone={current_lrs[0]:.2e}, Head={current_lrs[2]:.2e}")
    
    # Phase d'entra√Ænement
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, scaler)
    
    # Phase de validation
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Mise √† jour de l'historique
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['learning_rates'].append(current_lrs[2])  # LR du head
    
    # Affichage des r√©sultats
    print(f"\nüìä R√©sultats Epoch {epoch+1}:")
    print(f"   Train ‚Üí Loss: {train_loss:.4f} | Accuracy: {train_acc:.2f}%")
    print(f"   Val   ‚Üí Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")
    
    # Calcul de l'am√©lioration
    if len(history['val_acc']) > 1:
        acc_diff = val_acc - history['val_acc'][-2]
        loss_diff = val_loss - history['val_loss'][-2]
        print(f"   Œî Accuracy: {acc_diff:+.2f}% | Œî Loss: {loss_diff:+.4f}")
    
    # V√©rification Early Stopping
    if early_stopping(val_loss, val_acc, model, epoch+1):
        print(f"\nüõë Early Stopping d√©clench√© √† l'epoch {epoch + 1}!")
        break
    
    # Sauvegarde du meilleur mod√®le
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': config.__dict__,
            'history': history,
            'class_weights': class_weights.cpu()
        }, 'best_mobilenetv3_optimized.pth')
        print(f"   üíæ Meilleur mod√®le sauvegard√©! (Accuracy: {val_acc:.2f}%)")
    
    # Ajustement du learning rate
    scheduler.step()
    
    # D√©gel progressif (apr√®s warm-up)
    if epoch == config.warmup_epochs * 2:
        print("\nüîì D√©gel de couches suppl√©mentaires du backbone...")
        for name, param in model.backbone.named_parameters():
            if 'blocks.2' in name or 'blocks.3' in name:
                param.requires_grad = True
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"   Nouveaux param√®tres entra√Ænables: {trainable:,}")

# Restauration du meilleur mod√®le
early_stopping.restore_best_model(model)

print(f"\n{'='*70}")
print(f"üéâ ENTRA√éNEMENT TERMIN√â!")
print(f"{'='*70}")
print(f"üèÜ Meilleure accuracy de validation: {best_val_acc:.2f}%")
print(f"üìÖ Meilleure epoch: {best_epoch}")
print(f"üìà Epochs totales: {epoch + 1}")

# Sauvegarde finale
torch.save(model.state_dict(), 'final_mobilenetv3_optimized.pth')
print("üíæ Mod√®le final sauvegard√©!")



üöÄ D√©marrage de l'entra√Ænement optimis√©...

üéØ Configuration d'entra√Ænement:
   - Epochs: 100
   - Warm-up epochs: 5
   - Early Stopping patience: 15
   - Mixed Precision: Activ√©
   - Weighted Sampling: Activ√©
   - Class Weighting: Activ√©


üìÖ Epoch 1/100
üìà Learning Rates: Backbone=1.00e-05, Head=1.00e-04


üöÄ Training:   0%|          | 0/3308 [00:00<?, ?it/s]

In [None]:
# =============================================================================
# CELL 8: √âVALUATION D√âTAILL√âE
# =============================================================================
print("\n" + "="*70)
print("üìä √âVALUATION FINALE DU MOD√àLE")
print("="*70 + "\n")

# √âvaluation sur le test set
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"üéØ Performance sur le test set:")
print(f"   - Loss: {test_loss:.4f}")
print(f"   - Accuracy: {test_acc:.2f}%")

# M√©triques d√©taill√©es par classe
model.eval()
all_preds = []
all_targets = []
all_probs = []

with torch.no_grad():
    for data, target in tqdm(test_loader, desc='üîç Analyse d√©taill√©e'):
        data, target = data.to(device), target.to(device)
        
        with torch.cuda.amp.autocast():
            output = model(data)
        
        probs = F.softmax(output, dim=1)
        preds = output.argmax(dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calcul des m√©triques
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

accuracy = accuracy_score(all_targets, all_preds)
precision, recall, f1, support = precision_recall_fscore_support(
    all_targets, all_preds, average=None
)

print(f"\nüìà M√©triques Globales:")
print(f"   - Accuracy: {accuracy*100:.2f}%")
print(f"   - F1-Score (macro): {np.mean(f1):.4f}")
print(f"   - Precision (macro): {np.mean(precision):.4f}")
print(f"   - Recall (macro): {np.mean(recall):.4f}")

# Top-5 et Bottom-5 classes
class_f1 = list(zip(train_dataset.labels, f1, support))
class_f1_sorted = sorted(class_f1, key=lambda x: x[1], reverse=True)

print(f"\nüèÜ Top-5 Classes (meilleur F1-Score):")
for i, (label, score, supp) in enumerate(class_f1_sorted[:5], 1):
    print(f"   {i}. {label:15s} ‚Üí F1: {score:.4f} (n={supp})")

print(f"\n‚ö†Ô∏è Bottom-5 Classes (F1-Score le plus faible):")
for i, (label, score, supp) in enumerate(class_f1_sorted[-5:], 1):
    print(f"   {i}. {label:15s} ‚Üí F1: {score:.4f} (n={supp})")

# Analyse de l'√©quilibre
print(f"\n‚öñÔ∏è Analyse de l'√©quilibre des performances:")
f1_std = np.std(f1)
f1_range = np.max(f1) - np.min(f1)
print(f"   - √âcart-type F1: {f1_std:.4f}")
print(f"   - Range F1: {f1_range:.4f}")
print(f"   - Classes avec F1 > 0.90: {sum(f1 > 0.90)}/{len(f1)}")
print(f"   - Classes avec F1 < 0.70: {sum(f1 < 0.70)}/{len(f1)}")

# Rapport de classification d√©taill√©
print(f"\nüìã Rapport de Classification Complet:")
print(classification_report(
    all_targets, all_preds,
    target_names=train_dataset.labels,
    digits=3
))

print("\n‚úÖ √âvaluation termin√©e!")

In [None]:
# Cell 9: Export ONNX du meilleur mod√®le
print("\nüì§ Export du mod√®le en format ONNX...")

# Assurer que le mod√®le est en mode √©valuation
model.eval()

# Cr√©er un exemple d'input
dummy_input = torch.randn(1, 3, config.n_mels, config.target_length // config.hop_length + 1).to(device)

# Export ONNX
onnx_path = "mobilenetv3_speech_commands.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    verbose=False
)

print(f"‚úÖ Mod√®le export√© avec succ√®s: {onnx_path}")
print(f"üìä Taille du fichier ONNX: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB")

# V√©rification de l'export ONNX
try:
    import onnx
    import onnxruntime as ort
    
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("‚úÖ Mod√®le ONNX v√©rifi√© avec succ√®s!")
    
    # Test avec ONNX Runtime
    ort_session = ort.InferenceSession(onnx_path)
    
    # Test de pr√©diction
    dummy_np = dummy_input.cpu().numpy()
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_np}
    ort_outs = ort_session.run(None, ort_inputs)
    
    print("‚úÖ ONNX Runtime test r√©ussi!")
    print(f"üìã Informations ONNX:")
    print(f"   - Input shape: {ort_session.get_inputs()[0].shape}")
    print(f"   - Output shape: {ort_session.get_outputs()[0].shape}")
    print(f"   - Opset version: {onnx_model.opset_import[0].version}")
    
except ImportError:
    print("‚ö†Ô∏è ONNX non install√©, impossible de v√©rifier le mod√®le")
    print("üí° Installer avec: pip install onnx onnxruntime")


In [None]:
# Cell 10: Visualisations et sauvegarde finale
print("\nüé® Cr√©ation des visualisations...")

# Graphiques de performance
plt.figure(figsize=(15, 5))

# Loss
plt.subplot(1, 3, 1)
plt.plot(history['train_loss'], label='Train Loss', linewidth=2, alpha=0.8)
plt.plot(history['val_loss'], label='Val Loss', linewidth=2, alpha=0.8)
plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch})')
plt.title('√âvolution des Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Accuracy
plt.subplot(1, 3, 2)
plt.plot(history['train_acc'], label='Train Accuracy', linewidth=2, alpha=0.8)
plt.plot(history['val_acc'], label='Val Accuracy', linewidth=2, alpha=0.8)
plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch})')
plt.title('√âvolution de l\'Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Learning Rate
plt.subplot(1, 3, 3)
plt.plot(history['learning_rates'], label='Learning Rate', linewidth=2, color='purple', alpha=0.8)
plt.title('√âvolution du Learning Rate')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_metrics_mobilenetv3.png', dpi=300, bbox_inches='tight')
plt.show()

# Matrice de confusion
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(12, 10))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(f'Matrice de Confusion - MobileNetV3\nAccuracy: {test_acc:.2f}%')
plt.colorbar()

# Afficher seulement quelques labels pour la lisibilit√©
if len(train_dataset.labels) > 20:
    tick_marks = np.arange(0, len(train_dataset.labels), max(1, len(train_dataset.labels)//20))
    plt.xticks(tick_marks, [train_dataset.labels[i] for i in tick_marks], rotation=45)
    plt.yticks(tick_marks, [train_dataset.labels[i] for i in tick_marks])
else:
    tick_marks = np.arange(len(train_dataset.labels))
    plt.xticks(tick_marks, train_dataset.labels, rotation=45)
    plt.yticks(tick_marks, train_dataset.labels)

plt.xlabel('Pr√©diction')
plt.ylabel('V√©rit√© terrain')
plt.tight_layout()
plt.savefig('confusion_matrix_mobilenetv3.png', dpi=300, bbox_inches='tight')
plt.show()

# Sauvegarde de l'historique
import json
history_serializable = {k: [float(x) if isinstance(x, (np.floating, float)) else x for x in v] 
                       for k, v in history.items()}

with open('training_history.json', 'w') as f:
    json.dump(history_serializable, f, indent=2)

print("\n‚úÖ Toutes les op√©rations sont termin√©es!")
print("üìÅ Fichiers cr√©√©s:")
print(f"   - best_mobilenetv3_model_complete.pth (mod√®le complet avec historique)")
print(f"   - final_mobilenetv3_model.pth (mod√®le final)")
print(f"   - mobilenetv3_speech_commands.onnx (mod√®le ONNX)")
print(f"   - training_metrics_mobilenetv3.png (graphiques)")
print(f"   - confusion_matrix_mobilenetv3.png (matrice de confusion)")
print(f"   - training_history.json (historique d'entra√Ænement)")

print(f"\nüéØ R√©sultats finaux:")
print(f"   - Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"   - Test Accuracy: {test_acc:.2f}%")
print(f"   - Best Epoch: {best_epoch}")
print(f"   - Total Epochs: {epoch + 1}")