In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# 1. Data augmentation
# -----------------------------
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.RandomErasing(p=0.2)
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Importation des données 
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transforms, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transforms, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# -----------------------------
# 2. Modèle
# -----------------------------
def get_model_improved(device='cpu'):
    model = nn.Sequential(
    # Bloc 1
    nn.Conv2d(3, 32, kernel_size=4, padding=1, bias=True),
    nn.ReLU(),
    nn.BatchNorm2d(32),

    nn.Conv2d(32, 32, kernel_size=4, padding=1, bias=True),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.MaxPool2d(2, 2),
    nn.Dropout2d(0.2),

    # Bloc 2
    nn.Conv2d(32, 64, kernel_size=4, padding=1, bias=True),
    nn.ReLU(),
    nn.BatchNorm2d(64),

    nn.Conv2d(64, 64, kernel_size=4, padding=1, bias=True),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.MaxPool2d(2, 2),
    nn.Dropout2d(0.3),

    # Bloc 3
    nn.Conv2d(64, 128, kernel_size=4, padding=1, bias=True),
    nn.ReLU(),
    nn.BatchNorm2d(128),

    nn.Conv2d(128, 128, kernel_size=1, padding=0, bias=True),  # bottleneck
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.AdaptiveAvgPool2d((1,1)),

    nn.Flatten(),

    # Classifier compact
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(128, 10)
    ).to(device)



    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    return model, loss_fn, optimizer, scheduler


# -----------------------------
# 3. Entraînement + Early stopping + Tracking
# -----------------------------
def train_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, device='cpu', epochs=30, patience=5):
    best_val_loss = np.inf
    patience_counter = 0

    # Historique pour les courbes
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * X.size(0)
            _, predicted = torch.max(outputs, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

        train_acc = 100 * correct / total
        train_loss = running_loss / total

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                outputs = model(X)
                loss = loss_fn(outputs, y)
                val_loss += loss.item() * X.size(0)
                _, predicted = torch.max(outputs, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()

        val_acc = 100 * correct / total
        val_loss /= total

        scheduler.step(val_loss)

        # Enregistrement
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)

        print(f"Epoch [{epoch+1}/{epochs}] - "
              f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f}% | "
              f"Val loss: {val_loss:.4f} | Val acc: {val_acc:.2f}%")

        # Early stopping
        if val_loss > best_val_loss - 1e-3:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
        else:
            patience_counter = 0
            best_val_loss = val_loss

    # -----------------------------
    # 4. Affichage des courbes
    # -----------------------------
    epochs_range = range(1, len(history['train_loss']) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, history['train_loss'], label='Train Loss')
    plt.plot(epochs_range, history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss évolution')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, history['train_acc'], label='Train Accuracy')
    plt.plot(epochs_range, history['val_acc'], label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy ')
    plt.title('Accuracy évolution')
    plt.legend()

    plt.tight_layout()
    plt.show()


# -----------------------------
# 5. Lancement
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, loss_fn, optimizer, scheduler = get_model_improved(device)
train_model(model, loss_fn, optimizer, scheduler, train_loader, val_loader, device)

Définition de la méthode de pruning structurel

In [None]:
import torch
import torch.nn as nn
import copy
from collections import defaultdict
import numpy as np

def prune_model_by_filters_removal(model, top_indices_list, input_size=(3, 32, 32)):
    """
    VERSION CORRIGÉE: Supprime réellement les filtres pour réduire les paramètres.

    Args:
        model: modèle PyTorch nn.Sequential
        top_indices_list: [(layer_idx, out_ch, in_ch), ...] - filtres à CONSERVER
        input_size: taille d'entrée (channels, height, width)

    Returns:
        Modèle pruné avec moins de paramètres
    """
    device = next(model.parameters()).device
    model_copy = copy.deepcopy(model)

    # Organiser les filtres à conserver par couche
    filters_to_keep = defaultdict(set)
    for layer_idx, out_ch, in_ch in top_indices_list:
        filters_to_keep[layer_idx].add(out_ch)

    # Convertir en listes triées
    for layer_idx in filters_to_keep:
        filters_to_keep[layer_idx] = sorted(list(filters_to_keep[layer_idx]))

    new_layers = []
    conv_layers = [i for i, layer in enumerate(model_copy) if isinstance(layer, nn.Conv2d)]
    prev_kept_channels = None

    for i, layer in enumerate(model_copy):
        if isinstance(layer, nn.Conv2d):
            # Trouver l'index de cette couche conv
            conv_idx = conv_layers.index(i)

            # Déterminer les canaux de sortie à conserver
            if conv_idx in filters_to_keep:
                keep_out_channels = filters_to_keep[conv_idx]
            else:
                # Si pas spécifié, garder tous les canaux
                keep_out_channels = list(range(layer.out_channels))

            # Déterminer les canaux d'entrée
            if conv_idx == 0:
                # Première couche: garder tous les canaux d'entrée (RGB)
                new_in_channels = layer.in_channels
                keep_in_channels = list(range(layer.in_channels))
            else:
                # Couches suivantes: s'adapter aux canaux conservés de la couche précédente
                new_in_channels = len(prev_kept_channels) if prev_kept_channels else layer.in_channels
                keep_in_channels = prev_kept_channels if prev_kept_channels else list(range(layer.in_channels))

            new_out_channels = len(keep_out_channels)

            # Créer la nouvelle couche Conv2d RÉDUITE
            new_conv = nn.Conv2d(
                in_channels=new_in_channels,
                out_channels=new_out_channels,
                kernel_size=layer.kernel_size,
                stride=layer.stride,
                padding=layer.padding,
                bias=(layer.bias is not None)
            ).to(device)

            # Copier uniquement les poids des filtres conservés
            with torch.no_grad():
                for new_out_idx, old_out_idx in enumerate(keep_out_channels):
                    for new_in_idx, old_in_idx in enumerate(keep_in_channels):
                        new_conv.weight[new_out_idx, new_in_idx] = layer.weight[old_out_idx, old_in_idx]

                    if layer.bias is not None:
                        new_conv.bias[new_out_idx] = layer.bias[old_out_idx]

            new_layers.append(new_conv)
            prev_kept_channels = keep_out_channels

        elif isinstance(layer, nn.BatchNorm2d):
            # Adapter BatchNorm aux canaux conservés
            if prev_kept_channels is not None:
                new_bn = nn.BatchNorm2d(len(prev_kept_channels)).to(device)
                with torch.no_grad():
                    for new_idx, old_idx in enumerate(prev_kept_channels):
                        new_bn.weight[new_idx] = layer.weight[old_idx]
                        new_bn.bias[new_idx] = layer.bias[old_idx]
                        new_bn.running_mean[new_idx] = layer.running_mean[old_idx]
                        new_bn.running_var[new_idx] = layer.running_var[old_idx]
                new_layers.append(new_bn)
            else:
                new_layers.append(copy.deepcopy(layer))

        elif isinstance(layer, nn.Linear):
            # Ajuster la première couche Linear après Flatten
            if len([l for l in new_layers if isinstance(l, nn.Linear)]) == 0:
                # Calculer la nouvelle taille après les couches conv
                new_input_features = calculate_linear_input_size(new_layers, input_size)

                new_linear = nn.Linear(
                    in_features=new_input_features,
                    out_features=layer.out_features,
                    bias=(layer.bias is not None)
                ).to(device)

                # Copier les poids (avec adaptation de taille)
                with torch.no_grad():
                    min_features = min(new_input_features, layer.in_features)
                    new_linear.weight[:, :min_features] = layer.weight[:, :min_features]
                    if layer.bias is not None:
                        new_linear.bias[:] = layer.bias[:]

                new_layers.append(new_linear)
            else:
                # Autres couches Linear: copier tel quel
                new_layers.append(copy.deepcopy(layer))
        else:
            # Autres couches (ReLU, MaxPool, Dropout, Flatten)
            new_layers.append(copy.deepcopy(layer))

    return nn.Sequential(*new_layers).to(device)

def calculate_linear_input_size(conv_layers, input_size):
    """
    Calcule la taille d'entrée pour la première couche Linear
    en simulant une passe avant à travers les couches conv.
    """
    channels, h, w = input_size

    for layer in conv_layers:
        if isinstance(layer, nn.Conv2d):
            # Calculer nouvelle taille après convolution
            kernel_size = layer.kernel_size[0] if isinstance(layer.kernel_size, tuple) else layer.kernel_size
            stride = layer.stride[0] if isinstance(layer.stride, tuple) else layer.stride
            padding = layer.padding[0] if isinstance(layer.padding, tuple) else layer.padding

            h = (h + 2*padding - kernel_size) // stride + 1
            w = (w + 2*padding - kernel_size) // stride + 1
            channels = layer.out_channels

        elif isinstance(layer, nn.MaxPool2d):
            kernel_size = layer.kernel_size
            stride = layer.stride if layer.stride else kernel_size
            h = h // stride
            w = w // stride

    return channels * h * w

def structured_channel_pruning(model, importance_dict, keep_ratio=0.5):
    """
    Pruning structuré: supprime des canaux entiers basé sur leur importance moyenne.
    Plus efficace pour réduire réellement les paramètres.

    Args:
        model: modèle à pruner
        importance_dict: dict[layer][out_ch][in_ch] = importance
        keep_ratio: fraction de canaux à conserver (0.5 = 50%)
    """
    # Calculer l'importance moyenne par canal de sortie
    channel_importance = {}

    for layer_idx in importance_dict:
        channel_importance[layer_idx] = {}
        for out_ch in importance_dict[layer_idx]:
            # Moyenne d'importance sur tous les canaux d'entrée pour ce canal de sortie
            avg_importance = np.mean(list(importance_dict[layer_idx][out_ch].values()))
            channel_importance[layer_idx][out_ch] = avg_importance

    # Sélectionner les top canaux pour chaque couche
    top_indices = []
    for layer_idx in channel_importance:
        # Trier par importance décroissante
        channels_by_importance = sorted(
            channel_importance[layer_idx].items(),
            key=lambda x: x[1],
            reverse=False
        )

        # Garder le top keep_ratio
        n_keep = max(1, int(len(channels_by_importance) * keep_ratio))
        top_channels = [ch for ch, _ in channels_by_importance[:n_keep]]

        # Créer les indices pour tous les canaux d'entrée de ces canaux de sortie
        for out_ch in top_channels:
            for in_ch in importance_dict[layer_idx][out_ch]:
                top_indices.append((layer_idx, out_ch, in_ch))

    return top_indices

def test_effective_pruning(model, importance_dict, trn_dl, val_dl, loss_fn, keep_ratios=[0.8]):
    """
    Test de pruning avec réduction EFFECTIVE des paramètres.
    """
    print("Test de Pruning avec Réduction Effective des Paramètres")
    print("="*60)

    # Baseline
    orig_params = sum(p.numel() for p in model.parameters())
    orig_loss, orig_acc = evaluate_model(model, val_dl, loss_fn)

    print(f"Modèle original: {orig_acc:.4f} accuracy, {orig_params:,} paramètres")
    print()

    results = []

    for keep_ratio in keep_ratios:
        print(f"Test avec keep_ratio = {keep_ratio} ({keep_ratio*100:.0f}% des canaux)")
        print("-" * 40)

        try:
            # Pruning structuré par canaux
            top_indices = structured_channel_pruning(model, importance_dict, keep_ratio)
            print(f"Filtres sélectionnés: {len(top_indices)}")

            # Appliquer le pruning EFFECTIF
            pruned_model = prune_model_by_filters_removal(model, top_indices, input_size=(3, 32, 32))

            # Vérifier la réduction
            pruned_params = sum(p.numel() for p in pruned_model.parameters())
            actual_reduction = (orig_params - pruned_params) / orig_params * 100

            print(f"Paramètres avant: {orig_params:,}")
            print(f"Paramètres après: {pruned_params:,}")
            print(f"Réduction RÉELLE: {actual_reduction:.1f}%")

            # Test du modèle pruné
            val_loss_before, val_acc_before = evaluate_model(pruned_model, val_dl, loss_fn)
            print(f"Accuracy avant fine-tuning: {val_acc_before:.4f}")

            # Fine-tuning
            print("Fine-tuning...")
            pruned_model = fine_tune_pruned_model(pruned_model, trn_dl, val_dl, loss_fn, epochs=5, lr=1e-3)

            val_loss_final, val_acc_final = evaluate_model(pruned_model, val_dl, loss_fn)

            accuracy_retention = (val_acc_final / orig_acc) * 100
            compression_ratio = orig_params / pruned_params

            print(f"Accuracy finale: {val_acc_final:.4f}")
            print(f"Rétention accuracy: {accuracy_retention:.1f}%")
            print(f"Ratio de compression: {compression_ratio:.1f}x")
            print()

            results.append({
                'keep_ratio': keep_ratio,
                'actual_reduction': actual_reduction,
                'final_accuracy': val_acc_final,
                'accuracy_retention': accuracy_retention,
                'compression_ratio': compression_ratio,
                'pruned_params': pruned_params
            })

        except Exception as e:
            print(f"Erreur avec keep_ratio {keep_ratio}: {e}")
            print()

    # Résumé des résultats
    print("RÉSUMÉ DES RÉSULTATS")
    print("="*40)
    for r in results:
        print(f"Keep {r['keep_ratio']*100:.0f}%: {r['actual_reduction']:.1f}% réduction, "
              f"{r['accuracy_retention']:.1f}% accuracy, {r['compression_ratio']:.1f}x compression")

    return pruned_model

# Fonction d'évaluation (si pas déjà définie)
def evaluate_model(model, val_dl, loss_fn):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in val_dl:
            device = next(model.parameters()).device
            x, y = x.to(device), y.to(device)
            pred = model(x)
            total_loss += loss_fn(pred, y).item() * x.size(0)
            total_correct += (pred.argmax(1) == y).sum().item()
            total_samples += x.size(0)

    return total_loss / total_samples, total_correct / total_samples

def fine_tune_pruned_model(pruned_model, trn_dl, val_dl, loss_fn, epochs=10, lr=1e-4):
    """Fine-tune le modèle pruné pour récupérer l'accuracy."""
    from torch.optim import Adam

    device = next(pruned_model.parameters()).device
    optimizer = Adam(pruned_model.parameters(), lr=lr)

    for epoch in range(epochs):
        pruned_model.train()
        for x, y in trn_dl:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = pruned_model(x)
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()

        # Validation check
        val_loss, val_acc = evaluate_model(pruned_model, val_dl, loss_fn)
        if epoch == 0 or (epoch + 1) % 2 == 0:
            print(f"  Epoch {epoch+1}: Val Acc = {val_acc:.4f}")

    return pruned_model

# Usage example:
# Remplacer votre appel par:


Calcul des scores pour chaque méthode standard

In [None]:
import numpy as np
import torch
import torch.nn as nn
import copy
from collections import defaultdict

# ============================================================================
# CALCUL DES SCORES POUR CHAQUE MÉTHODE
# ============================================================================

def calculate_magnitude_ranking(model):
    """
    Pruning par magnitude (L1 norm) - baseline classique
    Les filtres avec petite norme L1 sont considérés moins importants
    """
    print("Calcul du ranking par magnitude L1...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = norme L1 (somme des valeurs absolues)
                score = np.abs(weight).sum()
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_l2_magnitude_ranking(model):
    """
    Pruning par magnitude L2 (norme euclidienne)
    Variante du pruning par magnitude
    """
    print("Calcul du ranking par magnitude L2...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = norme L2
                score = np.sqrt(np.sum(weight**2))
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_variance_ranking(model):
    """
    Pruning par variance
    Les filtres avec faible variance sont considérés moins informatifs
    """
    print("Calcul du ranking par variance...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = variance des poids
                score = np.var(weight)
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_random_ranking(model, seed=42):
    """
    Pruning aléatoire - baseline pour vérifier que les autres méthodes
    font mieux que le hasard
    """
    print("Calcul du ranking aléatoire...")
    np.random.seed(seed)

    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                # Score aléatoire entre 0 et 1
                score = np.random.random()
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


# ============================================================================
# FONCTION DE COMPARAISON GLOBALE
# ============================================================================

def compare_pruning_methods(model, trn_dl, val_dl, loss_fn,
                            filter_ranking_topological=None,
                            keep_ratios=[0.8, 0.6, 0.4, 0.2],
                            fine_tune_epochs=5):
    """
    Compare différentes méthodes de pruning

    Args:
        model: modèle PyTorch entraîné
        trn_dl, val_dl: dataloaders train/validation
        loss_fn: fonction de perte
        filter_ranking_topological: votre ranking topologique (optionnel)
        keep_ratios: ratios de filtres à conserver
        fine_tune_epochs: nombre d'époques de fine-tuning

    Returns:
        dict avec résultats pour chaque méthode
    """

    # Baseline : performance du modèle original
    print("="*80)
    print("ÉVALUATION DU MODÈLE ORIGINAL")
    print("="*80)
    orig_params = sum(p.numel() for p in model.parameters())
    orig_loss, orig_acc = evaluate_model(model, val_dl, loss_fn)
    print(f"Accuracy: {orig_acc:.4f}")
    print(f"Paramètres: {orig_params:,}\n")

    # Définition des méthodes à tester
    methods = {
        'L1 Magnitude': calculate_magnitude_ranking,
        'L2 Magnitude': calculate_l2_magnitude_ranking,
        'Variance': calculate_variance_ranking,
        'Random': calculate_random_ranking,
    }

    if filter_ranking_topological is not None:
        methods['Topological (Sobolev)'] = lambda m: filter_ranking_topological

    # Stocker tous les résultats
    all_results = {}

    # Tester chaque méthode
    for method_name, ranking_function in methods.items():
        print("\n" + "="*80)
        print(f"MÉTHODE: {method_name}")
        print("="*80)

        try:
            # Calculer le ranking (sauf si déjà fourni)
            if method_name == 'Topological (Sobolev)' and filter_ranking_topological is not None:
                ranking = filter_ranking_topological
            else:
                ranking = ranking_function(copy.deepcopy(model))

            # Tester le pruning avec cette méthode
            model_copy = copy.deepcopy(model)
            results = test_effective_pruning(
                model_copy,
                ranking,
                trn_dl,
                val_dl,
                loss_fn,
                keep_ratios=keep_ratios
            )

            all_results[method_name] = results

        except Exception as e:
            print(f"Erreur avec {method_name}: {e}")
            all_results[method_name] = None

    # ========================================================================
    # TABLEAU COMPARATIF FINAL
    # ========================================================================
    print("\n" + "="*80)
    print("TABLEAU COMPARATIF - RETENTION D'ACCURACY (%)")
    print("="*80)

    # Header
    header = "Méthode".ljust(25)
    for ratio in keep_ratios:
        header += f"Keep {int(ratio*100)}%".rjust(12)
    print(header)
    print("-"*80)

    # Lignes pour chaque méthode
    for method_name in methods.keys():
        if all_results[method_name] is None:
            continue

        line = method_name.ljust(25)
        for ratio in keep_ratios:
            result = [r for r in all_results[method_name] if r['keep_ratio'] == ratio]
            if result:
                acc_ret = result[0]['accuracy_retention']
                line += f"{acc_ret:10.1f}%".rjust(12)
            else:
                line += "N/A".rjust(12)
        print(line)

    # ========================================================================
    # GRAPHIQUE DE COMPARAISON
    # ========================================================================
    print("\n" + "="*80)
    print("ANALYSE PAR RATIO DE COMPRESSION")
    print("="*80)

    for ratio in keep_ratios:
        print(f"\n--- Keep {int(ratio*100)}% des filtres ---")

        results_at_ratio = []
        for method_name in methods.keys():
            if all_results[method_name] is None:
                continue
            result = [r for r in all_results[method_name] if r['keep_ratio'] == ratio]
            if result:
                results_at_ratio.append({
                    'method': method_name,
                    'acc_retention': result[0]['accuracy_retention'],
                    'compression': result[0]['compression_ratio']
                })

        # Trier par accuracy retention
        results_at_ratio.sort(key=lambda x: x['acc_retention'], reverse=True)

        for i, res in enumerate(results_at_ratio):
            rank = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}."
            print(f"{rank} {res['method']:25s}: {res['acc_retention']:5.1f}% "
                  f"(compression {res['compression']:.1f}x)")

    # ========================================================================
    # STATISTIQUES GLOBALES
    # ========================================================================
    print("\n" + "="*80)
    print("STATISTIQUES GLOBALES")
    print("="*80)

    for method_name in methods.keys():
        if all_results[method_name] is None:
            continue

        accs = [r['accuracy_retention'] for r in all_results[method_name]]
        print(f"\n{method_name}:")
        print(f"  Moyenne: {np.mean(accs):.1f}%")
        print(f"  Min: {np.min(accs):.1f}%")
        print(f"  Max: {np.max(accs):.1f}%")
        print(f"  Écart-type: {np.std(accs):.1f}%")

    return all_results


# ============================================================================
# FONCTION HELPER POUR ÉVALUATION
# ============================================================================

def evaluate_model(model, val_dl, loss_fn):
    """Évalue le modèle sur le dataset de validation"""
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            total_loss += loss_fn(pred, y).item() * x.size(0)
            total_correct += (pred.argmax(1) == y).sum().item()
            total_samples += x.size(0)

    return total_loss / total_samples, total_correct / total_samples


# ============================================================================
# EXEMPLE D'UTILISATION
# ============================================================================

if __name__ == "__main__":
    # Charger votre modèle entraîné et vos dataloaders
    # model, trn_dl, val_dl, loss_fn = ...

    # Si vous avez déjà calculé votre ranking topologique
    # filter_ranking_topological = ...

    # Lancer la comparaison
    results = compare_pruning_methods(
        model=model,
        trn_dl=trn_dl,
        val_dl=val_dl,
        loss_fn=loss_fn,
        filter_ranking_topological=filter_ranking,  # Votre méthode
        keep_ratios=[0.8, 0.6, 0.4, 0.2],
        fine_tune_epochs=5
    )

