In [None]:
import numpy as np
import torch
import torch.nn as nn
import copy
from collections import defaultdict

# ============================================================================
# CALCUL DES SCORES POUR CHAQUE MÉTHODE
# ============================================================================

def calculate_magnitude_ranking(model):
    """
    Pruning par magnitude (L1 norm) - baseline classique
    Les filtres avec petite norme L1 sont considérés moins importants
    """
    print("Calcul du ranking par magnitude L1...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = norme L1 (somme des valeurs absolues)
                score = np.abs(weight).sum()
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_l2_magnitude_ranking(model):
    """
    Pruning par magnitude L2 (norme euclidienne)
    Variante du pruning par magnitude
    """
    print("Calcul du ranking par magnitude L2...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = norme L2
                score = np.sqrt(np.sum(weight**2))
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_variance_ranking(model):
    """
    Pruning par variance
    Les filtres avec faible variance sont considérés moins informatifs
    """
    print("Calcul du ranking par variance...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()
                # Score = variance des poids
                score = np.var(weight)
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_random_ranking(model, seed=42):
    """
    Pruning aléatoire - baseline pour vérifier que les autres méthodes
    font mieux que le hasard
    """
    print("Calcul du ranking aléatoire...")
    np.random.seed(seed)

    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                # Score aléatoire entre 0 et 1
                score = np.random.random()
                filter_ranking[m_idx][oc][ic] = score

    return filter_ranking


def calculate_entropy_ranking(model):
    """
    Pruning par entropie - mesure de la distribution des poids
    Filtres avec faible entropie (poids similaires) sont moins informatifs
    """
    print("Calcul du ranking par entropie...")
    filter_ranking = {}
    conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]

    for m_idx, conv_layer in enumerate(conv_layers):
        filter_ranking[m_idx] = {}
        out_channels, in_channels = conv_layer.weight.shape[:2]

        for oc in range(out_channels):
            filter_ranking[m_idx][oc] = {}
            for ic in range(in_channels):
                weight = conv_layer.weight[oc, ic].detach().cpu().numpy()

                # Normaliser pour avoir une distribution
                weight_abs = np.abs(weight.flatten())
                weight_norm = weight_abs / (weight_abs.sum() + 1e-10)

                # Calculer l'entropie
                entropy = -np.sum(weight_norm * np.log(weight_norm + 1e-10))
                filter_ranking[m_idx][oc][ic] = entropy

    return filter_ranking


# ============================================================================
# FONCTION DE COMPARAISON GLOBALE
# ============================================================================

def compare_pruning_methods(model, trn_dl, val_dl, loss_fn,
                            filter_ranking_topological=None,
                            keep_ratios=[0.8, 0.6, 0.4, 0.2],
                            fine_tune_epochs=5):
    """
    Compare différentes méthodes de pruning

    Args:
        model: modèle PyTorch entraîné
        trn_dl, val_dl: dataloaders train/validation
        loss_fn: fonction de perte
        filter_ranking_topological: votre ranking topologique (optionnel)
        keep_ratios: ratios de filtres à conserver
        fine_tune_epochs: nombre d'époques de fine-tuning

    Returns:
        dict avec résultats pour chaque méthode
    """

    # Baseline : performance du modèle original
    print("="*80)
    print("ÉVALUATION DU MODÈLE ORIGINAL")
    print("="*80)
    orig_params = sum(p.numel() for p in model.parameters())
    orig_loss, orig_acc = evaluate_model(model, val_dl, loss_fn)
    print(f"Accuracy: {orig_acc:.4f}")
    print(f"Paramètres: {orig_params:,}\n")

    # Définir les méthodes à tester
    methods = {
        'L1 Magnitude': calculate_magnitude_ranking,
        'L2 Magnitude': calculate_l2_magnitude_ranking,
        'Variance': calculate_variance_ranking,
        'Entropy': calculate_entropy_ranking,
        'Random': calculate_random_ranking,
    }

    if filter_ranking_topological is not None:
        methods['Topological (Sobolev)'] = lambda m: filter_ranking_topological

    # Stocker tous les résultats
    all_results = {}

    # Tester chaque méthode
    for method_name, ranking_function in methods.items():
        print("\n" + "="*80)
        print(f"MÉTHODE: {method_name}")
        print("="*80)

        try:
            # Calculer le ranking (sauf si déjà fourni)
            if method_name == 'Topological (Sobolev)' and filter_ranking_topological is not None:
                ranking = filter_ranking_topological
            else:
                ranking = ranking_function(copy.deepcopy(model))

            # Tester le pruning avec cette méthode
            model_copy = copy.deepcopy(model)
            results = test_effective_pruning(
                model_copy,
                ranking,
                trn_dl,
                val_dl,
                loss_fn,
                keep_ratios=keep_ratios
            )

            all_results[method_name] = results

        except Exception as e:
            print(f"Erreur avec {method_name}: {e}")
            all_results[method_name] = None

    # ========================================================================
    # TABLEAU COMPARATIF FINAL
    # ========================================================================
    print("\n" + "="*80)
    print("TABLEAU COMPARATIF - RETENTION D'ACCURACY (%)")
    print("="*80)

    # Header
    header = "Méthode".ljust(25)
    for ratio in keep_ratios:
        header += f"Keep {int(ratio*100)}%".rjust(12)
    print(header)
    print("-"*80)

    # Lignes pour chaque méthode
    for method_name in methods.keys():
        if all_results[method_name] is None:
            continue

        line = method_name.ljust(25)
        for ratio in keep_ratios:
            result = [r for r in all_results[method_name] if r['keep_ratio'] == ratio]
            if result:
                acc_ret = result[0]['accuracy_retention']
                line += f"{acc_ret:10.1f}%".rjust(12)
            else:
                line += "N/A".rjust(12)
        print(line)

    # ========================================================================
    # GRAPHIQUE DE COMPARAISON
    # ========================================================================
    print("\n" + "="*80)
    print("ANALYSE PAR RATIO DE COMPRESSION")
    print("="*80)

    for ratio in keep_ratios:
        print(f"\n--- Keep {int(ratio*100)}% des filtres ---")

        results_at_ratio = []
        for method_name in methods.keys():
            if all_results[method_name] is None:
                continue
            result = [r for r in all_results[method_name] if r['keep_ratio'] == ratio]
            if result:
                results_at_ratio.append({
                    'method': method_name,
                    'acc_retention': result[0]['accuracy_retention'],
                    'compression': result[0]['compression_ratio']
                })

        # Trier par accuracy retention
        results_at_ratio.sort(key=lambda x: x['acc_retention'], reverse=True)

        for i, res in enumerate(results_at_ratio):
            rank = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else f"{i+1}."
            print(f"{rank} {res['method']:25s}: {res['acc_retention']:5.1f}% "
                  f"(compression {res['compression']:.1f}x)")

    # ========================================================================
    # STATISTIQUES GLOBALES
    # ========================================================================
    print("\n" + "="*80)
    print("STATISTIQUES GLOBALES")
    print("="*80)

    for method_name in methods.keys():
        if all_results[method_name] is None:
            continue

        accs = [r['accuracy_retention'] for r in all_results[method_name]]
        print(f"\n{method_name}:")
        print(f"  Moyenne: {np.mean(accs):.1f}%")
        print(f"  Min: {np.min(accs):.1f}%")
        print(f"  Max: {np.max(accs):.1f}%")
        print(f"  Écart-type: {np.std(accs):.1f}%")

    return all_results


# ============================================================================
# FONCTION HELPER POUR ÉVALUATION
# ============================================================================

def evaluate_model(model, val_dl, loss_fn):
    """Évalue le modèle sur le dataset de validation"""
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    device = next(model.parameters()).device

    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            total_loss += loss_fn(pred, y).item() * x.size(0)
            total_correct += (pred.argmax(1) == y).sum().item()
            total_samples += x.size(0)

    return total_loss / total_samples, total_correct / total_samples




results = compare_pruning_methods(
        model=model,
        trn_dl=train_loader,
        val_dl=val_loader,
        loss_fn=loss_fn,
        keep_ratios=[0.8, 0.6, 0.4, 0.2],
        fine_tune_epochs=5
    )
