In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F # Added for ResNet (F.relu)
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
from torch_pruning.pruner.algorithms.scheduler import linear_scheduler
# from torchsummary import summary # Not used directly, can be kept or removed

# For data loading (if not using a custom one)
import torchvision
import torchvision.transforms as transforms
import os # Already implicitly used by some functions, made explicit
from cnn.resNet.resnet_example import get_data_loaders

### ResNet Blocks and Model Definition

In [2]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # Adapted for CIFAR-10: kernel_size 3, stride 1
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # The final FC layer name is 'linear'
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride_val in strides:
            layers.append(block(self.in_planes, planes, stride_val))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # For CIFAR-10, output of layer4 will be 4x4 if input is 32x32
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)

def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

In [5]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs, params

### compare results of different pruning strategies

In [6]:
def compare_results(results):
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

### compare and plot results

In [7]:
def compare_results_and_plot(results, output_dir='output'):
    """
    Print a comparison table and generate bar charts for MACs, model size, and accuracy
    for each pruning strategy, including the initial model.
    """
    # Print comparison table
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

    # Generate bar charts
    os.makedirs(output_dir, exist_ok=True)
    strategies = ['initial'] + [s for s in results if s != 'initial']
    metrics_to_plot = ['macs', 'size_mb', 'accuracy'] # Renamed to avoid conflict
    titles = {
        'macs': 'MACs Comparison',
        'size_mb': 'Model Size (MB) Comparison',
        'accuracy': 'Accuracy (%) Comparison'
    }
    y_labels = {
        'macs': 'MACs (Millions)',
        'size_mb': 'Size (MB)',
        'accuracy': 'Accuracy (%)'
    }

    colors = plt.cm.tab10(range(len(strategies)))
    for metric_name in metrics_to_plot:
        values = [results[strategy][metric_name] / 1e6 if metric_name == 'macs' else results[strategy][metric_name]
                  for strategy in strategies]  # Convert MACs to millions
        plt.figure(figsize=(12, 6))
        bars = plt.bar(strategies, values, color=colors)
        plt.xlabel('Strategy')
        plt.ylabel(y_labels[metric_name])
        plt.title(titles[metric_name])
        plt.xticks(rotation=45, ha='right')

        # Add value labels
        for bar in bars:
            yval = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., yval, f'{yval:.2f}', ha='center', va='bottom')

        # Add initial model reference line
        if 'initial' in results: # Ensure initial results exist
            initial_value = results['initial'][metric_name] / 1e6 if metric_name == 'macs' else results['initial'][metric_name]
            plt.axhline(y=initial_value, color='r', linestyle='--', label='Initial')
            plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric_name}_comparison.png'))
        plt.close()

### Load Model

In [8]:
def load_model(model_class, num_classes, path, device): # model_class and num_classes added for flexibility
    model = model_class(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    return model

### Utility function to save the model

In [9]:
def save_model(model, path, example_input=None):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    if example_input is not None:
        onnx_path = path.replace('.pth', '.onnx')
        save_model_as_onnx(model, example_input, onnx_path)

### Evaluate the model

In [10]:
def evaluate_model(model, test_loader, example_input, device):
    model.eval()
    # Calculate metrics
    macs, _ = calculate_macs(model, example_input.to(device)) # ensure example_input on correct device
    params = sum(p.numel() for p in model.parameters() if p.requires_grad) # Count only trainable params
    size_mb = params * 4 / 1e6 # Approximation using 4 bytes per param (float32)

    # Calculate accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = [d.to(device) for d in data]
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy_val = 0
    if total > 0:
        accuracy_val = 100 * correct / total


    return {
        'macs': macs,
        'size_mb': size_mb,
        'accuracy': accuracy_val
    }

### Prune the model


In [11]:
def prune_model(model, example_input, target_macs, strategy, iterative_steps=5):
    if isinstance(strategy['importance'], tp.importance.TaylorImportance):
        pruning_ratio = 0.1
    else:
        pruning_ratio = 0.5

    # IMPORTANT: Update ignored_layers for ResNet's classifier
    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else []
    if hasattr(model, 'fc') and model.fc not in ignored_layers_list: # for general models
        ignored_layers_list.append(model.fc)


    pruner = strategy['pruner'](
        model,
        example_input,
        importance=strategy['importance'],
        iterative_steps=iterative_steps,
        ch_sparsity=pruning_ratio,
        root_module_types=[nn.Conv2d],
        ignored_layers=ignored_layers_list, # Use updated list
    )

    current_macs, base_nparams = calculate_macs(model, example_input)
    initial_macs = current_macs # Store initial MACs for logging

    for i in range(iterative_steps):
            if isinstance(strategy['importance'], tp.importance.TaylorImportance):
                # Ensure model is in training mode for grad calculation if needed
                is_training = model.training
                model.train()
                loss = model(example_input).sum() # a dummy loss for TaylorImportance
                loss.backward()
                if not is_training: model.eval() # Revert to original mode

            pruning_occurred_this_step = False
            for g in pruner.step(interactive=True):
                g.prune()
                pruning_occurred_this_step = True

            if not pruning_occurred_this_step:
                print(f"  Iter %d/%d, No prunable groups found. Stopping." % (i + 1, iterative_steps))
                break


            macs, nparams = tp.utils.count_ops_and_params(model, example_input)
            print(
                "  Iter %d/%d, Params: %.2f M => %.2f M"
                % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
            )
            print(
                "  Iter %d/%d, MACs: %.2f G => %.2f G (Target: %.2f G)"
                % (i + 1, iterative_steps, initial_macs / 1e9, macs / 1e9, target_macs / 1e9)
            )
            initial_macs = macs # Update for next iteration's "before" MACs
            base_nparams = nparams

            # Optional: Check if target_macs is met, though the original loop structure doesn't use it to break
            if macs <= target_macs:
                print(f"  Target MACs ({target_macs/1e9:.2f}G) reached or passed.")
                # break # uncomment if you want to stop early

    return model

### Train the model

In [16]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for batch_idx, data in enumerate(train_loader):
            inputs, labels = [d.to(device) for d in data]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

            if batch_idx % 100 == 0: # Print progress every 100 batches
                 print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}: "
                       f"Loss={loss.item():.4f}")


        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}: Avg Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%")
    return model

### Main workflow

In [14]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CURRENT_MODEL_NAME = "resnet18" # Change to "resnet34", "resnet50" as needed
    MODEL_CLASS = ResNet18 # Change to ResNet34, ResNet50
    NUM_CLASSES = 10


    config = {
        'strategies': {
            'magnitude': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.MagnitudeImportance(p=2)},
            'bn_scale': {'pruner': tp.pruner.BNScalePruner, 'importance': tp.importance.BNScaleImportance()},
            'group_norm': {'pruner': tp.pruner.GroupNormPruner, 'importance': tp.importance.GroupMagnitudeImportance(p=1)},
            'random': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.RandomImportance()},
            'Taylor': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.TaylorImportance()},
            'Hessian': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.GroupHessianImportance()},
            'lamp': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.LAMPImportance(p=2)},
            'geometry': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.FPGMImportance()}
        },
        'target_macs_sparsity': 0.5, # Target 50% reduction in MACs from initial
        'train_epochs': 10,       # Epochs for initial training
        'fine_tune_epochs': 20, # Epochs for fine-tuning after pruning
        'data_dir': './data', # CIFAR-10 data directory
        'output_dir': f'./output_resnet/{CURRENT_MODEL_NAME}/strategies',
        'iterative_steps': 5,     # Iterative steps for pruning methods
        'learning_rate_initial': 0.001,
        'learning_rate_finetune': 0.0005, # Potentially smaller LR for fine-tuning
    }
    os.makedirs(config['output_dir'], exist_ok=True)


    # Initialize model and data
    model = MODEL_CLASS(num_classes=NUM_CLASSES).to(device)
    train_loader, test_loader = get_data_loaders(config['data_dir']) # Using the defined CIFAR-10 loader
    example_input = torch.randn(1, 3, 32, 32).to(device) # CIFAR-10 example input

    initial_model_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_initial.pth")

    if not os.path.exists(initial_model_path):
        print(f"--- Initial Training for {CURRENT_MODEL_NAME} ---")
        model = train_model(
            model=model, train_loader=train_loader,
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optim.Adam(model.parameters(), lr=config['learning_rate_initial']),
            device=device, num_epochs=config['train_epochs']
        )
        save_model(model, initial_model_path, example_input)
        print(f"Initial model saved to {initial_model_path}")
    else:
        print(f"Loading existing initial model from {initial_model_path}")
        # model = MODEL_CLASS(num_classes=NUM_CLASSES).to(device) # Create instance first
        # model.load_state_dict(torch.load(initial_model_path))
        model = load_model(MODEL_CLASS, NUM_CLASSES, initial_model_path, device)


    results = {}
    print(f"\n--- Evaluating Initial {CURRENT_MODEL_NAME} ---")
    initial_metrics = evaluate_model(model, test_loader, example_input, device)
    results['initial'] = initial_metrics
    print(f"Initial Metrics: {initial_metrics}")

    initial_macs, _ = calculate_macs(model, example_input)
    # target_macs_absolute = initial_macs * config['target_macs_sparsity']
    target_macs_for_prune_model = initial_macs * config['target_macs_sparsity'] # For `prune_model`
    target_macs_for_gr = initial_macs * 0.3 # Example: Tighter target for GR threshold method (e.g., 70% reduction)
    target_size_mb_for_gr = results['initial']['size_mb'] * 0.3 # Example: Target 30% of original size
    target_params_for_gem = sum(p.numel() for p in model.parameters()) * 0.3 # Example: Target 30% of original params


    # Select which pruning method to use from the main function options:
    # 1. prune_model (original simple iterative pruning)
    # 2. gr_prune_model_with_threshold
    # 3. gem_prune_model_by_threshold
    CHOSEN_PRUNING_FUNCTION = "gr" # "simple", "gr", or "gem"


    for strategy_name, strategy_details in config['strategies'].items():
        print(f"\n--- Pruning {CURRENT_MODEL_NAME} with Strategy: {strategy_name} ---")
        # model_copy = MODEL_CLASS(num_classes=NUM_CLASSES).to(device)
        # model_copy.load_state_dict(torch.load(initial_model_path)) # Load fresh copy
        model_copy = load_model(MODEL_CLASS, NUM_CLASSES, initial_model_path, device)


        if CHOSEN_PRUNING_FUNCTION == "simple":
            pruned_model = prune_model(
                model=model_copy, example_input=example_input,
                target_macs=target_macs_for_prune_model, # Pass the absolute target MACs
                strategy=strategy_details,
                iterative_steps=config['iterative_steps'],
            )
        elif CHOSEN_PRUNING_FUNCTION == "gr":
            pruned_model = gr_prune_model_with_threshold(
                model=model_copy, example_input=example_input,
                target_macs=target_macs_for_gr, target_size_mb=target_size_mb_for_gr,
                strategy=strategy_details, iterative_steps=config['iterative_steps']
            )
        elif CHOSEN_PRUNING_FUNCTION == "gem":
             pruned_model = gem_prune_model_by_threshold(
                model=model_copy, example_input=example_input,
                target_macs=target_macs_for_gr, # Use appropriate target
                target_params=target_params_for_gem,
                strategy=strategy_details,
                max_iterations=20, # Adjust as needed for GEM
                step_ch_sparsity=0.2 # Tune this for GEM
            )
        else:
            raise ValueError(f"Unknown CHOSEN_PRUNING_FUNCTION: {CHOSEN_PRUNING_FUNCTION}")


        pruned_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_{strategy_name}_pruned.pth")
        save_model(pruned_model, pruned_path, example_input)
        print(f"Pruned model saved to {pruned_path}")

        print(f"\n--- Fine-tuning {CURRENT_MODEL_NAME} after {strategy_name} pruning ---")
        fine_tuned_model = train_model(
            model=pruned_model, train_loader=train_loader,
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optim.Adam(pruned_model.parameters(), lr=config['learning_rate_finetune']),
            device=device, num_epochs=config['fine_tune_epochs']
        )

        print(f"\n--- Evaluating fine-tuned {CURRENT_MODEL_NAME} ({strategy_name}) ---")
        results[strategy_name] = evaluate_model(
            model=fine_tuned_model, test_loader=test_loader,
            example_input=example_input, device=device
        )
        print(f"Metrics for {strategy_name}: {results[strategy_name]}")


        final_path = os.path.join(config['output_dir'], f"{CURRENT_MODEL_NAME}_{strategy_name}_final.pth")
        save_model(fine_tuned_model, final_path, example_input)
        print(f"Final fine-tuned model saved to {final_path}")


    compare_results_and_plot(results, output_dir=config['output_dir'])
    print("ResNet pruning workflow completed successfully!")

### GR Pruning Model

In [12]:
def gr_prune_model_with_threshold(model, example_input, target_macs, target_size_mb, strategy, iterative_steps=5):
    current_macs, nparams = tp.utils.count_ops_and_params(model, example_input)
    current_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
    print(f"Initial MACs: {current_macs / 1e9:.3f} G, Size: {current_size_mb:.2f} MB")

    pruning_ratio = 0.5
    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else []
    if hasattr(model, 'fc') and model.fc not in ignored_layers_list:
        ignored_layers_list.append(model.fc)


    if isinstance(strategy['importance'], tp.importance.TaylorImportance):
        pruner = strategy['pruner'](
            model, example_input, importance=strategy['importance'],
            iterative_steps=1, ch_sparsity=pruning_ratio / iterative_steps if iterative_steps > 0 else pruning_ratio,
            root_module_types=[nn.Conv2d], ignored_layers=ignored_layers_list
        )
        for i in range(iterative_steps):
            is_training = model.training
            model.train()
            model.zero_grad()
            loss = model(example_input).sum()
            loss.backward()
            if not is_training: model.eval()

            pruned_something = False
            for group in pruner.step(interactive=True):
                group.prune()
                pruned_something = True

            if not pruned_something and i > 0 : # Don't break on first iter if nothing found
                print(f"Step {i+1}/{iterative_steps}: No more prunable elements found.")
                break


            current_macs, _ = tp.utils.count_ops_and_params(model, example_input)
            current_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
            print(f"Step {i+1}/{iterative_steps}: MACs {current_macs / 1e9:.3f} G, Size {current_size_mb:.2f} MB")
            if current_macs <= target_macs and current_size_mb <= target_size_mb:
                print(f"Targets reached at step {i+1}")
                break
            model.zero_grad() # important for Taylor if used in loop without re-instantiating pruner
    else:
        pruner = strategy['pruner'](
            model, example_input, importance=strategy['importance'],
            iterative_steps=iterative_steps, ch_sparsity=pruning_ratio,
            root_module_types=[nn.Conv2d], ignored_layers=ignored_layers_list
        )
        # pruner.step() # This typically means run all iterative steps internally.
        # If you want fine-grained control like Taylor above, loop pruner.step()
        for i in range(iterative_steps):
            pruned_something = False
            for group in pruner.step(interactive=True): # Use interactive for explicit pruning
                group.prune()
                pruned_something = True

            if not pruned_something and i > 0:
                 print(f"Step {i+1}/{iterative_steps}: No more prunable elements found for non-Taylor strategy.")
                 break

            current_macs_step, _ = tp.utils.count_ops_and_params(model, example_input)
            current_size_mb_step = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
            print(f"Step {i+1}/{iterative_steps} (Non-Taylor): MACs {current_macs_step / 1e9:.3f} G, Size {current_size_mb_step:.2f} MB")
            if current_macs_step <= target_macs and current_size_mb_step <= target_size_mb:
                print(f"Targets reached at step {i+1} for non-Taylor strategy.")
                break


    final_macs, _ = tp.utils.count_ops_and_params(model, example_input)
    final_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2)
    print(f"After pruning: MACs {final_macs / 1e9:.3f} G, Size {final_size_mb:.2f} MB")

    if final_macs <= target_macs and final_size_mb <= target_size_mb:
        print("Pruning targets achieved.")
    else:
        print("Warning: Pruning targets not fully achieved.")
    return model

### gem pruning model

In [None]:
def gem_prune_model_by_threshold(model, example_input, target_macs, target_params, strategy, max_iterations=100, step_ch_sparsity=0.1):
    device = example_input.device
    model.to(device)

    print(f"--- Starting Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    print(f"Target MACs: {target_macs:,.0f}, Target Params: {target_params:,.0f}")

    ignored_layers_list = [model.linear] if hasattr(model, 'linear') else []
    if hasattr(model, 'fc') and model.fc not in ignored_layers_list:
        ignored_layers_list.append(model.fc)

    pruner = strategy['pruner'](
        model, example_input, importance=strategy['importance'],
        ch_sparsity=step_ch_sparsity, root_module_types=[nn.Conv2d],
        ignored_layers=ignored_layers_list,
        # round_to=8, # Optional
    )

    current_macs, current_params = calculate_macs(model, example_input)
    initial_macs, initial_params = current_macs, current_params
    print(f"Initial | MACs: {current_macs:,.0f}, Params: {current_params:,.0f}")

    iteration = 0
    # model.eval() # Moved inside loop for Taylor/Hessian

    while (current_macs > target_macs or current_params > target_params) and iteration < max_iterations:
        iteration += 1
        macs_before_step = current_macs
        params_before_step = current_params

        original_mode_is_train = model.training # Store original mode

        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
             model.train()
             for param in model.parameters():
                 param.requires_grad_(True)
             loss = model(example_input).mean()
             try:
                 model.zero_grad() # Zero gradients before backward
                 loss.backward()
             except Exception as e:
                 print(f"Error during backward pass for importance calc (Iter {iteration}): {e}")
                 if not original_mode_is_train: model.eval() # Restore mode
                 break
        else: # Ensure eval mode for other strategies
            model.eval()


        try:
            pruning_groups = list(pruner.step(interactive=True))
        except Exception as e:
             print(f"Error during pruner.step() (Iter {iteration}): {e}")
             if not original_mode_is_train and isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
                 model.eval() # Restore eval mode
             elif original_mode_is_train and not isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
                 model.train() # Restore train mode if it was initially training
             break

        if not pruning_groups:
            print(f"Iteration {iteration}: Pruner found no more candidates. Stopping.")
            break

        for group in pruning_groups:
            group.prune()

        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad() # Clean up gradients
            if not original_mode_is_train: # If model was originally in eval, set it back
                 model.eval()
        elif original_mode_is_train: # If it was originally training and not Taylor/Hessian
            model.train()


        current_macs, current_params = calculate_macs(model, example_input)
        print(
            f"Iter {iteration: >3}/{max_iterations} | "
            f"MACs: {macs_before_step:,.0f} -> {current_macs:,.0f} "
            f"({(macs_before_step-current_macs)/macs_before_step*100 if macs_before_step > 0 else 0:+.1f}%) | "
            f"Params: {params_before_step:,.0f} -> {current_params:,.0f} "
            f"({(params_before_step-current_params)/params_before_step*100 if params_before_step > 0 else 0:+.1f}%)"
        )

        if current_macs >= macs_before_step and current_params >= params_before_step and iteration > 1:
            print(f"Iteration {iteration}: No reduction. Stopping.")
            break

        if not original_mode_is_train : model.eval() # Ensure eval mode if it was initially

    print(f"--- Finished Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    if iteration >= max_iterations:
        print(f"Warning: Reached maximum pruning iterations ({max_iterations}).")

    final_macs, final_params = calculate_macs(model, example_input)
    print(f"Initial | MACs: {initial_macs:,.0f}, Params: {initial_params:,.0f}")
    print(f"Final   | MACs: {final_macs:,.0f}, Params: {final_params:,.0f}")
    print(f"Target  | MACs: {target_macs:,.0f}, Params: {target_params:,.0f}")

    macs_reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    params_reduction = (initial_params - final_params) / initial_params * 100 if initial_params > 0 else 0
    print(f"Reduction | MACs: {macs_reduction:.2f}%, Params: {params_reduction:.2f}%")

    if final_macs > target_macs or final_params > target_params:
         print("Warning: Pruning finished, but target threshold(s) were not fully met.")

    model.eval() # Ensure final model state is eval
    return model


In [17]:
if __name__ == "__main__":
    main()

Using dataset directory: /home/muis/thesis/github-repo/master-thesis/cnn/resNet/data
--- Initial Training for resnet18 ---
Epoch 1/10, Batch 0/782: Loss=2.4236
Epoch 1/10, Batch 100/782: Loss=1.4629
Epoch 1/10, Batch 200/782: Loss=1.2940
Epoch 1/10, Batch 300/782: Loss=1.3454
Epoch 1/10, Batch 400/782: Loss=1.1538
Epoch 1/10, Batch 500/782: Loss=1.0824
Epoch 1/10, Batch 600/782: Loss=1.4907
Epoch 1/10, Batch 700/782: Loss=1.0082
Epoch 1/10: Avg Loss=1.3176, Accuracy=52.05%
Epoch 2/10, Batch 0/782: Loss=1.1416
Epoch 2/10, Batch 100/782: Loss=1.0295
Epoch 2/10, Batch 200/782: Loss=0.9159
Epoch 2/10, Batch 300/782: Loss=0.8815
Epoch 2/10, Batch 400/782: Loss=0.7155
Epoch 2/10, Batch 500/782: Loss=0.6321
Epoch 2/10, Batch 600/782: Loss=0.8202
Epoch 2/10, Batch 700/782: Loss=0.6910
Epoch 2/10: Avg Loss=0.8107, Accuracy=71.51%
Epoch 3/10, Batch 0/782: Loss=0.5486
Epoch 3/10, Batch 100/782: Loss=0.5575
Epoch 3/10, Batch 200/782: Loss=0.6189
Epoch 3/10, Batch 300/782: Loss=0.4937
Epoch 3/10, B



Step 2/5 (Non-Taylor): MACs 0.354 G, Size 27.17 MB
Step 3/5 (Non-Taylor): MACs 0.269 G, Size 20.83 MB
Step 4/5 (Non-Taylor): MACs 0.198 G, Size 15.29 MB
Step 5/5 (Non-Taylor): MACs 0.140 G, Size 10.67 MB
Targets reached at step 5 for non-Taylor strategy.
After pruning: MACs 0.140 G, Size 10.67 MB
Pruning targets achieved.
✅ Model saved as ONNX to ./output_resnet/resnet18/strategies/resnet18_magnitude_pruned.onnx
Pruned model saved to ./output_resnet/resnet18/strategies/resnet18_magnitude_pruned.pth

--- Fine-tuning resnet18 after magnitude pruning ---
Epoch 1/20, Batch 0/782: Loss=2.0543
Epoch 1/20, Batch 100/782: Loss=0.4324
Epoch 1/20, Batch 200/782: Loss=0.4775
Epoch 1/20, Batch 300/782: Loss=0.1718
Epoch 1/20, Batch 400/782: Loss=0.1591
Epoch 1/20, Batch 500/782: Loss=0.1989
Epoch 1/20, Batch 600/782: Loss=0.0997
Epoch 1/20, Batch 700/782: Loss=0.1946
Epoch 1/20: Avg Loss=0.2654, Accuracy=91.14%
Epoch 2/20, Batch 0/782: Loss=0.0930
Epoch 2/20, Batch 100/782: Loss=0.0716
Epoch 2/20,



✅ Model saved as ONNX to ./output_resnet/resnet18/strategies/resnet18_Taylor_pruned.onnx
Pruned model saved to ./output_resnet/resnet18/strategies/resnet18_Taylor_pruned.pth

--- Fine-tuning resnet18 after Taylor pruning ---
Epoch 1/20, Batch 0/782: Loss=0.0542
Epoch 1/20, Batch 100/782: Loss=0.0340
Epoch 1/20, Batch 200/782: Loss=0.0478
Epoch 1/20, Batch 300/782: Loss=0.0250
Epoch 1/20, Batch 400/782: Loss=0.0634
Epoch 1/20, Batch 500/782: Loss=0.0072
Epoch 1/20, Batch 600/782: Loss=0.0255
Epoch 1/20, Batch 700/782: Loss=0.0442
Epoch 1/20: Avg Loss=0.0426, Accuracy=98.55%
Epoch 2/20, Batch 0/782: Loss=0.0318
Epoch 2/20, Batch 100/782: Loss=0.0122
Epoch 2/20, Batch 200/782: Loss=0.0130
Epoch 2/20, Batch 300/782: Loss=0.0107
Epoch 2/20, Batch 400/782: Loss=0.0149
Epoch 2/20, Batch 500/782: Loss=0.0125
Epoch 2/20, Batch 600/782: Loss=0.0124
Epoch 2/20, Batch 700/782: Loss=0.0012
Epoch 2/20: Avg Loss=0.0194, Accuracy=99.37%
Epoch 3/20, Batch 0/782: Loss=0.0004
Epoch 3/20, Batch 100/782: L