## import necessary libraries for pruning

In [1]:

import torch
import torch.nn as nn
import torch_pruning as tp
from torch import optim
from cnn.resNet.resnet_example import get_data_loaders
import torch
from torch import nn

### Seed Network

In [9]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, expansion=6):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        # Standard PyTorch layers (NO torch_pruning wrappers needed)
        self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels * expansion)
        self.relu = nn.ReLU6(inplace=True)

        self.conv2 = nn.Conv2d(
            in_channels * expansion, in_channels * expansion, kernel_size=3,
            stride=stride, padding=1, groups=in_channels * expansion, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels * expansion)

        self.conv3 = nn.Conv2d(in_channels * expansion, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x if self.use_res_connect else None

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.use_res_connect:
            return identity + out
        else:
            return out

### Mask Network

In [10]:
class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        # --- Remove mask-related parameters ---
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU6(inplace=True)

        # Define blocks (no mask_index or mask)
        self.block1 = InvertedResidual(32, 16, stride=1)
        self.block2 = InvertedResidual(16, 24, stride=2)
        self.block3 = InvertedResidual(24, 32, stride=2)
        self.block4 = InvertedResidual(32, 64, stride=2)
        self.block5 = InvertedResidual(64, 96, stride=1)
        self.block6 = InvertedResidual(96, 160, stride=2)
        self.block7 = InvertedResidual(160, 320, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # --- Remove mask-based block skipping ---
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

In [13]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs

### compare results of different pruning strategies

In [5]:
def compare_results(results):
    print("\n=== Pruning Strategy Comparison ===")
    print(f"{'Strategy':<12} | {'MACs':<12} | {'Size (MB)':<10} | {'Accuracy (%)':<12}")
    print("-" * 55)
    for strategy, metrics in results.items():
        print(f"{strategy:<12} | {metrics['macs']:.2e} | {metrics['size_mb']:>9.2f} | {metrics['accuracy']:>12.2f}")

### Load Model

In [6]:
def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

### Utility function to save the model

In [8]:
import os

def save_model(model, path, example_input=None):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    if example_input is not None:
        onnx_path = path.replace('.pth', '.onnx')
        save_model_as_onnx(model, example_input, onnx_path)

### Evaluate the model

In [None]:
def evaluate_model(model, test_loader, example_input, device):
    model.eval()
    # Calculate metrics
    macs = calculate_macs(model, example_input)
    params = sum(p.numel() for p in model.parameters())
    size_mb = params * 4 / 1e6

    # Calculate accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = [d.to(device) for d in data]
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return {
        'macs': macs,
        'size_mb': size_mb,
        'accuracy': 100 * correct / total
    }

### Prune the model

In [None]:
def prune_model(model, example_input, target_macs, strategy):
    pruner = strategy['pruner'](
        model,
        example_input,
        importance=strategy['importance'],
        ch_sparsity=0.5,  # Initial sparsity
        root_module_types=[nn.Conv2d],
        ignored_layers=[model.fc],
    )

    current_macs = calculate_macs(model, example_input)
    while current_macs > target_macs:
        pruner.step()
        current_macs = calculate_macs(model, example_input)

    return model

### Train the model

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for data in train_loader:
            inputs, labels = [d.to(device) for d in data]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%")
    return model

### Main workflow

In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Configuration
    config = {
        'strategies': {
            'magnitude': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.MagnitudeImportance(p=2),
            },
            'bn_scale': {
                'pruner': tp.pruner.BNScalePruner,
                'importance': tp.importance.BNScaleImportance(),
            },
            'group_norm': {
                'pruner': tp.pruner.GroupNormPruner,
                'importance': tp.pruner.MagnitudePruner,
            },
            'random': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.RandomImportance(),
            },
            'Taylor': {
                'pruner': tp.pruner.MagnitudePruner,
                'importance': tp.importance.TaylorImportance
            }
        },
        'target_macs_sparsity': 0.5,  # 50% MACs reduction
        'train_epochs': 10,
        'fine_tune_epochs': 10,
        'data_dir': './data',
        'output_dir': './output/strategies'
    }

    # Initialize model and data
    model = MobileNetV2(num_classes=10).to(device)
    train_loader, test_loader = get_data_loaders(config['data_dir'])
    example_input = torch.randn(1, 3, 32, 32).to(device)

    # Workflow execution
    initial_model_path = os.path.join(config['output_dir'], "mobilenetv2_initial.pth")

    if not os.path.exists(initial_model_path):
        # 1. Initial training
        model = train_model(
            model=model,
            train_loader=train_loader,
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optim.Adam(model.parameters(), lr=0.001),
            device=device,
            num_epochs=config['train_epochs']
        )
        save_model(model, initial_model_path, example_input)

    # 2. Pruning and evaluation workflow
    results = {}
    initial_macs = calculate_macs(model, example_input)
    target_macs = initial_macs * config['target_macs_sparsity']

    for strategy_name in config['strategies']:
        # 2a. Prepare fresh model for each strategy
        model_copy = load_model(MobileNetV2(num_classes=10).to(device), initial_model_path)

        # 2b. Perform pruning
        pruned_model = prune_model(
            model=model_copy,
            example_input=example_input,
            target_macs=target_macs,
            strategy=config['strategies'][strategy_name]
        )

        # 2c. Save pruned model
        pruned_path = os.path.join(config['output_dir'], f"mobilenetv2_{strategy_name}_pruned.pth")
        save_model(pruned_model, pruned_path, example_input)

        # 2d. Fine-tune
        fine_tuned_model = train_model(
            model=pruned_model,
            train_loader=train_loader,
            criterion=nn.CrossEntropyLoss().to(device),
            optimizer=optim.Adam(pruned_model.parameters(), lr=0.001),
            device=device,
            num_epochs=config['fine_tune_epochs']
        )

        # 2e. Evaluate
        results[strategy_name] = evaluate_model(
            model=fine_tuned_model,
            test_loader=test_loader,
            example_input=example_input,
            device=device
        )

        # 2f. Save final model
        final_path = os.path.join(config['output_dir'], f"mobilenetv2_{strategy_name}_final.pth")
        save_model(fine_tuned_model, final_path, example_input)

    # 3. Compare results
    compare_results(results)
    print("Workflow completed successfully!")