In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau # Added LR Scheduler

import torchvision
import torchvision.transforms as transforms
import os
import copy # For deepcopying model state
from torch.utils.data import DataLoader, SubsetRandomSampler # Ensure SubsetRandomSampler is imported
import os
import numpy as np # For shuffling indices if not using torch.random.shuffle directly
#from cnn.resNet.resnet_example import get_data_loaders

### Data loaders

In [2]:
def get_data_loaders(data_dir, batch_size=64, validation_split=0.1, num_workers=4, pin_memory=True):
    """
    Provides CIFAR-10 data loaders from a pre-downloaded dataset.
    Splits the training set into train and validation if validation_split > 0.
    """
    abs_data_dir = os.path.abspath(data_dir)
    print(f"Using dataset directory: {abs_data_dir}")

    # Using common CIFAR-10 augmentations for training, simple ToTensor for validation/test
    # Note: Your original transform was ((0.5,0.5,0.5), (0.5,0.5,0.5))
    # Standard CIFAR-10 normalization is usually different, like:
    # Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    # If you stick with 0.5,0.5,0.5, make sure your pretrained models (if any) used the same.
    # For now, I'll use the more common ones for train_transform and a simple one for val/test
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    val_test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Load the full training dataset
    # Assign appropriate transforms based on whether we'll split or not
    # If splitting, the full_trainset should use train_transform for its training part
    # and val_test_transform for its validation part. This is handled by creating
    # Subset instances with different transforms, but torchvision.datasets.CIFAR10 takes one transform.
    # A common approach is to apply train_transform to the full dataset and then subsets
    # will inherit it. If validation needs different transform, it's more complex without
    # re-instantiating datasets or custom Subset classes.
    # For simplicity here, we apply train_transform to the full set, which might be slightly
    # suboptimal for validation (as it gets augmentations meant for training).
    # A more robust way would be separate datasets or a wrapper.
    # Let's keep your original transform for the initial dataset load if validation_split is used
    # to avoid augmenting the validation set during the split.

    initial_transform = val_test_transform if validation_split > 0 else train_transform

    full_train_dataset = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=True,
        download=False, # Respecting your setting
        transform=initial_transform # Use val_test_transform if splitting, else train_transform
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=False,
        download=False, # Respecting your setting
        transform=val_test_transform
    )

    if not os.path.exists(os.path.join(abs_data_dir, 'cifar-10-batches-py')):
        print(f"ERROR: CIFAR-10 data not found in {abs_data_dir}. "
              "Please ensure it's pre-downloaded and extracted correctly.")
        # Depending on how strict you want this:
        # return None, None, None
        # raise FileNotFoundError(f"CIFAR-10 data not found in {abs_data_dir}")


    if validation_split > 0.0 and 0.0 < validation_split < 1.0:
        num_train = len(full_train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(validation_split * num_train))

        # For reproducibility of the split
        np.random.seed(42)
        np.random.shuffle(indices)

        train_idx, val_idx = indices[split:], indices[:split]

        # If we used val_test_transform initially for full_train_dataset,
        # now we want the training subset to use train_transform.
        # We create new Subset objects and can assign specific transforms to DataLoaders if needed,
        # but it's cleaner if the Subset itself 'knows' its transform.
        # A simple way is to re-wrap datasets with desired transforms after splitting indices.
        # However, a Subset directly uses the transform of its parent dataset.

        # To have different transforms for train and val subsets derived from the same original dataset
        # without augmentations on the val set:
        # 1. Load full_train_dataset with NO transform (or ToTensor only).
        # 2. Create two wrapper datasets (e.g., using `torch.utils.data.Dataset` subclass) for train and val
        #    that take a subset of indices and apply the correct transform in their __getitem__.
        # For simplicity here, if `initial_transform` was `val_test_transform`, the training data won't have augmentation.
        # To fix this while keeping it relatively simple:
        if initial_transform == val_test_transform: # Means we are splitting
             # We want the actual training part to have training augmentations.
             # Create a new dataset instance JUST FOR TRAINING with the train_transform.
             # This is a bit inefficient (re-scans dataset metadata) but clear.
            train_dataset_for_loader = torchvision.datasets.CIFAR10(
                root=data_dir, train=True, download=False, transform=train_transform
            )
            # val_dataset_for_loader will use full_train_dataset which has val_test_transform
            val_dataset_for_loader = full_train_dataset
        else: # No split, or initial_transform was already train_transform
            train_dataset_for_loader = full_train_dataset
            val_dataset_for_loader = test_dataset # Fallback for val_loader

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(
            dataset=train_dataset_for_loader, # This dataset should have train_transform
            batch_size=batch_size,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory
        )
        val_loader = DataLoader(
            dataset=val_dataset_for_loader, # This dataset should have val_test_transform
            batch_size=batch_size,
            sampler=val_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory
        )
        print(f"Splitting training data: {len(train_idx)} for training (with augmentation), {len(val_idx)} for validation.")

    else: # No validation split, or invalid split ratio
        if validation_split != 0:
             print(f"Warning: Invalid validation_split value ({validation_split}). Using full training set for training and test set for validation.")
        # Standard train loader with full training data
        train_loader = DataLoader(
            dataset=full_train_dataset, # full_train_dataset has train_transform if no split
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory
        )
        # Use test_loader as val_loader
        # This is okay for quick experiments but not for final model selection / reporting
        val_loader = DataLoader(
            dataset=test_dataset, # test_dataset has val_test_transform
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory
        )
        print("Using full training set for training. Using test set as validation set (not recommended for final evaluation).")


    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    return train_loader, val_loader, test_loader

### ResNet Blocks and Model Definition

In [3]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # Adapted for CIFAR-10: kernel_size 3, stride 1
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # The final FC layer name is 'linear'
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride_val in strides:
            layers.append(block(self.in_planes, planes, stride_val))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # For CIFAR-10, output of layer4 will be 4x4 if input is 32x32
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)

def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

In [6]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs, params

### compare results of different pruning strategies

In [7]:
def compare_results(results):
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

### compare and plot results

In [8]:
def compare_results_and_plot(results, output_dir='output'):
    """
    Print a comparison table and generate bar charts for MACs, model size, and accuracy
    for each pruning strategy, including the initial model.
    """
    # Print comparison table
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

    # Generate bar charts
    os.makedirs(output_dir, exist_ok=True)
    strategies = ['initial'] + [s for s in results if s != 'initial']
    metrics_to_plot = ['macs', 'size_mb', 'accuracy'] # Renamed to avoid conflict
    titles = {
        'macs': 'MACs Comparison',
        'size_mb': 'Model Size (MB) Comparison',
        'accuracy': 'Accuracy (%) Comparison'
    }
    y_labels = {
        'macs': 'MACs (Millions)',
        'size_mb': 'Size (MB)',
        'accuracy': 'Accuracy (%)'
    }

    colors = plt.cm.tab10(range(len(strategies)))
    for metric_name in metrics_to_plot:
        values = [results[strategy][metric_name] / 1e6 if metric_name == 'macs' else results[strategy][metric_name]
                  for strategy in strategies]  # Convert MACs to millions
        plt.figure(figsize=(12, 6))
        bars = plt.bar(strategies, values, color=colors)
        plt.xlabel('Strategy')
        plt.ylabel(y_labels[metric_name])
        plt.title(titles[metric_name])
        plt.xticks(rotation=45, ha='right')

        # Add value labels
        for bar in bars:
            yval = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., yval, f'{yval:.2f}', ha='center', va='bottom')

        # Add initial model reference line
        if 'initial' in results: # Ensure initial results exist
            initial_value = results['initial'][metric_name] / 1e6 if metric_name == 'macs' else results['initial'][metric_name]
            plt.axhline(y=initial_value, color='r', linestyle='--', label='Initial')
            plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric_name}_comparison.png'))
        plt.close()

### Load Model

In [9]:
def load_model_state(model_class, num_classes, path, device_to_load): # Renamed for clarity
    model = model_class(num_classes=num_classes).to(device_to_load)
    model.load_state_dict(torch.load(path, map_location=device_to_load))
    return model

### Utility function to save the model

In [10]:
def save_model_state(model, path, example_input=None): # Renamed for clarity
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    if example_input is not None:
        onnx_path = path.replace('.pth', '.onnx')
        save_model_as_onnx(model, example_input, onnx_path)

### Evaluate the model

In [11]:
def evaluate_model(model, data_loader, example_input, device_to_eval, eval_mode_str="Test"): # Added eval_mode_str
    model.eval()
    macs, _ = calculate_macs(model, example_input.to(device_to_eval))
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    size_mb = params * 4 / 1e6
    correct = 0; total = 0
    with torch.no_grad():
        for data_batch in data_loader: # Renamed data to data_batch
            inputs, labels = [d.to(device_to_eval) for d in data_batch]
            outputs = model(inputs); _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0); correct += (predicted == labels).sum().item()
    accuracy_val = 0
    if total > 0: accuracy_val = 100 * correct / total
    print(f"{eval_mode_str} Accuracy: {accuracy_val:.2f}%")
    return {'macs': macs, 'size_mb': size_mb, 'accuracy': accuracy_val}

### Prune the model


In [12]:
def prune_model(model, example_input, target_macs_unused, strategy, iterative_steps=5): # target_macs unused currently
    pruning_ratio = 0.1 if isinstance(strategy['importance'], tp.importance.TaylorImportance) else 0.5
    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else ([model.fc] if hasattr(model, 'fc') else [])
    pruner = strategy['pruner'](
        model, example_input, importance=strategy['importance'], iterative_steps=iterative_steps,
        ch_sparsity=pruning_ratio, root_module_types=[nn.Conv2d], ignored_layers=ignored_layers_list,
    )
    current_macs, base_nparams = calculate_macs(model, example_input); initial_macs_log = current_macs
    for i in range(iterative_steps):
        if isinstance(strategy['importance'], tp.importance.TaylorImportance):
            is_training = model.training; model.train()
            loss = model(example_input).sum(); loss.backward()
            if not is_training: model.eval()
        pruning_occurred = False
        for g in pruner.step(interactive=True): g.prune(); pruning_occurred = True
        if not pruning_occurred: print(f"  Iter {i+1}/{iterative_steps}, No prunable groups. Stopping."); break
        macs, nparams = calculate_macs(model, example_input)
        print(f"  Iter {i+1}/{iterative_steps}, Params: {base_nparams/1e6:.2f}M->{nparams/1e6:.2f}M, MACs: {initial_macs_log/1e9:.2f}G->{macs/1e9:.2f}G")
        initial_macs_log = macs; base_nparams = nparams
    return model

### Train the model

In [13]:
import copy # Ensure copy is imported

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                device, num_epochs, model_name="Model", patience=10): # Added patience parameter
    best_val_acc = 0.0
    best_model_state = None
    epochs_no_improve = 0  # Counter for epochs without validation improvement

    print(f"Starting training for {model_name} with patience={patience}")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            running_loss += loss.item()

            # Optional: More frequent logging if needed
            # if batch_idx % 100 == 0 and batch_idx > 0:
            #     print(f"{model_name} - Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, "
            #           f"Train Loss: {loss.item():.4f}")

        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100 * correct_train / total_train

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss_val_batch = criterion(outputs, labels) # Use a different var name for batch val loss
                val_loss += loss_val_batch.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = 100 * correct_val / total_val

        print(f"{model_name} - Epoch {epoch+1}/{num_epochs}: Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.2f}% | "
              f"Val Loss={epoch_val_loss:.4f}, Val Acc={epoch_val_acc:.2f}% | LR: {optimizer.param_groups[0]['lr']:.1e}")

        if scheduler:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(epoch_val_loss)
            else:
                scheduler.step()

        # Early stopping logic
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_model_state = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0 # Reset counter
            print(f"****** New best validation accuracy for {model_name}: {best_val_acc:.2f}% (Epoch {epoch+1}) ******")
        else:
            epochs_no_improve += 1
            print(f"Validation accuracy did not improve for {epochs_no_improve} epoch(s). Best: {best_val_acc:.2f}%")

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered for {model_name} after {patience} epochs without improvement.")
            break # Exit the training loop

    if best_model_state:
        print(f"Finished training {model_name}. Best validation accuracy achieved: {best_val_acc:.2f}%")
        model.load_state_dict(best_model_state)
    else:
        print(f"Warning: {model_name} - No best model state recorded. This might happen if training stops very early or val_acc never improves.")

    return model

### Main workflow

In [14]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CURRENT_MODEL_NAME = "resnet18"
    MODEL_CLASS = ResNet18
    NUM_CLASSES = 10

    config = {
        'strategies': {
            'magnitude': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.MagnitudeImportance(p=2)},
            'bn_scale': {'pruner': tp.pruner.BNScalePruner, 'importance': tp.importance.BNScaleImportance()},
            'group_norm': {'pruner': tp.pruner.GroupNormPruner, 'importance': tp.importance.GroupMagnitudeImportance(p=1)},
            'random': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.RandomImportance()},
            'Taylor': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.TaylorImportance()},
            # 'Hessian': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.GroupHessianImportance()}, # Can be slow
            'lamp': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.LAMPImportance(p=2)},
            'geometry': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.FPGMImportance()}
        },
        'target_macs_sparsity': 0.5,
        'train_epochs': 30,  # Increased initial training epochs
        'fine_tune_epochs': 50, # Significantly increased fine-tune epochs
        'data_dir': './data',
        'output_dir': f'./output_resnet_finetuned/{CURRENT_MODEL_NAME}/strategies',
        'iterative_steps': 15,
        'learning_rate_initial': 0.001,
        'learning_rate_finetune': 0.0005, # Initial LR for fine-tuning, scheduler will adjust
        'validation_split_ratio': 0.1, # Use 10% of training data for validation
        'early_stopping_patience': 10,
    }
    os.makedirs(config['output_dir'], exist_ok=True)

    model = MODEL_CLASS(num_classes=NUM_CLASSES).to(device)
    train_loader, val_loader, test_loader = get_data_loaders(
        config['data_dir'], validation_split=config['validation_split_ratio']
    )
    if train_loader is None: # In case get_data_loaders returns None on error
        print("Failed to load data. Exiting.")
        return
    example_input = torch.randn(1, 3, 32, 32).to(device)

    initial_model_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_initial.pth")

    if not os.path.exists(initial_model_path):
        print(f"--- Initial Training for {CURRENT_MODEL_NAME} ---")
        optimizer_initial = optim.Adam(model.parameters(), lr=config['learning_rate_initial'])
        # Scheduler for initial training (optional, but can be good)
        scheduler_initial = ReduceLROnPlateau(optimizer_initial, mode='max', factor=0.1, patience=5, verbose=True)
        model = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader, # Use val_loader
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optimizer_initial,
            scheduler=scheduler_initial, # Pass scheduler
            device=device,
            num_epochs=config['train_epochs'],
            patience=config['early_stopping_patience'], # Pass patience
            model_name=f"{CURRENT_MODEL_NAME} Initial"
        )
        save_model_state(model, initial_model_path, example_input)
    else:
        print(f"Loading existing initial model from {initial_model_path}")
        model = load_model_state(MODEL_CLASS, NUM_CLASSES, initial_model_path, device)

    results = {}
    print(f"\n--- Evaluating Initial {CURRENT_MODEL_NAME} on Test Set ---")
    # Evaluate initial model on the actual test set
    initial_metrics = evaluate_model(model, test_loader, example_input, device, eval_mode_str=f"{CURRENT_MODEL_NAME} Initial Test")
    results['initial'] = initial_metrics
    print(f"Initial Test Metrics: {initial_metrics}")

    initial_macs, _ = calculate_macs(model, example_input)
    target_macs_for_prune = initial_macs * config['target_macs_sparsity']

    # Pruning targets for GR and GEM (can be made more dynamic)
    target_macs_gr_gem = initial_macs * 0.3
    target_size_mb_gr = results['initial']['size_mb'] * 0.3
    target_params_gem = sum(p.numel() for p in model.parameters()) * 0.3
    CHOSEN_PRUNING_FUNCTION = "gr" # "simple", "gr", or "gem"

    for strategy_name, strategy_details in config['strategies'].items():
        print(f"\n--- Pruning {CURRENT_MODEL_NAME} with Strategy: {strategy_name} ---")
        model_copy = load_model_state(MODEL_CLASS, NUM_CLASSES, initial_model_path, device)

        if CHOSEN_PRUNING_FUNCTION == "simple":
            pruned_model = prune_model(
                model=model_copy,
                example_input=example_input,
                target_macs_unused=target_macs_for_prune,
                strategy=strategy_details,
                iterative_steps=config['iterative_steps'],
            )
        elif CHOSEN_PRUNING_FUNCTION == "gr":
            pruned_model = gr_prune_model_with_threshold(
                model=model_copy,
                example_input=example_input,
                target_macs=target_macs_gr_gem,
                target_size_mb=target_size_mb_gr,
                strategy=strategy_details,
                iterative_steps=config['iterative_steps']
            )
        elif CHOSEN_PRUNING_FUNCTION == "gem":
            pruned_model = gem_prune_model_by_threshold(
                model=model_copy,
                example_input=example_input,
                target_macs=target_macs_gr_gem,
                target_params=target_params_gem,
                strategy=strategy_details,
                max_iterations=20,
                step_ch_sparsity=0.2
            )
        else: raise ValueError(f"Unknown CHOSEN_PRUNING_FUNCTION: {CHOSEN_PRUNING_FUNCTION}")


        pruned_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_{strategy_name}_pruned.pth")
        save_model_state(pruned_model, pruned_path, example_input)

        print(f"\n--- Fine-tuning {CURRENT_MODEL_NAME} after {strategy_name} pruning ---")
        # Use a new optimizer and scheduler for fine-tuning each pruned model
        optimizer_ft = optim.Adam(pruned_model.parameters(), lr=config['learning_rate_finetune'])
        scheduler_ft = ReduceLROnPlateau(optimizer_ft, mode='max', factor=0.2, patience=5, verbose=True, min_lr=1e-6)

        fine_tuned_model = train_model(
            model=pruned_model,
            train_loader=train_loader,
            val_loader=val_loader, # Use val_loader
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optimizer_ft,
            scheduler=scheduler_ft, # Pass scheduler
            device=device,
            num_epochs=config['fine_tune_epochs'],
            patience=config['early_stopping_patience'], # Pass patience
            model_name=f"{CURRENT_MODEL_NAME} Pruned-{strategy_name}"
        )

        print(f"\n--- Evaluating fine-tuned {CURRENT_MODEL_NAME} ({strategy_name}) on Test Set ---")
        # Evaluate the fine-tuned model (which has the best validation weights) on the test set
        results[strategy_name] = evaluate_model(
            model=fine_tuned_model,
            data_loader=test_loader, # Use actual test_loader for final eval
            example_input=example_input,
            device_to_eval=device,
            eval_mode_str=f"{CURRENT_MODEL_NAME} FineTuned-{strategy_name} Test"
        )
        print(f"Test Metrics for {strategy_name}: {results[strategy_name]}")

        final_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_{strategy_name}_final_best_val.pth")
        save_model_state(fine_tuned_model, final_path, example_input) # Model already has best val weights

    compare_results_and_plot(results, output_dir=config['output_dir'])
    print("ResNet pruning and enhanced fine-tuning workflow completed successfully!")

### GR Pruning Model

In [15]:
def gr_prune_model_with_threshold(model, example_input, target_macs, target_size_mb, strategy, iterative_steps=5):
    current_macs, nparams = tp.utils.count_ops_and_params(model, example_input)
    current_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
    print(f"Initial MACs: {current_macs / 1e9:.3f} G, Size: {current_size_mb:.2f} MB")

    pruning_ratio = 0.5
    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else []
    if hasattr(model, 'fc') and model.fc not in ignored_layers_list:
        ignored_layers_list.append(model.fc)


    if isinstance(strategy['importance'], tp.importance.TaylorImportance):
        pruner = strategy['pruner'](
            model, example_input, importance=strategy['importance'],
            iterative_steps=1, ch_sparsity=pruning_ratio / iterative_steps if iterative_steps > 0 else pruning_ratio,
            root_module_types=[nn.Conv2d], ignored_layers=ignored_layers_list
        )
        for i in range(iterative_steps):
            is_training = model.training
            model.train()
            model.zero_grad()
            loss = model(example_input).sum()
            loss.backward()
            if not is_training: model.eval()

            pruned_something = False
            for group in pruner.step(interactive=True):
                group.prune()
                pruned_something = True

            if not pruned_something and i > 0 : # Don't break on first iter if nothing found
                print(f"Step {i+1}/{iterative_steps}: No more prunable elements found.")
                break


            current_macs, _ = tp.utils.count_ops_and_params(model, example_input)
            current_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
            print(f"Step {i+1}/{iterative_steps}: MACs {current_macs / 1e9:.3f} G, Size {current_size_mb:.2f} MB")
            if current_macs <= target_macs and current_size_mb <= target_size_mb:
                print(f"Targets reached at step {i+1}")
                break
            model.zero_grad() # important for Taylor if used in loop without re-instantiating pruner
    else:
        pruner = strategy['pruner'](
            model, example_input, importance=strategy['importance'],
            iterative_steps=iterative_steps, ch_sparsity=pruning_ratio,
            root_module_types=[nn.Conv2d], ignored_layers=ignored_layers_list
        )
        # pruner.step() # This typically means run all iterative steps internally.
        # If you want fine-grained control like Taylor above, loop pruner.step()
        for i in range(iterative_steps):
            pruned_something = False
            for group in pruner.step(interactive=True): # Use interactive for explicit pruning
                group.prune()
                pruned_something = True

            if not pruned_something and i > 0:
                 print(f"Step {i+1}/{iterative_steps}: No more prunable elements found for non-Taylor strategy.")
                 break

            current_macs_step, _ = tp.utils.count_ops_and_params(model, example_input)
            current_size_mb_step = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
            print(f"Step {i+1}/{iterative_steps} (Non-Taylor): MACs {current_macs_step / 1e9:.3f} G, Size {current_size_mb_step:.2f} MB")
            if current_macs_step <= target_macs and current_size_mb_step <= target_size_mb:
                print(f"Targets reached at step {i+1} for non-Taylor strategy.")
                break


    final_macs, _ = tp.utils.count_ops_and_params(model, example_input)
    final_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
    print(f"After pruning: MACs {final_macs / 1e9:.3f} G, Size {final_size_mb:.2f} MB")

    if final_macs <= target_macs and final_size_mb <= target_size_mb:
        print("Pruning targets achieved.")
    else:
        print("Warning: Pruning targets not fully achieved.")
    return model

### gem pruning model

In [16]:
def gem_prune_model_by_threshold(model, example_input, target_macs, target_params, strategy, max_iterations=100, step_ch_sparsity=0.1):
    device = example_input.device
    model.to(device)

    print(f"--- Starting Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    print(f"Target MACs: {target_macs:,.0f}, Target Params: {target_params:,.0f}")

    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else []
    if hasattr(model, 'fc') and model.fc not in ignored_layers_list:
        ignored_layers_list.append(model.fc)

    pruner = strategy['pruner'](
        model, example_input, importance=strategy['importance'],
        ch_sparsity=step_ch_sparsity, root_module_types=[nn.Conv2d],
        ignored_layers=ignored_layers_list,
        # round_to=8, # Optional
    )

    current_macs, current_params = calculate_macs(model, example_input)
    initial_macs, initial_params = current_macs, current_params
    print(f"Initial | MACs: {current_macs:,.0f}, Params: {current_params:,.0f}")

    iteration = 0
    # model.eval() # Moved inside loop for Taylor/Hessian

    while (current_macs > target_macs or current_params > target_params) and iteration < max_iterations:
        iteration += 1
        macs_before_step = current_macs
        params_before_step = current_params

        original_mode_is_train = model.training # Store original mode

        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
             model.train()
             for param in model.parameters():
                 param.requires_grad_(True)
             loss = model(example_input).mean()
             try:
                 model.zero_grad() # Zero gradients before backward
                 loss.backward()
             except Exception as e:
                 print(f"Error during backward pass for importance calc (Iter {iteration}): {e}")
                 if not original_mode_is_train: model.eval() # Restore mode
                 break
        else: # Ensure eval mode for other strategies
            model.eval()


        try:
            pruning_groups = list(pruner.step(interactive=True))
        except Exception as e:
             print(f"Error during pruner.step() (Iter {iteration}): {e}")
             if not original_mode_is_train and isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
                 model.eval() # Restore eval mode
             elif original_mode_is_train and not isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
                 model.train() # Restore train mode if it was initially training
             break

        if not pruning_groups:
            print(f"Iteration {iteration}: Pruner found no more candidates. Stopping.")
            break

        for group in pruning_groups:
            group.prune()

        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad() # Clean up gradients
            if not original_mode_is_train: # If model was originally in eval, set it back
                 model.eval()
        elif original_mode_is_train: # If it was originally training and not Taylor/Hessian
            model.train()


        current_macs, current_params = calculate_macs(model, example_input)
        print(
            f"Iter {iteration: >3}/{max_iterations} | "
            f"MACs: {macs_before_step:,.0f} -> {current_macs:,.0f} "
            f"({(macs_before_step-current_macs)/macs_before_step*100 if macs_before_step > 0 else 0:+.1f}%) | "
            f"Params: {params_before_step:,.0f} -> {current_params:,.0f} "
            f"({(params_before_step-current_params)/params_before_step*100 if params_before_step > 0 else 0:+.1f}%)"
        )

        if current_macs >= macs_before_step and current_params >= params_before_step and iteration > 1:
            print(f"Iteration {iteration}: No reduction. Stopping.")
            break

        if not original_mode_is_train : model.eval() # Ensure eval mode if it was initially

    print(f"--- Finished Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    if iteration >= max_iterations:
        print(f"Warning: Reached maximum pruning iterations ({max_iterations}).")

    final_macs, final_params = calculate_macs(model, example_input)
    print(f"Initial | MACs: {initial_macs:,.0f}, Params: {initial_params:,.0f}")
    print(f"Final   | MACs: {final_macs:,.0f}, Params: {final_params:,.0f}")
    print(f"Target  | MACs: {target_macs:,.0f}, Params: {target_params:,.0f}")

    macs_reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    params_reduction = (initial_params - final_params) / initial_params * 100 if initial_params > 0 else 0
    print(f"Reduction | MACs: {macs_reduction:.2f}%, Params: {params_reduction:.2f}%")

    if final_macs > target_macs or final_params > target_params:
         print("Warning: Pruning finished, but target threshold(s) were not fully met.")

    model.eval() # Ensure final model state is eval
    return model


In [17]:
if __name__ == "__main__":
    main()

Using dataset directory: /home/muis/thesis/github-repo/master-thesis/cnn/resNet/data
Splitting training data: 45000 for training (with augmentation), 5000 for validation.
--- Initial Training for resnet18 ---
Starting training for resnet18 Initial with patience=10




resnet18 Initial - Epoch 1/30: Train Loss=1.5045, Train Acc=44.74% | Val Loss=1.2184, Val Acc=55.32% | LR: 1.0e-03
****** New best validation accuracy for resnet18 Initial: 55.32% (Epoch 1) ******
resnet18 Initial - Epoch 2/30: Train Loss=1.0273, Train Acc=63.27% | Val Loss=0.9467, Val Acc=66.22% | LR: 1.0e-03
****** New best validation accuracy for resnet18 Initial: 66.22% (Epoch 2) ******
resnet18 Initial - Epoch 3/30: Train Loss=0.8240, Train Acc=70.76% | Val Loss=0.7191, Val Acc=76.20% | LR: 1.0e-03
****** New best validation accuracy for resnet18 Initial: 76.20% (Epoch 3) ******
resnet18 Initial - Epoch 4/30: Train Loss=0.6754, Train Acc=76.54% | Val Loss=0.7447, Val Acc=75.12% | LR: 1.0e-03
Validation accuracy did not improve for 1 epoch(s). Best: 76.20%
resnet18 Initial - Epoch 5/30: Train Loss=0.5830, Train Acc=79.77% | Val Loss=0.5466, Val Acc=81.22% | LR: 1.0e-03
****** New best validation accuracy for resnet18 Initial: 81.22% (Epoch 5) ******
resnet18 Initial - Epoch 6/30: T



Step 2/15 (Non-Taylor): MACs 0.479 G, Size 36.94 MB
Step 3/15 (Non-Taylor): MACs 0.447 G, Size 34.40 MB
Step 4/15 (Non-Taylor): MACs 0.414 G, Size 31.85 MB
Step 5/15 (Non-Taylor): MACs 0.384 G, Size 29.50 MB
Step 6/15 (Non-Taylor): MACs 0.354 G, Size 27.17 MB
Step 7/15 (Non-Taylor): MACs 0.327 G, Size 25.00 MB
Step 8/15 (Non-Taylor): MACs 0.294 G, Size 22.81 MB
Step 9/15 (Non-Taylor): MACs 0.269 G, Size 20.83 MB
Step 10/15 (Non-Taylor): MACs 0.245 G, Size 18.88 MB
Step 11/15 (Non-Taylor): MACs 0.222 G, Size 17.08 MB
Step 12/15 (Non-Taylor): MACs 0.198 G, Size 15.29 MB
Step 13/15 (Non-Taylor): MACs 0.178 G, Size 13.68 MB
Step 14/15 (Non-Taylor): MACs 0.158 G, Size 12.11 MB
Targets reached at step 14 for non-Taylor strategy.
After pruning: MACs 0.158 G, Size 12.11 MB
Pruning targets achieved.
✅ Model saved as ONNX to ./output_resnet_finetuned/resnet18/strategies/resnet18_magnitude_pruned.onnx

--- Fine-tuning resnet18 after magnitude pruning ---
Starting training for resnet18 Pruned-magn



✅ Model saved as ONNX to ./output_resnet_finetuned/resnet18/strategies/resnet18_Taylor_pruned.onnx

--- Fine-tuning resnet18 after Taylor pruning ---
Starting training for resnet18 Pruned-Taylor with patience=10
resnet18 Pruned-Taylor - Epoch 1/50: Train Loss=0.2950, Train Acc=89.78% | Val Loss=0.3682, Val Acc=87.52% | LR: 5.0e-04
****** New best validation accuracy for resnet18 Pruned-Taylor: 87.52% (Epoch 1) ******
resnet18 Pruned-Taylor - Epoch 2/50: Train Loss=0.2817, Train Acc=90.17% | Val Loss=0.3832, Val Acc=87.20% | LR: 5.0e-04
Validation accuracy did not improve for 1 epoch(s). Best: 87.52%
resnet18 Pruned-Taylor - Epoch 3/50: Train Loss=0.2598, Train Acc=90.90% | Val Loss=0.3390, Val Acc=88.68% | LR: 5.0e-04
****** New best validation accuracy for resnet18 Pruned-Taylor: 88.68% (Epoch 3) ******
resnet18 Pruned-Taylor - Epoch 4/50: Train Loss=0.2442, Train Acc=91.58% | Val Loss=0.3223, Val Acc=89.24% | LR: 5.0e-04
****** New best validation accuracy for resnet18 Pruned-Taylor: