In [1]:
import torch
import torch.nn as nn
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
import os
import numpy as np
import copy
import json
import pandas as pd

import torchvision
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_BASE_NAME = "resnet18"
print(f"Using device: {DEVICE}")


def get_data_loaders(data_dir_path='./data', batch_size=128, val_split=0.1, seed=42):
    """Load CIFAR-10 dataset with train/val/test splits"""
    abs_data_dir = os.path.abspath(data_dir_path)
    print(f"Loading CIFAR-10 from: {abs_data_dir}")

    transform_cifar = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    # Load CIFAR-10 from local directory (assuming it's already downloaded)
    full_train_dataset = torchvision.datasets.CIFAR10(
        root=abs_data_dir, train=True, download=False, transform=transform_cifar
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root=abs_data_dir, train=False, download=False, transform=transform_cifar
    )

    # Create train/validation split
    val_size = int(len(full_train_dataset) * val_split)
    train_size = len(full_train_dataset) - val_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset, [train_size, val_size], generator=generator
    )

    # Create data loaders
    num_workers = min(4, os.cpu_count() or 2)
    pin_memory = True if DEVICE.type == 'cuda' else False

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             num_workers=num_workers, pin_memory=pin_memory)

    print(f"DataLoaders created - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    return train_loader, val_loader, test_loader


def get_resnet18_model(num_classes=10, use_pretrained=True, pretrained_path='./base/resnet18-f37072fd.pth'):
    """Get ResNet-18 model adapted for CIFAR-10"""
    # Always create model without weights first
    model = models.resnet18(weights=None)
    
    if use_pretrained and os.path.exists(pretrained_path):
        # Load pre-downloaded weights from local file
        print(f"Loading pre-trained weights from: {pretrained_path}")
        pretrained_state_dict = torch.load(pretrained_path, map_location=DEVICE)
        
        # Load the weights, ignoring the fc layer if it doesn't match
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_state_dict.items() 
                          if k in model_dict and model_dict[k].shape == v.shape}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict=False)
        print("✅ Loaded ResNet-18 with pre-downloaded ImageNet weights")
    else:
        if use_pretrained:
            print(f"Warning: Pre-trained weights not found at {pretrained_path}")
        print("✅ Created ResNet-18 without pretrained weights")

    # Adapt final fc layer for CIFAR-10
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    print(f"✅ Adapted classifier for {num_classes} classes")
    
    return model


def get_ignored_layers(model):
    """Get layers to ignore during pruning (typically final classifier)"""
    ignored_layers = []
    # For ResNet, the final classifier is the 'fc' layer
    if hasattr(model, 'fc'):
        ignored_layers.append(model.fc)
    return ignored_layers


def calculate_macs_params(model, example_input):
    """Calculate MACs and parameters using torch_pruning"""
    model.eval()
    target_device = example_input.device
    model_on_device = model.to(target_device)

    with torch.no_grad():
        macs, params = tp.utils.count_ops_and_params(model_on_device, example_input)

    return macs, params


def save_model(model, save_path, example_input_cpu=None):
    """Save model state dict and optionally ONNX"""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

    if example_input_cpu is not None:
        onnx_path = save_path.replace('.pth', '.onnx')
        try:
            model_cpu = model.to('cpu')
            torch.onnx.export(
                model_cpu, example_input_cpu, onnx_path,
                export_params=True, opset_version=13,
                input_names=['input'], output_names=['output'],
                dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
            )
            print(f"✅ ONNX model saved to {onnx_path}")
        except Exception as e:
            print(f"Warning: ONNX export failed: {e}")


def evaluate_model(model, data_loader, example_input, criterion, device):
    """Evaluate model and return comprehensive metrics"""
    model.eval()
    model.to(device)

    # Calculate efficiency metrics
    macs, params = calculate_macs_params(model, example_input.to(device))
    model_size_mb = params * 4 / (1024 * 1024)  # Assuming float32

    # Calculate accuracy and loss
    correct = 0
    total = 0
    total_loss = 0.0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item() * data.size(0)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    avg_loss = total_loss / total if total > 0 else float('nan')
    accuracy = 100.0 * correct / total if total > 0 else 0.0

    return {
        'accuracy': accuracy,
        'loss': avg_loss,
        'macs': macs,
        'params': params,
        'size_mb': model_size_mb
    }


def prune_model(model, strategy_config, sparsity_ratio, example_input, ignored_layers=None):
    """Apply structured pruning to model"""
    if sparsity_ratio == 0.0:
        print("No pruning needed (sparsity = 0.0)")
        return model

    model.eval()
    pruned_model = copy.deepcopy(model)
    pruned_model.to(example_input.device)

    # Calculate initial MACs
    initial_macs, _ = calculate_macs_params(pruned_model, example_input)
    print(f"Initial MACs: {initial_macs / 1e6:.2f}M")

    ignored_layers = ignored_layers or []

    # Create pruner based on strategy
    pruner = strategy_config['pruner'](
        pruned_model,
        example_input,
        importance=strategy_config['importance'],
        iterative_steps=5,  # Use 5 iterative steps
        ch_sparsity=sparsity_ratio,
        root_module_types=[nn.Conv2d],
        ignored_layers=ignored_layers
    )

    print(f"Applying {strategy_config['importance'].__class__.__name__} pruning at {sparsity_ratio:.1%} sparsity...")

    # Apply pruning
    pruner.step()

    # Calculate final MACs
    final_macs, _ = calculate_macs_params(pruned_model, example_input)
    reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    print(f"Final MACs: {final_macs / 1e6:.2f}M (Reduction: {reduction:.1f}%)")

    return pruned_model


def train_model(model, train_loader, criterion, optimizer, device, num_epochs,
                val_loader=None, patience=7, log_prefix=""):
    """Train model with early stopping"""
    model.to(device)

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()

        avg_train_loss = train_loss / len(train_loader)
        train_acc = 100.0 * train_correct / train_total

        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)

        log_msg = f"Epoch {epoch + 1}/{num_epochs} ({log_prefix}): Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%"

        # Validation phase
        if val_loader:
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for data, target in val_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = criterion(output, target)

                    val_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    val_total += target.size(0)
                    val_correct += (predicted == target).sum().item()

            avg_val_loss = val_loss / len(val_loader)
            val_acc = 100.0 * val_correct / val_total

            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(val_acc)

            log_msg += f", Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%"

            # Early stopping check
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_no_improve = 0
                best_model_state = copy.deepcopy(model.state_dict())
                log_msg += " (Best)"
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"{log_msg}")
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break
        else:
            history['val_loss'].append(None)
            history['val_acc'].append(None)

        print(log_msg)

    # Load best model state if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Loaded best model state")

    return model, history


def save_results_to_files(all_results, output_dir):
    """Save experimental results to JSON and CSV files"""
    os.makedirs(output_dir, exist_ok=True)

    # Save complete results as JSON
    results_json_path = os.path.join(output_dir, 'complete_results.json')
    with open(results_json_path, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    print(f"✅ Complete results saved to {results_json_path}")

    # Create summary DataFrame
    summary_data = []
    for strategy, strategy_results in all_results.items():
        for sparsity, metrics in strategy_results.items():
            row = {
                'strategy': strategy,
                'sparsity_ratio': sparsity,
                'accuracy': metrics['accuracy'],
                'loss': metrics['loss'],
                'macs_millions': metrics['macs'] / 1e6,
                'params_millions': metrics['params'] / 1e6,
                'size_mb': metrics['size_mb']
            }
            summary_data.append(row)

    # Save summary as CSV
    summary_df = pd.DataFrame(summary_data)
    summary_csv_path = os.path.join(output_dir, 'summary_results.csv')
    summary_df.to_csv(summary_csv_path, index=False)
    print(f"✅ Summary results saved to {summary_csv_path}")

    return summary_df


def create_results_plots(summary_df, output_dir):
    """Create visualization plots"""
    os.makedirs(output_dir, exist_ok=True)

    strategies = summary_df['strategy'].unique()
    sparsity_levels = sorted(summary_df['sparsity_ratio'].unique())

    # Plot 1: Accuracy vs Sparsity
    plt.figure(figsize=(10, 6))
    for strategy in strategies:
        strategy_data = summary_df[summary_df['strategy'] == strategy].sort_values('sparsity_ratio')
        plt.plot(strategy_data['sparsity_ratio'] * 100, strategy_data['accuracy'],
                 'o-', linewidth=2, markersize=8, label=strategy)

    plt.xlabel('Sparsity (%)', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title('ResNet-18: Accuracy vs Sparsity', fontsize=14, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_path = os.path.join(output_dir, 'accuracy_vs_sparsity.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Accuracy plot saved to {plot_path}")

    # Plot 2: Efficiency frontier (Accuracy vs MACs)
    plt.figure(figsize=(10, 6))
    for strategy in strategies:
        strategy_data = summary_df[summary_df['strategy'] == strategy].sort_values('sparsity_ratio')
        plt.scatter(strategy_data['macs_millions'], strategy_data['accuracy'],
                    s=100, label=strategy, alpha=0.8)
        plt.plot(strategy_data['macs_millions'], strategy_data['accuracy'],
                 '--', alpha=0.6)

    plt.xlabel('MACs (Millions)', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title('ResNet-18: Efficiency Frontier (Accuracy vs MACs)', fontsize=14, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_path = os.path.join(output_dir, 'efficiency_frontier.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Efficiency frontier plot saved to {plot_path}")


def print_results_table(summary_df):
    """Print formatted results table"""
    print("\n" + "=" * 80)
    print("EXPERIMENTAL RESULTS SUMMARY")
    print("=" * 80)

    # Baseline results
    baseline_results = summary_df[summary_df['sparsity_ratio'] == 0.0].iloc[0]
    print(f"\nBaseline Performance:")
    print(f"  Accuracy: {baseline_results['accuracy']:.2f}%")
    print(f"  MACs: {baseline_results['macs_millions']:.2f}M")
    print(f"  Parameters: {baseline_results['params_millions']:.2f}M")
    print(f"  Model Size: {baseline_results['size_mb']:.2f}MB")

    # Strategy comparison at 50% sparsity
    print(f"\nStrategy Comparison at 50% Sparsity:")
    sparsity_50_data = summary_df[summary_df['sparsity_ratio'] == 0.5]
    for _, row in sparsity_50_data.iterrows():
        degradation = baseline_results['accuracy'] - row['accuracy']
        retention = (row['accuracy'] / baseline_results['accuracy']) * 100
        print(
            f"  {row['strategy']:>12}: {row['accuracy']:>6.2f}% accuracy ({degradation:>+5.2f}%, {retention:>5.1f}% retention)")

    # Complete results table
    print(f"\nComplete Results Table:")
    print("-" * 80)
    print(f"{'Strategy':<12} {'Sparsity':<8} {'Accuracy':<8} {'MACs(M)':<8} {'Params(M)':<9} {'Size(MB)':<8}")
    print("-" * 80)

    for _, row in summary_df.sort_values(['strategy', 'sparsity_ratio']).iterrows():
        print(f"{row['strategy']:<12} {row['sparsity_ratio'] * 100:>6.0f}% "
              f"{row['accuracy']:>7.2f}% {row['macs_millions']:>7.2f} "
              f"{row['params_millions']:>8.2f} {row['size_mb']:>7.2f}")


def main():
    """Main experimental workflow"""
    print("Starting ResNet-18 CIFAR-10 Pruning Experiments")
    print("=" * 60)

    # Configuration
    config = {
        'strategies': {
            'BNScale': {
                'pruner': tp.pruner.BNScalePruner,
                'importance': tp.importance.BNScaleImportance()
            },
            'MagnitudeL2': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.MagnitudeImportance(p=2)
            },
            'Random': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.RandomImportance()
            },
        },
        'pruning_ratios': [0.0, 0.2, 0.5, 0.7],
        'num_classes': 10,
        'batch_size': 128,
        'learning_rate': 0.0001,
        'epochs': 30,
        'patience': 7,
        'output_dir': './results_resnet18_cifar10',
        'models_dir': './base',
        'pretrained_path': './base/resnet18-f37072fd.pth'  # Path to pre-downloaded ResNet-18 weights
    }

    # Create output directories
    os.makedirs(config['output_dir'], exist_ok=True)
    os.makedirs(config['models_dir'], exist_ok=True)

    # Load data
    print("Loading CIFAR-10 dataset...")
    train_loader, val_loader, test_loader = get_data_loaders(
        batch_size=config['batch_size']
    )

    # Prepare inputs and criterion
    example_input_cpu = torch.randn(1, 3, 32, 32)
    example_input_device = example_input_cpu.to(DEVICE)
    criterion = nn.CrossEntropyLoss()

    # Get baseline model and train it
    print("\nCreating and training baseline model...")
    model = get_resnet18_model(
        num_classes=config['num_classes'], 
        use_pretrained=True,
        pretrained_path=config['pretrained_path']
    )
    model.to(DEVICE)

    # Train baseline model
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    trained_model, training_history = train_model(
        model, train_loader, criterion, optimizer, DEVICE,
        config['epochs'], val_loader, config['patience'], "Baseline Training"
    )

    # Save baseline model
    baseline_model_path = os.path.join(config['models_dir'], 'resnet18_baseline_model.pth')
    save_model(trained_model, baseline_model_path, example_input_cpu)

    # Evaluate baseline
    print("\nEvaluating baseline model...")
    baseline_metrics = evaluate_model(trained_model, test_loader, example_input_device, criterion, DEVICE)
    print(f"Baseline Results: Accuracy={baseline_metrics['accuracy']:.2f}%, "
          f"MACs={baseline_metrics['macs'] / 1e6:.2f}M, "
          f"Params={baseline_metrics['params'] / 1e6:.2f}M")

    # Initialize results storage
    all_results = {}
    for strategy_name in config['strategies'].keys():
        all_results[strategy_name] = {0.0: baseline_metrics}

    # Get ignored layers
    ignored_layers = get_ignored_layers(trained_model)

    # Run pruning experiments
    print("\nStarting pruning experiments...")
    for strategy_name, strategy_config in config['strategies'].items():
        print(f"\n--- Strategy: {strategy_name} ---")

        for sparsity_ratio in config['pruning_ratios']:
            if sparsity_ratio == 0.0:
                continue  # Skip baseline (already done)

            print(f"\nProcessing {strategy_name} at {sparsity_ratio:.1%} sparsity...")

            # Load fresh copy of trained baseline
            model_copy = get_resnet18_model(
                num_classes=config['num_classes'], 
                use_pretrained=False
            )
            model_copy.load_state_dict(torch.load(baseline_model_path, map_location=DEVICE))
            model_copy.to(DEVICE)

            # Apply pruning
            pruned_model = prune_model(
                model_copy, strategy_config, sparsity_ratio,
                example_input_device, ignored_layers
            )

            # Fine-tune pruned model
            print("Fine-tuning pruned model...")
            optimizer_ft = optim.Adam(pruned_model.parameters(), lr=config['learning_rate'])
            fine_tuned_model, ft_history = train_model(
                pruned_model, train_loader, criterion, optimizer_ft, DEVICE,
                config['epochs'], val_loader, config['patience'],
                f"{strategy_name}-{sparsity_ratio:.1%}"
            )

            # Evaluate fine-tuned model
            final_metrics = evaluate_model(fine_tuned_model, test_loader, example_input_device, criterion, DEVICE)
            all_results[strategy_name][sparsity_ratio] = final_metrics

            print(f"Results: Accuracy={final_metrics['accuracy']:.2f}%, "
                  f"MACs={final_metrics['macs'] / 1e6:.2f}M")

            # Save fine-tuned model
            model_filename = f"resnet18_{strategy_name.lower()}_sparsity_{sparsity_ratio:.1f}.pth"
            model_path = os.path.join(config['models_dir'], model_filename)
            save_model(fine_tuned_model, model_path, example_input_cpu)

    # Save and analyze results
    print("\nSaving results...")
    summary_df = save_results_to_files(all_results, config['output_dir'])

    # Create plots
    print("Creating plots...")
    create_results_plots(summary_df, config['output_dir'])

    # Print summary
    print_results_table(summary_df)

    print(f"\n🎉 All experiments completed!")
    print(f"📁 Results saved to: {os.path.abspath(config['output_dir'])}")
    print(f"📁 Models saved to: {os.path.abspath(config['models_dir'])}")


if __name__ == "__main__":
    main()

Using device: cuda
Starting ResNet-18 CIFAR-10 Pruning Experiments
Loading CIFAR-10 dataset...
Loading CIFAR-10 from: /home/muis/thesis/github-repo/master-thesis/cnn/resNet/data
DataLoaders created - Train: 45000, Val: 5000, Test: 10000

Creating and training baseline model...
Loading pre-trained weights from: ./base/resnet18-f37072fd.pth
✅ Loaded ResNet-18 with pre-downloaded ImageNet weights
✅ Adapted classifier for 10 classes
Epoch 1/30 (Baseline Training): Train Loss: 1.0594, Train Acc: 63.55%, Val Loss: 0.7187, Val Acc: 74.36% (Best)
Epoch 2/30 (Baseline Training): Train Loss: 0.5706, Train Acc: 80.46%, Val Loss: 0.6358, Val Acc: 77.68% (Best)
Epoch 3/30 (Baseline Training): Train Loss: 0.3644, Train Acc: 87.60%, Val Loss: 0.6332, Val Acc: 79.24% (Best)
Epoch 4/30 (Baseline Training): Train Loss: 0.2349, Train Acc: 91.96%, Val Loss: 0.6951, Val Acc: 79.60%
Epoch 5/30 (Baseline Training): Train Loss: 0.1441, Train Acc: 95.13%, Val Loss: 0.7408, Val Acc: 79.54%
Epoch 6/30 (Baseline 



Epoch 1/30 (BNScale-20.0%): Train Loss: 0.3898, Train Acc: 86.59%, Val Loss: 0.6327, Val Acc: 79.24% (Best)
Epoch 2/30 (BNScale-20.0%): Train Loss: 0.2144, Train Acc: 92.76%, Val Loss: 0.6929, Val Acc: 79.40%
Epoch 3/30 (BNScale-20.0%): Train Loss: 0.1362, Train Acc: 95.37%, Val Loss: 0.7812, Val Acc: 79.52%
Epoch 4/30 (BNScale-20.0%): Train Loss: 0.0956, Train Acc: 96.84%, Val Loss: 0.8158, Val Acc: 79.36%
Epoch 5/30 (BNScale-20.0%): Train Loss: 0.0761, Train Acc: 97.46%, Val Loss: 0.9063, Val Acc: 78.50%
Epoch 6/30 (BNScale-20.0%): Train Loss: 0.0733, Train Acc: 97.41%, Val Loss: 0.9417, Val Acc: 78.88%
Epoch 7/30 (BNScale-20.0%): Train Loss: 0.0628, Train Acc: 97.87%, Val Loss: 0.9163, Val Acc: 79.66%
Epoch 8/30 (BNScale-20.0%): Train Loss: 0.0497, Train Acc: 98.38%, Val Loss: 0.9737, Val Acc: 79.56%
Early stopping triggered after 8 epochs
Loaded best model state
Results: Accuracy=79.16%, MACs=34.04M
✅ Model saved to ./base/resnet18_bnscale_sparsity_0.2.pth
✅ ONNX model saved to ./b