## import necessary libraries for pruning

In [1]:
import torch
import torch.nn as nn
import torch_pruning as tp
from torch import optim
from cnn.resNet.resnet_example import get_data_loaders
import torch
from torch import nn

## Structural Pruning

### Seed Network

In [2]:


class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, expansion=6):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        # Standard PyTorch layers (NO torch_pruning wrappers needed)
        self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels * expansion)
        self.relu = nn.ReLU6(inplace=True)

        self.conv2 = nn.Conv2d(
            in_channels * expansion, in_channels * expansion, kernel_size=3,
            stride=stride, padding=1, groups=in_channels * expansion, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels * expansion)

        self.conv3 = nn.Conv2d(in_channels * expansion, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x if self.use_res_connect else None

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.use_res_connect:
            return identity + out
        else:
            return out

### Mask Network

In [3]:

class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        # --- Remove mask-related parameters ---
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU6(inplace=True)

        # Define blocks (no mask_index or mask)
        self.block1 = InvertedResidual(32, 16, stride=1)
        self.block2 = InvertedResidual(16, 24, stride=2)
        self.block3 = InvertedResidual(24, 32, stride=2)
        self.block4 = InvertedResidual(32, 64, stride=2)
        self.block5 = InvertedResidual(64, 96, stride=1)
        self.block6 = InvertedResidual(96, 160, stride=2)
        self.block7 = InvertedResidual(160, 320, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # --- Remove mask-based block skipping ---
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

### Save model as ONNX

In [4]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

### Calculate MACs

In [5]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs

### Magnitude-Pruner as Importance function for Pruning

In [None]:
def prune_model(model, example_input, target_macs):
    # 1. Define pruning ratio based on MACs
    current_macs = calculate_macs(model, example_input)
    macs_sparsity = 1 - (target_macs / current_macs)

    # todo: can Implement different pruner strategies here based on the importance scoring function,
    # todo: ref: main_imagenet.py file of a torch_pruning repo
    #model.eval()
    # 2. Initialize pruner
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_input,
        importance=tp.importance.MagnitudeImportance(p=2),  # L2 norm
        ch_sparsity= 0.5, # todo: I can set different sparsity for each layer,  I will apply sparsity based on the macs and flops
        root_module_types=[nn.Conv2d, nn.Linear],  # Layers to prune
        ignored_layers=[model.fc],  # DO NOT prune the final classifier!
    )

    # 3. Prune
    pruner.step()

    return model

### Define Strategy-to-Pruner Mapping

In [29]:
from torch_pruning import (
    pruner,
    importance
)

STRATEGIES = {
    'magnitude': {
        'pruner': pruner.MagnitudePruner,
        'importance': importance.MagnitudeImportance(p=2),  # L2 norm
    },
    'bn_scale': {
        'pruner': pruner.BNScalePruner,
        'importance': importance.BNScaleImportance(),
    },
    'group_norm': {
        'pruner': pruner.GroupNormPruner,
        'importance': importance.GroupNormImportance(p=2),  # L2 norm
    },
    'random': {
        'pruner': pruner.RandomPruner,
        'importance': importance.RandomImportance(),
    }
}

AttributeError: module 'torch_pruning.pruner.importance' has no attribute 'GroupNormImportance'

### Prune with different strategies

In [30]:
def prune_with_strategies(model, example_input, target_macs, strategies):
    results = {}
    original_state = model.state_dict().copy()

    for strategy_name in strategies:
        # Reset model to original state
        model.load_state_dict(original_state)

        # Clone model for this strategy
        model_copy = tp.utils.clone_model(model)

        # Initialize pruner for this strategy
        strategy = STRATEGIES[strategy_name]
        pruner_class = strategy['pruner']
        importance = strategy['importance']

        pruner = pruner_class(
            model_copy,
            example_input,
            importance=importance,
            ch_sparsity=0.5,  # Target 50% sparsity
            root_module_types=[nn.Conv2d],  # Target Conv2d layers
            ignored_layers=[model_copy.fc],  # Skip classifier
        )
        pruner.step()  # Execute pruning

        # Calculate metrics
        macs = calculate_macs(model_copy, example_input)
        params = sum(p.numel() for p in model_copy.parameters())
        size_mb = params * 4 / 1e6

        results[strategy_name] = {
            'macs': macs,
            'size_mb': size_mb,
            'accuracy': None
        }

        # Save pruned model
        torch.save(model_copy.state_dict(), f"output/mobilenetv2_{strategy_name}.pth")

    return results

### Compare Pruning Strategies results

In [None]:
def compare_results(results, test_loader, device):
    for strategy_name, metrics in results.items():
        # Rebuild model
        model = MobileNetV2(num_classes=10).to(device)
        model.load_state_dict(torch.load(f"output/mobilenetv2_{strategy_name}.pth"))

        # Test accuracy
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        metrics['accuracy'] = 100 * correct / total
        print(f"{strategy_name}:")
        print(f"  MACs: {metrics['macs']:.2e}")
        print(f"  Size: {metrics['size_mb']:.2f} MB")
        print(f"  Accuracy: {metrics['accuracy']:.2f}%")
        print("-" * 40)

### Print Model Metrics

In [6]:
def print_model_metrics(model, example_input, label):
    macs = calculate_macs(model, example_input)
    params = sum(p.numel() for p in model.parameters())
    size_mb = params * 4 / 1e6  # 4 bytes per float32
    print(f"{label}: MACs={macs:.2e}, Size={size_mb:.2f} MB")

### Model Training

In [7]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for data in train_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # todo: try to set different optimizer after pruning , early stopping, scheduler, # of epochs for fine tuning, pruning
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%")
    return model

### Main function

In [None]:
import pprint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetV2(num_classes=10).to(device)
example_input = torch.randn(1, 3, 32, 32).to(device)  # CIFAR-10 input shape

DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input, verbose=True)
all_groups = list(DG.get_all_groups())
#print("Number of Groups: %d"%len(all_groups))
#print("The last Group:", all_groups[-1])

for key in all_groups:
    print(f"{key}")

# tp.utils.draw_dependency_graph(DG, save_as='output/draw_dep_graph.png', title=None)
# tp.utils.draw_groups(DG, save_as='output/draw_groups.png', title=None)
# tp.utils.draw_computational_graph(DG, save_as='output/draw_comp_graph.png', title=None)


In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MobileNetV2(num_classes=10).to(device)
    example_input = torch.randn(1, 3, 32, 32).to(device)  # CIFAR-10 input shape

    # Save initial model (before pruning)
    torch.save(model.state_dict(), "output/mobilenetv2_before_pruning.pth")
    save_model_as_onnx(model, example_input, "output/onnx/mobilenetv2_before_pruning.onnx")

    # Train and prune
    train_loader, test_loader = get_data_loaders('./data')
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    model = train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)

    # Prune
    initial_macs = calculate_macs(model, example_input)
    target_macs = initial_macs // 2
    print_model_metrics(model, example_input, "Before Pruning")
    model = prune_model(model, example_input, target_macs)
    print_model_metrics(model, example_input, "After Pruning")
    # Save pruned model
    torch.save(model.state_dict(), "output/mobilenetv2_after_pruning.pth")
    save_model_as_onnx(model, example_input, "output/mobilenetv2_after_pruning.onnx")

    # Fine-tune and save final model
    model = train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)
    torch.save(model.state_dict(), "output/mobilenetv2_final.pth")
    save_model_as_onnx(model, example_input, "output/mobilenetv2_final.onnx")

if __name__ == "__main__":
    main()

### Visualize the model

In [11]:
import netron
netron.start("output/mobilenetv2_before_pruning.onnx")  # Before pruning
netron.start("output/mobilenetv2_after_pruning.onnx")   # After pruning

Serving 'output/mobilenetv2_before_pruning.onnx' at http://localhost:8080
Serving 'output/mobilenetv2_after_pruning.onnx' at http://localhost:8081


('localhost', 8081)