In [2]:
import torch
import torch.nn as nn
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
import os
import numpy as np
import copy
import json
import pandas as pd

import torchvision
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_BASE_NAME = "mobilenet_v2"
print(f"Using device: {DEVICE}")


def get_data_loaders(data_dir_path='./data', batch_size=128, val_split=0.1, seed=42, use_augmentation=True):
    """Load CIFAR-10 dataset with train/val/test splits and improved transforms"""
    abs_data_dir = os.path.abspath(data_dir_path)
    print(f"Loading CIFAR-10 from: {abs_data_dir}")

    # Enhanced transforms with data augmentation for training
    if use_augmentation:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=10),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            # Add random erasing for regularization
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.33), ratio=(0.3, 3.3))
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load CIFAR-10 from local directory
    full_train_dataset = torchvision.datasets.CIFAR10(
        root=abs_data_dir, train=True, download=False, transform=train_transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root=abs_data_dir, train=False, download=False, transform=test_transform
    )

    # Create train/validation split
    val_size = int(len(full_train_dataset) * val_split)
    train_size = len(full_train_dataset) - val_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset, [train_size, val_size], generator=generator
    )

    # Apply test transform to validation set
    val_dataset.dataset = copy.deepcopy(full_train_dataset)
    val_dataset.dataset.transform = test_transform

    # Create data loaders
    num_workers = min(4, os.cpu_count() or 2)
    pin_memory = True if DEVICE.type == 'cuda' else False

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=pin_memory)

    print(f"DataLoaders created - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    return train_loader, val_loader, test_loader


def get_mobilenetv2_model(num_classes=10, use_pretrained=True, pretrained_path='./base/mobilenet_v2-b0353104.pth'):
    """Get MobileNetV2 model adapted for CIFAR-10 with improved classifier"""
    # Always create model without weights first
    model = models.mobilenet_v2(weights=None)

    if use_pretrained and os.path.exists(pretrained_path):
        # Load pre-downloaded weights from local file
        print(f"Loading pre-trained weights from: {pretrained_path}")
        pretrained_state_dict = torch.load(pretrained_path, map_location=DEVICE)

        # Load the weights, ignoring the classifier layer if it doesn't match
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_state_dict.items()
                           if k in model_dict and model_dict[k].shape == v.shape}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict=False)
        print("✅ Loaded MobileNetV2 with pre-downloaded ImageNet weights")
    else:
        if use_pretrained:
            print(f"Warning: Pre-trained weights not found at {pretrained_path}")
        print("✅ Created MobileNetV2 without pretrained weights")

    # Enhanced classifier for better performance
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(256, num_classes)
    )

    print(f"✅ Enhanced classifier created for {num_classes} classes")
    return model


def get_ignored_layers(model):
    """Get layers to ignore during pruning (typically final classifier)"""
    ignored_layers = []
    if hasattr(model, 'classifier'):
        if isinstance(model.classifier, nn.Sequential):
            for layer in model.classifier:
                if isinstance(layer, nn.Linear):
                    ignored_layers.append(layer)
        elif isinstance(model.classifier, nn.Linear):
            ignored_layers.append(model.classifier)
    return ignored_layers


def calculate_macs_params(model, example_input):
    """Calculate MACs and parameters using torch_pruning"""
    model.eval()
    target_device = example_input.device
    model_on_device = model.to(target_device)

    with torch.no_grad():
        macs, params = tp.utils.count_ops_and_params(model_on_device, example_input)

    return macs, params


def save_model(model, save_path, example_input_cpu=None):
    """Save model state dict and optionally ONNX"""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

    if example_input_cpu is not None:
        onnx_path = save_path.replace('.pth', '.onnx')
        try:
            model_cpu = model.to('cpu')
            torch.onnx.export(
                model_cpu, example_input_cpu, onnx_path,
                export_params=True, opset_version=13,
                input_names=['input'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
            )
            print(f"✅ ONNX model saved to {onnx_path}")
        except Exception as e:
            print(f"Warning: ONNX export failed: {e}")


def evaluate_model(model, data_loader, example_input, criterion, device):
    """Evaluate model and return comprehensive metrics"""
    model.eval()
    model.to(device)

    # Calculate efficiency metrics
    macs, params = calculate_macs_params(model, example_input.to(device))
    model_size_mb = params * 4 / (1024 * 1024)  # Assuming float32

    # Calculate accuracy and loss
    correct = 0
    total = 0
    total_loss = 0.0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item() * data.size(0)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    avg_loss = total_loss / total if total > 0 else float('nan')
    accuracy = 100.0 * correct / total if total > 0 else 0.0

    return {
        'accuracy': accuracy,
        'loss': avg_loss,
        'macs': macs,
        'params': params,
        'size_mb': model_size_mb
    }


def prune_model(model, strategy_config, sparsity_ratio, example_input, ignored_layers=None):
    """Apply structured pruning to model"""
    if sparsity_ratio == 0.0:
        print("No pruning needed (sparsity = 0.0)")
        return model

    model.eval()
    pruned_model = copy.deepcopy(model)
    pruned_model.to(example_input.device)

    # Calculate initial MACs
    initial_macs, _ = calculate_macs_params(pruned_model, example_input)
    print(f"Initial MACs: {initial_macs / 1e6:.2f}M")

    ignored_layers = ignored_layers or []

    # Create pruner based on strategy
    pruner = strategy_config['pruner'](
        pruned_model,
        example_input,
        importance=strategy_config['importance'],
        iterative_steps=5,  # Use 5 iterative steps
        ch_sparsity=sparsity_ratio,
        root_module_types=[nn.Conv2d],
        ignored_layers=ignored_layers
    )

    print(f"Applying {strategy_config['importance'].__class__.__name__} pruning at {sparsity_ratio:.1%} sparsity...")

    # Apply pruning
    pruner.step()

    # Calculate final MACs
    final_macs, _ = calculate_macs_params(pruned_model, example_input)
    reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    print(f"Final MACs: {final_macs / 1e6:.2f}M (Reduction: {reduction:.1f}%)")

    return pruned_model


def train_model(model, train_loader, criterion, optimizer, device, num_epochs,
                val_loader=None, patience=10, log_prefix="", scheduler=None):
    """Enhanced training function with better optimization and monitoring"""
    model.to(device)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    epochs_no_improve = 0
    best_model_state = None

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'learning_rates': []
    }

    # Label smoothing for better generalization
    smoothing_criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)

            # Use label smoothing for training
            loss = smoothing_criterion(output, target)
            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()

        avg_train_loss = train_loss / len(train_loader)
        train_acc = 100.0 * train_correct / train_total

        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)

        # Record learning rate
        current_lr = optimizer.param_groups[0]['lr']
        history['learning_rates'].append(current_lr)

        log_msg = f"Epoch {epoch + 1:3d}/{num_epochs} ({log_prefix}): Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, LR: {current_lr:.6f}"

        # Validation phase
        if val_loader:
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for data, target in val_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = criterion(output, target)  # Use original criterion for validation

                    val_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    val_total += target.size(0)
                    val_correct += (predicted == target).sum().item()

            avg_val_loss = val_loss / len(val_loader)
            val_acc = 100.0 * val_correct / val_total

            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(val_acc)

            log_msg += f", Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%"

            # Early stopping check - use validation accuracy as primary metric
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = copy.deepcopy(model.state_dict())
                log_msg += " ✓"
            else:
                epochs_no_improve += 1

            # Update scheduler
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(avg_val_loss)
                else:
                    scheduler.step()

            if epochs_no_improve >= patience:
                print(f"{log_msg}")
                print(f"Early stopping triggered after {epoch + 1} epochs (Best Val Acc: {best_val_acc:.2f}%)")
                break
        else:
            history['val_loss'].append(None)
            history['val_acc'].append(None)

            # Update scheduler for training without validation
            if scheduler is not None and not isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step()

        print(log_msg)

    # Load best model state if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model state (Val Acc: {best_val_acc:.2f}%)")

    return model, history


class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing cross entropy loss for better generalization"""

    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, x, target):
        confidence = 1. - self.smoothing
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


def get_optimizer_and_scheduler(model, config, total_steps=None):
    """Get optimized optimizer and learning rate scheduler"""

    # Separate parameters for different learning rates
    backbone_params = []
    classifier_params = []

    for name, param in model.named_parameters():
        if 'classifier' in name:
            classifier_params.append(param)
        else:
            backbone_params.append(param)

    # Use different learning rates for backbone and classifier
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': config['learning_rate'] * 0.1, 'weight_decay': config['weight_decay']},
        {'params': classifier_params, 'lr': config['learning_rate'], 'weight_decay': config['weight_decay'] * 0.1}
    ])

    # Use cosine annealing with warm restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=config['scheduler_restart_period'],
        T_mult=2,
        eta_min=config['learning_rate'] * 0.001
    )

    return optimizer, scheduler


def save_results_to_files(all_results, output_dir):
    """Save experimental results to JSON and CSV files"""
    os.makedirs(output_dir, exist_ok=True)

    # Save complete results as JSON
    results_json_path = os.path.join(output_dir, 'complete_results.json')
    with open(results_json_path, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    print(f"✅ Complete results saved to {results_json_path}")

    # Create summary DataFrame
    summary_data = []
    for strategy, strategy_results in all_results.items():
        for sparsity, metrics in strategy_results.items():
            row = {
                'strategy': strategy,
                'sparsity_ratio': sparsity,
                'accuracy': metrics['accuracy'],
                'loss': metrics['loss'],
                'macs_millions': metrics['macs'] / 1e6,
                'params_millions': metrics['params'] / 1e6,
                'size_mb': metrics['size_mb']
            }
            summary_data.append(row)

    # Save summary as CSV
    summary_df = pd.DataFrame(summary_data)
    summary_csv_path = os.path.join(output_dir, 'summary_results.csv')
    summary_df.to_csv(summary_csv_path, index=False)
    print(f"✅ Summary results saved to {summary_csv_path}")

    return summary_df


def create_results_plots(summary_df, output_dir):
    """Create visualization plots"""
    os.makedirs(output_dir, exist_ok=True)

    strategies = summary_df['strategy'].unique()

    # Plot 1: Accuracy vs Sparsity
    plt.figure(figsize=(12, 8))
    for strategy in strategies:
        strategy_data = summary_df[summary_df['strategy'] == strategy].sort_values('sparsity_ratio')
        plt.plot(strategy_data['sparsity_ratio'] * 100, strategy_data['accuracy'],
                 'o-', linewidth=3, markersize=10, label=strategy)

    plt.xlabel('Sparsity (%)', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)
    plt.title('MobileNetV2: Accuracy vs Sparsity', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_path = os.path.join(output_dir, 'accuracy_vs_sparsity.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Accuracy plot saved to {plot_path}")

    # Plot 2: Efficiency frontier (Accuracy vs MACs)
    plt.figure(figsize=(12, 8))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

    for i, strategy in enumerate(strategies):
        strategy_data = summary_df[summary_df['strategy'] == strategy].sort_values('sparsity_ratio')
        plt.scatter(strategy_data['macs_millions'], strategy_data['accuracy'],
                    s=150, label=strategy, alpha=0.8, color=colors[i % len(colors)])
        plt.plot(strategy_data['macs_millions'], strategy_data['accuracy'],
                 '--', alpha=0.7, linewidth=2, color=colors[i % len(colors)])

    plt.xlabel('MACs (Millions)', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)
    plt.title('MobileNetV2: Efficiency Frontier (Accuracy vs MACs)', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_path = os.path.join(output_dir, 'efficiency_frontier.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Efficiency frontier plot saved to {plot_path}")


def print_results_table(summary_df):
    """Print formatted results table"""
    print("\n" + "=" * 90)
    print("EXPERIMENTAL RESULTS SUMMARY")
    print("=" * 90)

    # Baseline results
    baseline_results = summary_df[summary_df['sparsity_ratio'] == 0.0].iloc[0]
    print(f"\nBaseline Performance:")
    print(f"  Accuracy: {baseline_results['accuracy']:.2f}%")
    print(f"  MACs: {baseline_results['macs_millions']:.2f}M")
    print(f"  Parameters: {baseline_results['params_millions']:.2f}M")
    print(f"  Model Size: {baseline_results['size_mb']:.2f}MB")

    # Strategy comparison at 50% sparsity
    print(f"\nStrategy Comparison at 50% Sparsity:")
    sparsity_50_data = summary_df[summary_df['sparsity_ratio'] == 0.5]
    for _, row in sparsity_50_data.iterrows():
        degradation = baseline_results['accuracy'] - row['accuracy']
        retention = (row['accuracy'] / baseline_results['accuracy']) * 100
        print(
            f"  {row['strategy']:>12}: {row['accuracy']:>6.2f}% accuracy ({degradation:>+5.2f}%, {retention:>5.1f}% retention)")

    # Complete results table
    print(f"\nComplete Results Table:")
    print("-" * 90)
    print(f"{'Strategy':<12} {'Sparsity':<8} {'Accuracy':<8} {'MACs(M)':<8} {'Params(M)':<9} {'Size(MB)':<8}")
    print("-" * 90)

    for _, row in summary_df.sort_values(['strategy', 'sparsity_ratio']).iterrows():
        print(f"{row['strategy']:<12} {row['sparsity_ratio'] * 100:>6.0f}% "
              f"{row['accuracy']:>7.2f}% {row['macs_millions']:>7.2f} "
              f"{row['params_millions']:>8.2f} {row['size_mb']:>7.2f}")


def main():
    """Main experimental workflow with improved training"""
    print("Starting Enhanced MobileNetV2 CIFAR-10 Pruning Experiments")
    print("=" * 65)

    # Enhanced Configuration
    config = {
        'strategies': {
            'BNScale': {
                'pruner': tp.pruner.BNScalePruner,
                'importance': tp.importance.BNScaleImportance()
            },
            'MagnitudeL2': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.MagnitudeImportance(p=2)
            },
            'Random': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.RandomImportance()
            },
        },
        'pruning_ratios': [0.0, 0.2, 0.5, 0.7],
        'num_classes': 10,
        'batch_size': 128,
        'learning_rate': 0.001,  # Increased from 0.0001
        'weight_decay': 1e-4,  # Added L2 regularization
        'epochs': 1000,  # Reduced from 1000 for faster iteration
        'patience': 20,  # Increased patience
        'scheduler_restart_period': 10,  # For cosine annealing
        'use_augmentation': True,  # Enable data augmentation
        'output_dir': './results_mobilenetv2_cifar10_enhanced',
        'models_dir': './base',
        'pretrained_path': './base/mobilenet_v2-b0353104.pth'
    }

    # Create output directories
    os.makedirs(config['output_dir'], exist_ok=True)
    os.makedirs(config['models_dir'], exist_ok=True)

    # Load data with augmentation
    print("Loading CIFAR-10 dataset with enhanced preprocessing...")
    train_loader, val_loader, test_loader = get_data_loaders(
        batch_size=config['batch_size'],
        use_augmentation=config['use_augmentation']
    )

    # Prepare inputs and criterion
    example_input_cpu = torch.randn(1, 3, 32, 32)
    example_input_device = example_input_cpu.to(DEVICE)
    criterion = nn.CrossEntropyLoss()

    # Get baseline model and train it
    print("\nCreating enhanced baseline model...")
    model = get_mobilenetv2_model(
        num_classes=config['num_classes'],
        use_pretrained=True,
        pretrained_path=config['pretrained_path']
    )
    model.to(DEVICE)

    # Get enhanced optimizer and scheduler
    optimizer, scheduler = get_optimizer_and_scheduler(model, config)

    print("Training enhanced baseline model...")
    trained_model, training_history = train_model(
        model, train_loader, criterion, optimizer, DEVICE,
        config['epochs'], val_loader, config['patience'],
        "Enhanced Baseline", scheduler
    )

    # Save baseline model
    baseline_model_path = os.path.join(config['models_dir'], 'enhanced_baseline_model.pth')
    save_model(trained_model, baseline_model_path, example_input_cpu)

    # Evaluate baseline
    print("\nEvaluating enhanced baseline model...")
    baseline_metrics = evaluate_model(trained_model, test_loader, example_input_device, criterion, DEVICE)
    print(f"Enhanced Baseline Results: Accuracy={baseline_metrics['accuracy']:.2f}%, "
          f"MACs={baseline_metrics['macs'] / 1e6:.2f}M, "
          f"Params={baseline_metrics['params'] / 1e6:.2f}M")

    # Initialize results storage
    all_results = {}
    for strategy_name in config['strategies'].keys():
        all_results[strategy_name] = {0.0: baseline_metrics}

    # Get ignored layers
    ignored_layers = get_ignored_layers(trained_model)

    # Run pruning experiments
    print("\nStarting enhanced pruning experiments...")
    for strategy_name, strategy_config in config['strategies'].items():
        print(f"\n--- Strategy: {strategy_name} ---")

        for sparsity_ratio in config['pruning_ratios']:
            if sparsity_ratio == 0.0:
                continue  # Skip baseline (already done)

            print(f"\nProcessing {strategy_name} at {sparsity_ratio:.1%} sparsity...")

            # Load fresh copy of trained baseline
            model_copy = get_mobilenetv2_model(
                num_classes=config['num_classes'],
                use_pretrained=False
            )
            model_copy.load_state_dict(torch.load(baseline_model_path, map_location=DEVICE))
            model_copy.to(DEVICE)

            # Apply pruning
            pruned_model = prune_model(
                model_copy, strategy_config, sparsity_ratio,
                example_input_device, ignored_layers
            )

            # Enhanced fine-tuning with reduced learning rate
            print("Fine-tuning pruned model with enhanced settings...")
            ft_config = config.copy()
            ft_config['learning_rate'] = config['learning_rate'] * 0.1  # Reduce LR for fine-tuning
            ft_config['epochs'] = config['epochs']  # Fewer epochs for fine-tuning

            optimizer_ft, scheduler_ft = get_optimizer_and_scheduler(pruned_model, ft_config)

            fine_tuned_model, ft_history = train_model(
                pruned_model, train_loader, criterion, optimizer_ft, DEVICE,
                ft_config['epochs'], val_loader, config['patience'],
                f"{strategy_name}-{sparsity_ratio:.1%}", scheduler_ft
            )

            # Evaluate fine-tuned model
            final_metrics = evaluate_model(fine_tuned_model, test_loader, example_input_device, criterion, DEVICE)
            all_results[strategy_name][sparsity_ratio] = final_metrics

            print(f"Results: Accuracy={final_metrics['accuracy']:.2f}%, "
                  f"MACs={final_metrics['macs'] / 1e6:.2f}M")

            # Save fine-tuned model
            model_filename = f"enhanced_{strategy_name.lower()}_sparsity_{sparsity_ratio:.1f}.pth"
            model_path = os.path.join(config['models_dir'], model_filename)
            save_model(fine_tuned_model, model_path, example_input_cpu)

    # Save and analyze results
    print("\nSaving enhanced results...")
    summary_df = save_results_to_files(all_results, config['output_dir'])

    # Create enhanced plots
    print("Creating enhanced plots...")
    create_results_plots(summary_df, config['output_dir'])

    # Print comprehensive summary
    print_results_table(summary_df)

    print(f"\n🎉 All enhanced experiments completed!")
    print(f"📁 Results saved to: {os.path.abspath(config['output_dir'])}")
    print(f"📁 Models saved to: {os.path.abspath(config['models_dir'])}")

    # Performance expectations
    print(f"\n📊 Expected Performance Improvements:")
    print(f"   • Baseline accuracy should reach 85-92% (vs previous ~70%)")
    print(f"   • Better accuracy retention after pruning")
    print(f"   • More stable training with enhanced optimizations")


if __name__ == "__main__":
    main()

Using device: cuda
Starting Enhanced MobileNetV2 CIFAR-10 Pruning Experiments
Loading CIFAR-10 dataset with enhanced preprocessing...
Loading CIFAR-10 from: /home/muis/thesis/github-repo/master-thesis/cnn/mobile_net_v2/data
DataLoaders created - Train: 45000, Val: 5000, Test: 10000

Creating enhanced baseline model...
Loading pre-trained weights from: ./base/mobilenet_v2-b0353104.pth
✅ Loaded MobileNetV2 with pre-downloaded ImageNet weights
✅ Enhanced classifier created for 10 classes
Training enhanced baseline model...
Epoch   1/1000 (Enhanced Baseline): Train Loss: 1.6498, Train Acc: 47.37%, LR: 0.000100, Val Loss: 1.1409, Val Acc: 60.92% ✓
Epoch   2/1000 (Enhanced Baseline): Train Loss: 1.3706, Train Acc: 61.01%, LR: 0.000098, Val Loss: 0.9460, Val Acc: 68.32% ✓
Epoch   3/1000 (Enhanced Baseline): Train Loss: 1.2741, Train Acc: 65.82%, LR: 0.000091, Val Loss: 0.8256, Val Acc: 73.04% ✓
Epoch   4/1000 (Enhanced Baseline): Train Loss: 1.2158, Train Acc: 68.45%, LR: 0.000080, Val Loss: 



Epoch   1/1000 (BNScale-20.0%): Train Loss: 1.0763, Train Acc: 74.90%, LR: 0.000010, Val Loss: 0.6810, Val Acc: 78.86% ✓
Epoch   2/1000 (BNScale-20.0%): Train Loss: 1.0153, Train Acc: 77.29%, LR: 0.000010, Val Loss: 0.6488, Val Acc: 79.54% ✓
Epoch   3/1000 (BNScale-20.0%): Train Loss: 0.9855, Train Acc: 78.73%, LR: 0.000009, Val Loss: 0.6247, Val Acc: 80.56% ✓
Epoch   4/1000 (BNScale-20.0%): Train Loss: 0.9684, Train Acc: 79.29%, LR: 0.000008, Val Loss: 0.6142, Val Acc: 80.56%
Epoch   5/1000 (BNScale-20.0%): Train Loss: 0.9674, Train Acc: 79.36%, LR: 0.000007, Val Loss: 0.6074, Val Acc: 80.66% ✓
Epoch   6/1000 (BNScale-20.0%): Train Loss: 0.9539, Train Acc: 79.83%, LR: 0.000005, Val Loss: 0.6095, Val Acc: 81.14% ✓
Epoch   7/1000 (BNScale-20.0%): Train Loss: 0.9464, Train Acc: 80.23%, LR: 0.000004, Val Loss: 0.6050, Val Acc: 81.00%
Epoch   8/1000 (BNScale-20.0%): Train Loss: 0.9419, Train Acc: 80.30%, LR: 0.000002, Val Loss: 0.5979, Val Acc: 81.14%
Epoch   9/1000 (BNScale-20.0%): Train 