## import necessary libraries for pruning

In [47]:
import torch, copy
import torch.nn as nn
import torch_pruning as tp
import matplotlib.pyplot as plt
from torch import optim
from torch_pruning.pruner.algorithms.scheduler import linear_scheduler
from torchsummary import summary
import numpy as np
from cnn.resNet.resnet_example import get_data_loaders
import torch
from torch import nn

### Seed Network

In [19]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, expansion=6):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        # Standard PyTorch layers (NO torch_pruning wrappers needed)
        self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels * expansion)
        self.relu = nn.ReLU6(inplace=True)

        self.conv2 = nn.Conv2d(
            in_channels * expansion, in_channels * expansion, kernel_size=3,
            stride=stride, padding=1, groups=in_channels * expansion, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels * expansion)

        self.conv3 = nn.Conv2d(in_channels * expansion, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x if self.use_res_connect else None

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.use_res_connect:
            return identity + out
        else:
            return out

### Mask Network

In [20]:
class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        # --- Remove mask-related parameters ---
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU6(inplace=True)

        # Define blocks (no mask_index or mask)
        self.block1 = InvertedResidual(32, 16, stride=1)
        self.block2 = InvertedResidual(16, 24, stride=2)
        self.block3 = InvertedResidual(24, 32, stride=2)
        self.block4 = InvertedResidual(32, 64, stride=2)
        self.block5 = InvertedResidual(64, 96, stride=1)
        self.block6 = InvertedResidual(96, 160, stride=2)
        self.block7 = InvertedResidual(160, 320, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # --- Remove mask-based block skipping ---
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

In [23]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs, params

### compare results of different pruning strategies

In [24]:
def compare_results(results):
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

### compare and plot results

In [25]:
def compare_results_and_plot(results_dict, strategies_config, output_dir='output'):
    os.makedirs(output_dir, exist_ok=True)

    metrics_data = results_dict

    print("\n=== Pruning Strategy Comparison ===")
    header = f"{'Strategy':<35} | {'MACs':<12} | {'Params':<12} | {'Size (MiB)':<10} | {'Accuracy (%)':<12}"
    print(header)
    print("-" * len(header))

    if 'initial' in metrics_data:
        strat_name = 'initial'
        metrics = metrics_data[strat_name]
        print(f"{strat_name:<35} | {metrics['macs']:.2e} | {metrics['params']:.2e} | {metrics['size_mib']:>9.2f} | {metrics['accuracy']:>12.2f}")

    sorted_strategy_keys = sorted(strategies_config.keys())

    for strategy_key_orig in sorted_strategy_keys:
        strat_name_pruned = f"{strategy_key_orig}_pruned_not_finetuned"
        if strat_name_pruned in metrics_data:
            metrics = metrics_data[strat_name_pruned]
            print(f"{strat_name_pruned:<35} | {metrics['macs']:.2e} | {metrics['params']:.2e} | {metrics['size_mib']:>9.2f} | {metrics['accuracy']:>12.2f}")

        strat_name_final = strategy_key_orig
        if strat_name_final in metrics_data:
            metrics = metrics_data[strat_name_final]
            print(f"{strat_name_final:<35} | {metrics['macs']:.2e} | {metrics['params']:.2e} | {metrics['size_mib']:>9.2f} | {metrics['accuracy']:>12.2f}")

    plot_strategies_final_names = ['initial'] + sorted_strategy_keys

    metric_keys_to_plot = ['macs', 'params', 'size_mib', 'accuracy']
    titles = {
        'macs': 'MACs Comparison (Final Models)',
        'params': 'Parameters Comparison (Final Models)',
        'size_mib': 'Model Size (MiB) Comparison (Final Models)',
        'accuracy': 'Accuracy (%) Comparison (Final Models)'
    }
    y_labels = {
        'macs': 'MACs',
        'params': 'Parameters',
        'size_mib': 'Size (MiB)',
        'accuracy': 'Accuracy (%)'
    }

    num_strategies_for_plot = len(plot_strategies_final_names)

    # Fix for MatplotlibDeprecationWarning and color generation
    colors_cmap = plt.colormaps.get_cmap('tab10') # Get the colormap object

    for metric_key in metric_keys_to_plot:
        values = []
        valid_strategies_for_plot = []
        for strategy_name_for_plot in plot_strategies_final_names:
            if strategy_name_for_plot in metrics_data:
                actual_key_in_results = strategy_name_for_plot
                if actual_key_in_results in metrics_data and metric_key in metrics_data[actual_key_in_results]:
                    values.append(metrics_data[actual_key_in_results][metric_key])
                    valid_strategies_for_plot.append(strategy_name_for_plot)

        if not values:
            print(f"Skipping plot for {metric_key} as no data was found.")
            continue

        plt.figure(figsize=(max(12, int(1.5 * len(valid_strategies_for_plot))), 7)) # Dynamic width, ensure int

        # Generate colors for the valid strategies being plotted
        bar_colors = [colors_cmap(i % colors_cmap.N) for i in range(len(valid_strategies_for_plot))]
        bars = plt.bar(valid_strategies_for_plot, values, color=bar_colors)

        plt.xlabel('Strategy')
        plt.ylabel(y_labels[metric_key])
        plt.title(titles[metric_key])
        plt.xticks(rotation=45, ha='right')

        for bar in bars:
            yval = bar.get_height()
            if metric_key in ['macs', 'params']:
                 plt.text(bar.get_x() + bar.get_width()/2., yval, f'{yval:.2e}', ha='center', va='bottom')
            else:
                 plt.text(bar.get_x() + bar.get_width()/2., yval, f'{yval:.2f}', ha='center', va='bottom')

        if 'initial' in metrics_data and metric_key in metrics_data['initial']:
            initial_value = metrics_data['initial'][metric_key]
            # Fix for ValueError: Invalid format specifier
            if metric_key in ["macs", "params"]:
                label_text = f'Initial ({initial_value:.2e})'
            else:
                label_text = f'Initial ({initial_value:.2f})'
            plt.axhline(y=initial_value, color='r', linestyle='--', label=label_text)

        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric_key}_comparison_final.png'))
        plt.close()
    print(f"✅ Comparison plots saved to {output_dir}")

### Load Model

In [26]:
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

### Utility function to save the model

In [27]:
import os

def save_model(model, path, example_input=None):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    if example_input is not None:
        onnx_path = path.replace('.pth', '.onnx')
        save_model_as_onnx(model, example_input, onnx_path)

def load_model_state(model_class, path, device, *args, **kwargs):
    # *args, **kwargs are for model_class constructor (e.g. num_classes)
    model = model_class(*args, **kwargs)
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print(f"✅ Model loaded from {path} to {device}")
    return model

### Evaluate the model

In [28]:
def evaluate_model(model, test_loader, example_input_device, criterion_eval, device_eval):
    model.eval()
    model.to(device_eval)
    macs, num_params = calculate_macs(model, example_input_device)
    size_mib = num_params * 4 / (1024 * 1024)
    correct = 0; total = 0; running_loss = 0.0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = [d.to(device_eval) for d in data]
            outputs = model(inputs)
            loss = criterion_eval(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = running_loss / total if total > 0 else float('nan')
    accuracy = 100 * correct / total if total > 0 else 0
    return {'macs': macs, 'params': num_params, 'size_mib': size_mib, 'accuracy': accuracy, 'loss': avg_loss}


### Prune the model

In [29]:
def prune_model(model, example_input, target_macs, strategy, iterative_steps=5):
    #, iterative_pruning_ratio_scheduler=linear_scheduler()):
    if isinstance(strategy['importance'], tp.importance.TaylorImportance):
        pruning_ratio = 0.1
    else:
        pruning_ratio = 0.5

    pruner = strategy['pruner'](
        model,
        example_input,
        importance=strategy['importance'],
        iterative_steps=iterative_steps,
        ch_sparsity=pruning_ratio,  # Initial sparsity
        #iterative_pruning_ratio_scheduler=iterative_pruning_ratio_scheduler,
        root_module_types=[nn.Conv2d],
        ignored_layers=[model.fc],
    )

    current_macs, base_nparams = calculate_macs(model, example_input)
    # while current_macs > target_macs:
    #     pruner.step()
    #     current_macs = calculate_macs(model, example_input)


    for i in range(iterative_steps):
            if isinstance(strategy['importance'], tp.importance.TaylorImportance):
                loss = model(example_input).sum() # a dummy loss for TaylorImportance
                loss.backward()
            for g in pruner.step(interactive=True):
                g.prune()
            macs, nparams = tp.utils.count_ops_and_params(model, example_input)
            #print(model(example_input).shape)
            print(
                "  Iter %d/%d, Params: %.2f M => %.2f M"
                % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
            )
            print(
                "  Iter %d/%d, MACs: %.2f G => %.2f G"
                % (i + 1, iterative_steps, current_macs / 1e9, macs / 1e9)
            )

    return model

### Train the model

In [30]:
def train_model(model,
                train_loader,
                criterion,
                optimizer,
                device,
                num_epochs,
                val_loader=None, strategy_name="",
                early_stopping_patience=None,
                early_stopping_metric='val_loss',
                load_best_model_on_stop=True
                ): # <<<< ADDED EARLY STOPPING PARAMETERS

    model.to(device)
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    # Early stopping specific variables
    best_metric_score = float('inf') if early_stopping_metric == 'val_loss' else float('-inf')
    epochs_no_improve = 0
    best_model_state_dict = None # To store the state_dict of the best model

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        for i, data in enumerate(train_loader):
            inputs, labels = [d.to(device) for d in data]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader) if len(train_loader) > 0 else 0
        epoch_acc = 100 * correct_train / total_train if total_train > 0 else 0
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        log_msg = f"Strategy: {strategy_name} - Epoch {epoch+1}/{num_epochs}: Train Loss={epoch_loss:.4f}, Train Acc={epoch_acc:.2f}%"

        current_val_metric = None # To store the metric value for the current epoch
        if val_loader:
            model.eval()
            running_val_loss = 0.0
            correct_val = 0
            total_val = 0
            with torch.no_grad():
                for data_val in val_loader:
                    inputs_val, labels_val = [d.to(device) for d in data_val]
                    outputs_val = model(inputs_val)
                    val_loss_item = criterion(outputs_val, labels_val)
                    running_val_loss += val_loss_item.item()
                    _, predicted_val = torch.max(outputs_val.data, 1)
                    total_val += labels_val.size(0)
                    correct_val += (predicted_val == labels_val).sum().item()

            epoch_val_loss = running_val_loss / len(val_loader) if len(val_loader) > 0 else 0
            epoch_val_acc = 100 * correct_val / total_val if total_val > 0 else 0
            history['val_loss'].append(epoch_val_loss)
            history['val_acc'].append(epoch_val_acc)
            log_msg += f", Val Loss={epoch_val_loss:.4f}, Val Acc={epoch_val_acc:.2f}%"

            # Early stopping logic
            if early_stopping_patience is not None:
                if early_stopping_metric == 'val_loss':
                    current_val_metric = epoch_val_loss
                    if current_val_metric < best_metric_score:
                        best_metric_score = current_val_metric
                        epochs_no_improve = 0
                        if load_best_model_on_stop:
                            best_model_state_dict = copy.deepcopy(model.state_dict())
                        log_msg += " (New best val_loss)"
                    else:
                        epochs_no_improve += 1
                elif early_stopping_metric == 'val_acc':
                    current_val_metric = epoch_val_acc
                    if current_val_metric > best_metric_score:
                        best_metric_score = current_val_metric
                        epochs_no_improve = 0
                        if load_best_model_on_stop:
                            best_model_state_dict = copy.deepcopy(model.state_dict())
                        log_msg += " (New best val_acc)"
                    else:
                        epochs_no_improve += 1
                else:
                    # This case should ideally not be hit if parameters are validated upstream
                    # or have defaults, but good for robustness.
                    print(f"Warning: Unsupported early_stopping_metric: {early_stopping_metric}. Defaulting to 'val_loss'.")
                    early_stopping_metric = 'val_loss' # Fallback
                    # Re-evaluate for the current epoch with the fallback metric (could be complex, simpler to just warn and continue)

        else: # No validation loader
            history['val_loss'].append(None)
            history['val_acc'].append(None)
            if early_stopping_patience is not None: # Only warn if ES was intended
                print(f"Warning: Early stopping for '{strategy_name}' configured (patience: {early_stopping_patience}), but no validation loader provided. Early stopping will be inactive.")

        print(log_msg)

        # Check for early stopping condition
        if early_stopping_patience is not None and val_loader and epochs_no_improve >= early_stopping_patience:
            print(f"Early stopping triggered for '{strategy_name}' after {epoch+1} epochs. No improvement in '{early_stopping_metric}' for {early_stopping_patience} epochs.")
            if load_best_model_on_stop and best_model_state_dict is not None:
                print(f"Loading best model weights from epoch {epoch + 1 - epochs_no_improve} with {early_stopping_metric}: {best_metric_score:.4f}")
                model.load_state_dict(best_model_state_dict)
            break # Stop training loop

    # After the loop, if training completed fully (not early stopped) OR early stopped but we want the best model
    if load_best_model_on_stop and best_model_state_dict is not None:
        # This check ensures we load the best model if the last epoch wasn't the best one,
        # even if early stopping didn't trigger (e.g., training ran for all num_epochs)
        # Or if it did trigger, this ensures the best model is loaded.
        # Need to get the last recorded validation metric if val_loader was present
        last_recorded_val_metric = None
        if val_loader and history[early_stopping_metric]: # Check if history has entries for the metric
             # Get the last non-None value for the metric
            valid_metrics = [m for m in history[early_stopping_metric] if m is not None]
            if valid_metrics:
                last_recorded_val_metric = valid_metrics[-1]

        should_load_best = True # Default to loading if we have a best_model_state_dict
        if last_recorded_val_metric is not None: # If we have a final metric to compare
            if early_stopping_metric == 'val_loss':
                if last_recorded_val_metric <= best_metric_score: # Current/last model is as good or better
                    should_load_best = False
            elif early_stopping_metric == 'val_acc':
                if last_recorded_val_metric >= best_metric_score: # Current/last model is as good or better
                    should_load_best = False

        if should_load_best:
            print(f"Training for '{strategy_name}' finished. Loading best recorded model state with {early_stopping_metric}: {best_metric_score:.4f}")
            model.load_state_dict(best_model_state_dict)

    return model, history

### Plotting fine-tuning curves

In [48]:
import numpy as np
def plot_finetuning_curves(history, plot_title_suffix, output_dir_plots, model_macs_val):
    os.makedirs(output_dir_plots, exist_ok=True)
    actual_epochs = len(history['train_loss'])
    epochs_range = range(1, actual_epochs + 1)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1); plt.plot(epochs_range, history['train_loss'][:actual_epochs], 'bo-', label='Train Loss')
    if history.get('val_loss') and any(v is not None for v in history['val_loss']):
        plt.plot(epochs_range, history['val_loss'][:actual_epochs], 'ro-', label='Val Loss')
    plt.title(f'Loss ({plot_title_suffix})\nMACs: {model_macs_val:.2e}'); plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2); plt.plot(epochs_range, history['train_acc'][:actual_epochs], 'bo-', label='Train Acc')
    if history.get('val_acc') and any(v is not None for v in history['val_acc']):
        plt.plot(epochs_range, history['val_acc'][:actual_epochs], 'ro-', label='Val Acc')
    plt.title(f'Accuracy ({plot_title_suffix})\nMACs: {model_macs_val:.2e}'); plt.xlabel('Epochs'); plt.ylabel('Acc %'); plt.legend(); plt.grid(True)
    plt.tight_layout()
    safe_suffix = plot_title_suffix.replace(' ', '_').replace('/', '_').replace(':', '_')
    plt.savefig(os.path.join(output_dir_plots, f'finetune_curves_{safe_suffix}.png')); plt.close()
    print(f"✅ Fine-tuning curves for {plot_title_suffix} saved.")


def plot_metrics_vs_ratio_all_strategies(results_data_plot, ratios_tested_plot, output_dir_main_plots):
    os.makedirs(output_dir_main_plots, exist_ok=True)
    strategies_plot = list(results_data_plot.keys())
    if not strategies_plot: print("No strategies in results for ratio plots."); return
    ratios_str_plot = [f"{r:.1f}" for r in ratios_tested_plot]; num_ratios_plot = len(ratios_tested_plot); num_strategies_plot = len(strategies_plot)
    bar_width_plot = 0.8 / num_strategies_plot; index_plot = np.arange(num_ratios_plot)
    colors_map_plot = plt.colormaps.get_cmap('tab10')

    # MACs Plot
    plt.figure(figsize=(max(12, int(1.8 * num_ratios_plot * num_strategies_plot / 3)), 7))
    for i, s_name in enumerate(strategies_plot):
        macs_vals = [results_data_plot[s_name].get(r_val, {}).get('macs', np.nan) for r_val in ratios_tested_plot]
        plt.bar(index_plot + i * bar_width_plot, macs_vals, bar_width_plot, label=s_name, color=colors_map_plot(i % colors_map_plot.N))
    plt.xlabel('Pruning Ratio (ch_sparsity)'); plt.ylabel('MACs (Log Scale)'); plt.title('MACs vs. Pruning Ratio (Fine-tuned Models)')
    plt.xticks(index_plot + bar_width_plot * (num_strategies_plot - 1) / 2, ratios_str_plot); plt.yscale('log')
    plt.legend(title="Strategy", bbox_to_anchor=(1.02, 1), loc='upper left'); plt.grid(True, which="both", ls="-", alpha=0.3)
    plt.tight_layout(rect=[0, 0, 0.88, 1]); plt.savefig(os.path.join(output_dir_main_plots, 'MACs_vs_Ratio_by_Strategy.png')); plt.close()
    print(f"✅ MACs vs. Ratio by Strategy plot saved to {output_dir_main_plots}")

    # Loss Plot
    plt.figure(figsize=(max(12, int(1.8 * num_ratios_plot * num_strategies_plot / 3)), 7))
    for i, s_name in enumerate(strategies_plot):
        loss_vals = [results_data_plot[s_name].get(r_val, {}).get('loss', np.nan) for r_val in ratios_tested_plot]
        plt.bar(index_plot + i * bar_width_plot, loss_vals, bar_width_plot, label=s_name, color=colors_map_plot(i % colors_map_plot.N))
    plt.xlabel('Pruning Ratio (ch_sparsity)'); plt.ylabel('Avg. Test Loss'); plt.title('Avg. Test Loss vs. Pruning Ratio (Fine-tuned Models)')
    plt.xticks(index_plot + bar_width_plot * (num_strategies_plot - 1) / 2, ratios_str_plot)
    plt.legend(title="Strategy", bbox_to_anchor=(1.02, 1), loc='upper left'); plt.grid(True, alpha=0.3)
    plt.tight_layout(rect=[0, 0, 0.88, 1]); plt.savefig(os.path.join(output_dir_main_plots, 'Loss_vs_Ratio_by_Strategy.png')); plt.close()
    print(f"✅ Loss vs. Ratio by Strategy plot saved to {output_dir_main_plots}")

### Main workflow

In [36]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Configuration
    cfg = {
        'strategies': {
            'magnitude': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.MagnitudeImportance(p=2),
            },
            'bn_scale': {
                'pruner': tp.pruner.BNScalePruner,
                'importance': tp.importance.BNScaleImportance(),
            },
            # todo: check the examples for the following strategies, why it is giving error
            'group_norm': {
                'pruner': tp.pruner.GroupNormPruner,
                'importance': tp.importance.GroupMagnitudeImportance(p=1),
            },
            'random': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.RandomImportance(),
            },
            'Taylor': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.TaylorImportance()
            },
            'Hessian': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.GroupHessianImportance()
            },
            'lamp': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.LAMPImportance(p=2)
            },
            'geometry': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.FPGMImportance()
            },
            #todo: implement growing reg pruning
        },
        #todo: different types of schedulers can be added
        'pruning_ratios_to_test': [0.0, 0.2, 0.5, 0.7], # 0.0 for baseline
        'iterative_steps_pruner_general': 5, # For non-Taylor pruners inside gr_prune_model_by_ratio
        'iterative_steps_taylor_pruning': 5, # Number of backward passes for Taylor method in gr_prune_model_by_ratio
        'val_split_for_loader': 0.1, # Corresponds to your get_data_loaders val_split
        #'iterative_pruning_ratio_scheduler': linear_scheduler(),
         'initial_train_epochs': 30,
        'fine_tune_epochs': 50,
        'early_stopping_patience': 15,
        'early_stopping_metric': 'val_loss', # 'val_loss' or 'val_acc'
        'load_best_model_on_early_stop': True,

        'data_dir': './data',
        'output_dir_base': './output_final_ratio_experiment',
        'num_classes': 10, 'batch_size': 128,
        'learning_rate_initial': 0.001, 'learning_rate_finetune': 0.0001,
    }

    os.makedirs(cfg['output_dir_base'], exist_ok=True)

    # --- Updated Data Loader Call ---
    print("Initializing DataLoaders...")
    # Now get_data_loaders returns train, val, test
    train_loader, val_loader, test_loader = get_data_loaders(
        data_dir=cfg['data_dir'],
        batch_size=cfg['batch_size'],
        val_split=cfg['val_split_for_loader'] # Pass the val_split
    )
    print(f"Train loader: {len(train_loader.dataset)} samples, Val loader: {len(val_loader.dataset)} samples, Test loader: {len(test_loader.dataset)} samples")
    # --- End of Updated Data Loader Call ---

    example_input_cpu_onnx = torch.randn(1, 3, 32, 32)
    example_input_dev = example_input_cpu_onnx.to(device)

    initial_model_trained_path = os.path.join(cfg['output_dir_base'], "mobilenetv2_initial_dense_trained.pth")
    criterion_train_eval = nn.CrossEntropyLoss().to(device) # Define criterion once

    if not os.path.exists(initial_model_trained_path):
        print("--- Training Initial Dense MobileNetV2 ---")
        dense_model_instance = MobileNetV2(num_classes=cfg['num_classes']).to(device)
        trained_dense_model, _ = train_model(
            model=dense_model_instance,
            train_loader=train_loader, # Use train_loader from get_data_loaders
            criterion=criterion_train_eval,
            optimizer=optim.Adam(dense_model_instance.parameters(), lr=cfg['learning_rate_initial']),
            device=device,
            num_epochs=cfg['initial_train_epochs'],
            val_loader=val_loader, # <<<< PASS THE DEDICATED val_loader HERE
            strategy_name="InitialDenseModel",
            early_stopping_patience=cfg['early_stopping_patience'],
            early_stopping_metric=cfg['early_stopping_metric'],
            load_best_model_on_stop=cfg['load_best_model_on_early_stop']
        )
        save_model(trained_dense_model, initial_model_trained_path, example_input=example_input_cpu_onnx.to(device))
    else:
        print(f"--- Using Pre-trained Dense Model from {initial_model_trained_path} ---")

    results_all_ratios_strategies = {s_name: {} for s_name in cfg['strategies'].keys()}

    print("\n--- Evaluating Baseline Model (Ratio 0.0) on TEST SET---")
    baseline_model_eval = load_model_state(MobileNetV2, initial_model_trained_path, device, num_classes=cfg['num_classes'])
    # Final evaluation of baseline model on the TEST set
    baseline_metrics_eval = evaluate_model(baseline_model_eval, test_loader, example_input_dev, criterion_train_eval, device)
    print(f"Baseline Metrics (Ratio 0.0, Evaluated on Test Set): {baseline_metrics_eval}")
    for s_name_key in cfg['strategies'].keys():
        results_all_ratios_strategies[s_name_key][0.0] = baseline_metrics_eval # Store test set metrics

    for strategy_name_key, strategy_details_dict in cfg['strategies'].items():
        strategy_specific_output_dir = os.path.join(cfg['output_dir_base'], strategy_name_key)
        os.makedirs(strategy_specific_output_dir, exist_ok=True)

        for current_ratio in cfg['pruning_ratios_to_test']:
            if current_ratio == 0.0: continue

            print(f"\n\n--- Processing: Strategy '{strategy_name_key}', Ratio: {current_ratio:.2f} ---")
            model_for_pruning_current = load_model_state(MobileNetV2, initial_model_trained_path, device, num_classes=cfg['num_classes'])

            print(f"--- Pruning model with {strategy_name_key} to ratio {current_ratio:.2f} ---")
            pruned_model_current = gr_prune_model_by_ratio( # Your adapted pruning function
                model_for_pruning_current, example_input_dev, strategy_details_dict, current_ratio,
                iterative_steps_config=cfg['iterative_steps_pruner_general'],
                iterative_steps_taylor=cfg['iterative_steps_taylor_pruning']
            )

            # Optional: Evaluate on test set *before* fine-tuning
            metrics_pruned_bf_ft_current = evaluate_model(pruned_model_current, test_loader, example_input_dev, criterion_train_eval, device)
            print(f"Metrics for '{strategy_name_key}' @ Ratio {current_ratio:.2f} (Pruned, Before FT, on Test Set): {metrics_pruned_bf_ft_current}")

            print(f"--- Fine-tuning pruned model ({strategy_name_key} @ Ratio {current_ratio:.2f}) ---")
            fine_tuned_model_current, ft_history_current = train_model(
                model=pruned_model_current,
                train_loader=train_loader, # Use train_loader
                criterion=criterion_train_eval,
                optimizer=optim.Adam(pruned_model_current.parameters(), lr=cfg['learning_rate_finetune']),
                device=device,
                num_epochs=cfg['fine_tune_epochs'],
                val_loader=val_loader, # <<<< PASS THE DEDICATED val_loader HERE
                strategy_name=f"{strategy_name_key}_R{current_ratio:.1f}",
                early_stopping_patience=cfg['early_stopping_patience'],
                early_stopping_metric=cfg['early_stopping_metric'],
                load_best_model_on_stop=cfg['load_best_model_on_early_stop']
            )

            macs_ft_current, _ = calculate_macs(fine_tuned_model_current, example_input_dev)
            plot_finetuning_curves(ft_history_current, f"{strategy_name_key}_R{current_ratio:.1f}",
                                   strategy_specific_output_dir, macs_ft_current)

            # Final evaluation of fine-tuned model on the TEST set
            print(f"--- Evaluating Fine-tuned Model ({strategy_name_key} @ Ratio {current_ratio:.2f}) on TEST SET ---")
            final_metrics_current = evaluate_model(fine_tuned_model_current, test_loader, example_input_dev, criterion_train_eval, device)
            results_all_ratios_strategies[strategy_name_key][current_ratio] = final_metrics_current
            print(f"Metrics for '{strategy_name_key}' @ Ratio {current_ratio:.2f} (Fine-tuned, on Test Set): {final_metrics_current}")
            save_model(fine_tuned_model_current,
                       os.path.join(strategy_specific_output_dir, f"model_R{current_ratio:.1f}_final.pth"),
                       example_input=example_input_cpu_onnx.to(device))

    plot_metrics_vs_ratio_all_strategies(results_all_ratios_strategies, cfg['pruning_ratios_to_test'], cfg['output_dir_base'])
    print("\nAll ratio-based pruning experiments completed successfully!")

In [49]:
if __name__ == "__main__":
    main()

Initializing DataLoaders...
Using dataset directory: /home/muis/thesis/github-repo/master-thesis/cnn/mobile_net_v2/data
Train loader: 45000 samples, Val loader: 5000 samples, Test loader: 10000 samples
--- Using Pre-trained Dense Model from ./output_final_ratio_experiment/mobilenetv2_initial_dense_trained.pth ---

--- Evaluating Baseline Model (Ratio 0.0) on TEST SET---
✅ Model loaded from ./output_final_ratio_experiment/mobilenetv2_initial_dense_trained.pth to cuda
Baseline Metrics (Ratio 0.0, Evaluated on Test Set): {'macs': 6059786.0, 'params': 1169642, 'size_mib': 4.461830139160156, 'accuracy': 63.29, 'loss': 1.032341127204895}


--- Processing: Strategy 'magnitude', Ratio: 0.20 ---
✅ Model loaded from ./output_final_ratio_experiment/mobilenetv2_initial_dense_trained.pth to cuda
--- Pruning model with magnitude to ratio 0.20 ---
Initial MACs before pruning for ratio 0.20: 0.006 G
Applying MagnitudeImportance with target ratio: 0.20 using pruner's 5 iterative_steps.
After pruning fo

### Load the saved Onnx model and convert to Pytorch model

In [12]:
import torch
import onnx
import onnxruntime
from onnx2torch import convert

# Step 1: Load the ONNX model
onnx_model_path = './output/strategies/mobilenetv2_bn_scale_final.onnx'
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)  # Verify the ONNX model
print("✅ ONNX model loaded and verified.")

# Step 2: Convert ONNX to PyTorch
torch_model = convert(onnx_model)
print("✅ ONNX model converted to PyTorch.")
tp.utils.print_tool.before_pruning(torch_model)


✅ ONNX model loaded and verified.
✅ ONNX model converted to PyTorch.


In [None]:
from torchinfo import summary
print(summary(torch_model))

In [None]:

from torchinfo import summary
print(summary(torch_model))
#tp.utils.print_tool.after_pruning(torch_model)


### Load the converted Pytorch model and fine-tune

In [None]:
train_loader, test_loader = get_data_loaders('./data')
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_model = torch_model.to(device)

tp.utils.print_tool.before_pruning(torch_model)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
num_epochs = 50

fine_tuned_model = train_model(torch_model, train_loader, criterion, optimizer, device, num_epochs)

### Plotting params results

In [15]:
import matplotlib.pyplot as plt
import numpy as np

# Data: Total parameters for each model
strategies = ['Initial', 'Magnitude', 'BN Scale', 'Group Norm', 'Random', 'Taylor', 'Hessian', 'Lamp']
params = [
    1162530,  # Initial model (from your query)
    305350,   # Magnitude (from output: 0.31M after pruning, assuming final tuned model)
    301558,   # BN Scale (from your query and output: 0.31M after pruning)
    305350,   # Group Norm (from output: 0.31M after pruning, assuming similar to magnitude)
    305350,   # Random (from output: 0.31M after pruning, assuming similar to magnitude)
    949650,   # Taylor (from output: 0.95M after pruning)
    1162530,  # Hessian (from output: 1.17M, no reduction observed)
    305350    # Lamp (from output: 0.31M after pruning)
]
initial_params = params[0]

# Calculate percentage reduction for each strategy
reductions = [((initial_params - p) / initial_params * 100) for p in params]

# Colors for each strategy
colors = ['gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink']

# Create bar chart
plt.figure(figsize=(14, 8))
bars = plt.bar(strategies, params, color=colors)

# Add value labels and percentage reduction
for i, bar in enumerate(bars):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:,}\n({reductions[i]:.1f}%)' if i > 0 else f'{height:,}',
             ha='center', va='bottom', fontsize=9)

# Add horizontal line for initial parameters
plt.axhline(y=initial_params, color='gray', linestyle='--', label='Initial Parameters')

# Customize the chart
plt.xlabel('Pruning Strategy')
plt.ylabel('Total Parameters')
plt.title('Comparison of Total Parameters Across Pruning Strategies')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()

# Save the chart to a file (since we're not displaying it interactively)
plt.savefig('output/pruning_comparison.png')

### comparison for MACs, params, size_mb

In [16]:
import matplotlib.pyplot as plt
import numpy as np

# Data: Metrics for each model
strategies = ['Initial', 'Magnitude', 'BN Scale', 'Group Norm', 'Random', 'Taylor', 'Hessian', 'Lamp']
metrics = {
    'params': [1162530, 305350, 301558, 305350, 305350, 949650, 1162530, 305350],  # Total parameters
    'macs': [6060000, 1840000, 1840000, 1840000, 1840000, 4920000, 6060000, 1840000],  # MACs
    'size_mb': [4.68, 1.22, 1.22, 1.22, 1.22, 3.80, 4.68, 1.22]  # Model size in MB
}
initial_metrics = {k: v[0] for k, v in metrics.items()}

# Calculate percentage reductions
reductions = {k: [((initial_metrics[k] - v) / initial_metrics[k] * 100) for v in values] for k, values in metrics.items()}

# Colors for each strategy
colors = ['gray', 'blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink']

# Create subplots
fig, axes = plt.subplots(3, 1, figsize=(14, 18))

# Plot each metric
for i, (metric_name, values) in enumerate(metrics.items()):
    ax = axes[i]
    bars = ax.bar(strategies, values, color=colors)

    # Add value labels and percentage reduction
    for j, bar in enumerate(bars):
        height = bar.get_height()
        label = f'{height:,}' if metric_name != 'size_mb' else f'{height:.2f}'
        if j > 0:
            label += f'\n({reductions[metric_name][j]:.1f}%)'
        ax.text(bar.get_x() + bar.get_width()/2., height, label, ha='center', va='bottom', fontsize=9)

    # Add horizontal line for initial metric
    ax.axhline(y=initial_metrics[metric_name], color='gray', linestyle='--', label='Initial')

    # Customize subplot
    ax.set_xlabel('Pruning Strategy')
    ax.set_ylabel(metric_name.replace('_', ' ').title())
    ax.set_title(f'Comparison of {metric_name.replace("_", " ").title()} Across Pruning Strategies')
    ax.set_xticks(range(len(strategies)))
    ax.set_xticklabels(strategies, rotation=45, ha='right')
    ax.legend()

# Adjust layout and save
plt.tight_layout()
plt.savefig('output/pruning_metrics_comparison.png')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data
epochs = list(range(1, 11))
train_acc = [18.48, 21.53, 21.21, 23.94, 26.32, 29.75, 31.91, 34.05, 36.34, 36.77]
valid_acc = [21.93, 20.94, 22.23, 25.33, 27.88, 30.41, 33.33, 35.23, 37.02, 38.57]
test_acc = [21.90, 21.01, 22.10, 24.86, 27.84, 31.08, 33.47, 35.38, 37.53, 38.45]

# Set up the plot
plt.figure(figsize=(10, 6))
bar_width = 0.25
x = np.arange(len(epochs))

# Create bars
plt.bar(x - bar_width, train_acc, bar_width, label='Train Accuracy', color='skyblue')
plt.bar(x, valid_acc, bar_width, label='Validation Accuracy', color='lightgreen')
plt.bar(x + bar_width, test_acc, bar_width, label='Test Accuracy', color='salmon')

# Customize the plot
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Model Accuracy per Epoch')
plt.xticks(x, epochs)
plt.legend()
plt.grid(True, axis='y', linestyle='--', alpha=0.7)

# Save the plot
plt.savefig('darts_accuracy_bar_plot.png')

### Gr: prune with threshold

In [44]:
#def gr_prune_model_with_threshold(model, example_input, target_macs, target_size_mb, strategy, max_iterative_steps=20, max_sparsity=0.9):
def gr_prune_model_by_ratio(model_to_prune, example_input_device, strategy_config, target_pruning_ratio,
                            iterative_steps_config=5,  # For non-Taylor pruners or overall iterations
                            iterative_steps_taylor=5 # For the outer loop of backward passes for Taylor
                           ):
    model = copy.deepcopy(model_to_prune)
    model.to(example_input_device.device)
    initial_macs, _ = calculate_macs(model, example_input_device)
    print(f"Initial MACs before pruning for ratio {target_pruning_ratio:.2f}: {initial_macs / 1e9:.3f} G")

    ignored_layers_list = []
    if hasattr(model, 'fc'):
        ignored_layers_list = [model.fc]
    else:
        print("Warning: model.fc not found. Pruning might affect classifier.")

    if target_pruning_ratio == 0.0:
        print("Target pruning ratio is 0.0. No pruning applied.")
        return model

    if isinstance(strategy_config['importance'], tp.importance.TaylorImportance):
        print(f"Applying TaylorImportance with target ratio: {target_pruning_ratio:.2f} using manual backward passes in eval mode.")

        # For this approach, the pruner is initialized to apply a portion of the pruning
        # in each step of our manual loop, similar to your original working snippet.
        # The pruner's own `iterative_steps` should be 1 because we are controlling the iteration.
        sparsity_per_manual_step = target_pruning_ratio / iterative_steps_taylor
        # Ensure some sparsity if target > 0 and steps are too many
        if sparsity_per_manual_step == 0 and target_pruning_ratio > 0 :
             sparsity_per_manual_step = target_pruning_ratio

        # This pruner instance is for ONE application within our manual loop
        pruner_for_taylor_step = strategy_config['pruner'](
            model,
            example_input_device,
            importance=strategy_config['importance'],
            iterative_steps=1,  # Pruner itself does one application based on current grads
            ch_sparsity=sparsity_per_manual_step, # Target for THIS single pruner.step() call
            root_module_types=[nn.Conv2d],
            ignored_layers=ignored_layers_list
        )

        # Keep model in EVAL mode for forward/backward to try and avoid BN error
        # This relies on requires_grad=True for parameters.
        model.eval() # Explicitly set to eval mode

        for i in range(iterative_steps_taylor):
            # Ensure gradients are zeroed for this step's computation
            model.zero_grad()
            # Parameters should retain requires_grad=True unless explicitly turned off
            for param in model.parameters(): # Ensure gradients are enabled if they were somehow turned off
                 if not param.requires_grad:
                      param.requires_grad_(True)

            # Forward pass in eval mode
            loss = model(example_input_device.clone()).mean()

            try:
                # Backward pass (still in eval mode for BNs)
                loss.backward()
            except Exception as e:
                print(f"Error during loss.backward() at Taylor step {i+1} (model in eval mode): {e}")
                raise

            # Pruner step using the gradients computed above
            # The pruner might internally switch parts of the model to train if absolutely necessary,
            # but typically it just consumes the gradients.
            print(f"  Taylor Manual Step {i+1}/{iterative_steps_taylor} (target overall ratio: {target_pruning_ratio:.2f})")
            groups_to_prune = list(pruner_for_taylor_step.step(interactive=True)) # Get groups based on current step's sparsity target
            if not groups_to_prune and target_pruning_ratio > 0: # If positive target but no groups
                 print(f"  No more groups to prune at Taylor step {i+1} based on ch_sparsity={sparsity_per_manual_step:.3f}.")
                 # We might want to break if we expect more pruning to happen for the overall target
                 # This depends on how pruner_for_taylor_step.step() handles reaching its *own* iterative_steps=1.
                 # If it can be called multiple times meaningfully to achieve more pruning, then continue.
                 # If it only effectively prunes once based on initial sparsity, this loop structure needs rethink.
                 # Let's assume for now that calling it multiple times could achieve more if candidates exist.

            for g in groups_to_prune:
                g.prune()

            macs_after_step, _ = calculate_macs(model, example_input_device)
            print(f"  MACs after Taylor step {i+1}: {macs_after_step / 1e9:.3f} G")

            # Check if overall target MACs or effective sparsity is reached (optional premature exit)
            # This requires defining what the MACs target would be for `target_pruning_ratio`.
            # For simplicity, we run all `iterative_steps_taylor_manual_loop`

            # Zero gradients again IF parameters are shared across pruner re-initialization
            # But here pruner_for_taylor_step is initialized once. Grads should be cleared by model.zero_grad() at loop start.

        # After the loop, model remains in eval() mode. This is usually fine.

    else: # For non-Taylor strategies
        print(f"Applying {strategy_config['importance'].__class__.__name__} with target ratio: {target_pruning_ratio:.2f} "
              f"using pruner's {iterative_steps_config} iterative_steps.")

        # Set model to eval for non-Taylor pruning as well for consistency,
        # as most pruners expect eval mode for graph analysis.
        model.eval()

        pruner = strategy_config['pruner'](
            model,
            example_input_device,
            importance=strategy_config['importance'],
            iterative_steps=iterative_steps_config,
            ch_sparsity=target_pruning_ratio,
            root_module_types=[nn.Conv2d],
            ignored_layers=ignored_layers_list
        )
        pruner.step()

    final_macs, _ = calculate_macs(model, example_input_device)
    reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    print(f"After pruning for ratio {target_pruning_ratio:.2f}: MACs {final_macs / 1e9:.3f} G (Reduction: {reduction:.2f}%)")
    return model

### Gem: Prune with threshold

In [1]:
import time # Optional: for adding time limits or tracking

# Assume calculate_macs is defined as before:
# def calculate_macs(model, example_input):
#     macs, params = tp.utils.count_ops_and_params(model, example_input)
#     return macs, params

def gem_prune_model_by_threshold(model, example_input, target_macs, target_params, strategy, max_iterations=100, step_ch_sparsity=0.1):
    """
    Prunes the model iteratively until both MACs and parameter count are below
    the specified thresholds, or max_iterations is reached.

    Args:
        model: The PyTorch model to prune.
        example_input: Example input tensor for MACs calculation and pruner.
        target_macs: The desired maximum MAC count.
        target_params: The desired maximum parameter count.
        strategy: Dictionary containing 'pruner' and 'importance'.
        max_iterations (int): Safety limit for the number of pruning steps.
        step_ch_sparsity (float): Channel sparsity target for each individual pruning step.
                                  Influences how many candidates `pruner.step` proposes.
                                  Smaller values lead to potentially finer steps.

    Returns:
        The pruned model.
    """
    device = example_input.device
    model.to(device) # Ensure model is on the correct device

    print(f"--- Starting Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    print(f"Target MACs: {target_macs:,.0f}, Target Params: {target_params:,.0f}")

    # Instantiate the pruner
    # Note: 'iterative_steps' in init is less critical here as the while loop controls iteration.
    pruner = strategy['pruner'](
        model,
        example_input,
        importance=strategy['importance'],
        ch_sparsity=step_ch_sparsity, # Target sparsity *per step*
        root_module_types=[nn.Conv2d], # Focus pruning on Conv layers
        ignored_layers=[model.fc], # Don't prune the final classifier
        # Optional: other pruner args like round_to might be useful depending on strategy/model
        # round_to=8, # Example: commonly used for hardware efficiency
    )

    # Get initial state
    current_macs, current_params = calculate_macs(model, example_input)
    initial_macs, initial_params = current_macs, current_params # Keep for logging
    print(f"Initial | MACs: {current_macs:,.0f}, Params: {current_params:,.0f}")

    iteration = 0
    model.eval() # Ensure model is in eval mode for pruning logic unless Taylor

    while (current_macs > target_macs or current_params > target_params) and iteration < max_iterations:
        iteration += 1
        macs_before_step = current_macs
        params_before_step = current_params

        # --- Special handling for Importance methods requiring gradients ---
        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
             model.train() # Need gradients
             # Ensure requires_grad is True if it was turned off
             for param in model.parameters():
                 param.requires_grad_(True)

             loss = model(example_input).mean() # Use mean or sum as dummy loss
             try:
                 loss.backward() # Calculate gradients needed for importance
             except Exception as e:
                 print(f"Error during backward pass for importance calc (Iter {iteration}): {e}")
                 # Decide how to handle: break, skip step, etc.
                 break # Safer to stop if backward fails

        # --- Perform one step of interactive pruning ---
        try:
            # Get the next set of pruning candidates based on current importance
            pruning_groups = list(pruner.step(interactive=True))
        except Exception as e:
             print(f"Error during pruner.step() (Iter {iteration}): {e}")
             # Handle potential errors during dependency analysis or importance scoring
             break # Stop if pruner step fails

        if not pruning_groups:
            print(f"Iteration {iteration}: Pruner found no more candidates. Stopping.")
            break # No more structures can be pruned according to the strategy/dependencies

        # --- Apply the pruning ---
        for group in pruning_groups:
            group.prune()

        # --- Clean up gradients if calculated ---
        if isinstance(strategy['importance'], (tp.importance.TaylorImportance, tp.importance.GroupHessianImportance)):
            # Zero gradients to prevent interference with potential future training/fine-tuning
            model.zero_grad()
            model.eval() # Switch back to eval mode after grad calculation

        # --- Recalculate metrics ---
        # It's crucial to recalculate AFTER pruning is applied
        current_macs, current_params = calculate_macs(model, example_input)

        # --- Log progress ---
        print(
            f"Iter {iteration: >3}/{max_iterations} | "
            f"MACs: {macs_before_step:,.0f} -> {current_macs:,.0f} "
            f"({(macs_before_step-current_macs)/macs_before_step*100:+.1f}%) | "
            f"Params: {params_before_step:,.0f} -> {current_params:,.0f} "
            f"({(params_before_step-current_params)/params_before_step*100:+.1f}%)"
        )

        # --- Check for Stagnation (optional but recommended) ---
        if current_macs >= macs_before_step and current_params >= params_before_step:
            print(f"Iteration {iteration}: No reduction in MACs or Params this step. Stopping to prevent loop.")
            # This might happen if the only prunable groups left have negligible impact
            # or if there's an issue with the importance/pruning logic.
            break

    # --- Final Status Report ---
    print(f"--- Finished Pruning (Strategy: {strategy['importance'].__class__.__name__}) ---")
    if iteration >= max_iterations:
        print(f"Warning: Reached maximum pruning iterations ({max_iterations}).")

    final_macs, final_params = calculate_macs(model, example_input)
    print(f"Initial | MACs: {initial_macs:,.0f}, Params: {initial_params:,.0f}")
    print(f"Final   | MACs: {final_macs:,.0f}, Params: {final_params:,.0f}")
    print(f"Target  | MACs: {target_macs:,.0f}, Params: {target_params:,.0f}")

    macs_reduction = (initial_macs - final_macs) / initial_macs * 100 if initial_macs > 0 else 0
    params_reduction = (initial_params - final_params) / initial_params * 100 if initial_params > 0 else 0
    print(f"Reduction | MACs: {macs_reduction:.2f}%, Params: {params_reduction:.2f}%")

    if final_macs > target_macs or final_params > target_params:
         print("Warning: Pruning finished, but target threshold(s) were not fully met.")

    model.eval() # Ensure model is in eval mode finally
    return model