## import necessary libraries for pruning

In [None]:
import torch
import torch.nn as nn
import torch_pruning as tp
from torch import optim
from cnn.resNet.resnet_example import get_data_loaders
import torch
from torch import nn

#### asuming model size

#### Training model

### plotting function

#### Main fucntion

## Structural Pruning

### Seed Network

In [5]:


class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, expansion=6):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        # Standard PyTorch layers (NO torch_pruning wrappers needed)
        self.conv1 = nn.Conv2d(in_channels, in_channels * expansion, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels * expansion)
        self.relu = nn.ReLU6(inplace=True)

        self.conv2 = nn.Conv2d(
            in_channels * expansion, in_channels * expansion, kernel_size=3,
            stride=stride, padding=1, groups=in_channels * expansion, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels * expansion)

        self.conv3 = nn.Conv2d(in_channels * expansion, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x if self.use_res_connect else None

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.use_res_connect:
            return identity + out
        else:
            return out

### Mask Network

In [2]:

class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        # --- Remove mask-related parameters ---
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU6(inplace=True)

        # Define blocks (no mask_index or mask)
        self.block1 = InvertedResidual(32, 16, stride=1)
        self.block2 = InvertedResidual(16, 24, stride=2)
        self.block3 = InvertedResidual(24, 32, stride=2)
        self.block4 = InvertedResidual(32, 64, stride=2)
        self.block5 = InvertedResidual(64, 96, stride=1)
        self.block6 = InvertedResidual(96, 160, stride=2)
        self.block7 = InvertedResidual(160, 320, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # --- Remove mask-based block skipping ---
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

### Save model as ONNX

In [9]:
def save_model_as_onnx(model, example_input, output_path):
    # Set model to evaluation mode
    model.eval()
    # Export to ONNX
    torch.onnx.export(
        model,
        example_input,
        output_path,
        export_params=True,
        opset_version=13,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    print(f"✅ Model saved as ONNX to {output_path}")

### Calculate MACs

In [None]:
def calculate_macs(model, example_input):
    macs, params = tp.utils.count_ops_and_params(model, example_input)
    return macs

### Main function for Pruning

In [None]:
def prune_model(model, example_input, target_macs):
    # 1. Define pruning ratio based on MACs
    current_macs = calculate_macs(model, example_input)
    macs_sparsity = 1 - (target_macs / current_macs)

    # todo: can Implement different pruner strategies here based on the importance scoring function,
    # todo: ref: main_imagenet.py file of a torch_pruning repo
    # 2. Initialize pruner
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_input,
        importance=tp.importance.MagnitudeImportance(p=2),  # L2 norm
        ch_sparsity= 0.5, # todo: I can set different sparsity for each layer,  I will apply sparsity based on the macs and flops
        root_module_types=[nn.Conv2d, nn.Linear],  # Layers to prune
        ignored_layers=[model.fc],  # DO NOT prune the final classifier!
    )

    # 3. Prune
    pruner.step()

    return model

### Print Model Metrics

In [None]:
def print_model_metrics(model, example_input, label):
    macs = calculate_macs(model, example_input)
    params = sum(p.numel() for p in model.parameters())
    size_mb = params * 4 / 1e6  # 4 bytes per float32
    print(f"{label}: MACs={macs:.2e}, Size={size_mb:.2f} MB")

### Model Training

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for data in train_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # todo: try to set different optimizer after pruning , early stopping, scheduler, # of epochs for fine tuning, pruning
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}: Loss={epoch_loss:.4f}, Accuracy={epoch_acc:.2f}%")
    return model

### Main function

In [10]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MobileNetV2(num_classes=10).to(device)
    example_input = torch.randn(1, 3, 32, 32).to(device)  # CIFAR-10 input shape

    # Save initial model (before pruning)
    torch.save(model.state_dict(), "output/mobilenetv2_before_pruning.pth")
    save_model_as_onnx(model, example_input, "output/onnx/mobilenetv2_before_pruning.onnx")

    # Train and prune
    train_loader, test_loader = get_data_loaders('./data')
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    model = train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)

    # Prune
    initial_macs = calculate_macs(model, example_input)
    target_macs = initial_macs // 2
    print_model_metrics(model, example_input, "Before Pruning")
    model = prune_model(model, example_input, target_macs)
    print_model_metrics(model, example_input, "After Pruning")
    # Save pruned model
    torch.save(model.state_dict(), "output/mobilenetv2_after_pruning.pth")
    save_model_as_onnx(model, example_input, "output/mobilenetv2_after_pruning.onnx")

    # Fine-tune and save final model
    model = train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)
    torch.save(model.state_dict(), "output/mobilenetv2_final.pth")
    save_model_as_onnx(model, example_input, "output/mobilenetv2_final.onnx")

if __name__ == "__main__":
    main()

✅ Model saved as ONNX to output/mobilenetv2_before_pruning.onnx
Using dataset directory: /home/muis/thesis/github-repo/master-thesis/cnn/mobile_net_v2/data
Epoch 1: Loss=1.6700, Accuracy=38.68%
Epoch 2: Loss=1.2926, Accuracy=53.77%
Epoch 3: Loss=1.1241, Accuracy=60.28%
Epoch 4: Loss=0.9963, Accuracy=65.26%
Epoch 5: Loss=0.9004, Accuracy=68.39%
Epoch 6: Loss=0.8265, Accuracy=71.05%
Epoch 7: Loss=0.7562, Accuracy=73.67%
Epoch 8: Loss=0.7054, Accuracy=75.24%
Epoch 9: Loss=0.6582, Accuracy=77.14%
Epoch 10: Loss=0.6139, Accuracy=78.47%
Before Pruning: MACs=6.06e+06, Size=4.68 MB
After Pruning: MACs=1.84e+06, Size=1.22 MB
✅ Model saved as ONNX to output/mobilenetv2_after_pruning.onnx




Epoch 1: Loss=4.2010, Accuracy=20.21%
Epoch 2: Loss=4.1631, Accuracy=20.47%
Epoch 3: Loss=4.1648, Accuracy=20.54%
Epoch 4: Loss=4.1589, Accuracy=20.54%
Epoch 5: Loss=4.1633, Accuracy=20.63%
Epoch 6: Loss=4.1515, Accuracy=20.66%
Epoch 7: Loss=4.1579, Accuracy=20.66%
Epoch 8: Loss=4.1605, Accuracy=20.71%
Epoch 9: Loss=4.1630, Accuracy=20.70%
Epoch 10: Loss=4.1623, Accuracy=20.57%
✅ Model saved as ONNX to output/mobilenetv2_final.onnx


### Visualize the model

In [11]:
import netron
netron.start("output/mobilenetv2_before_pruning.onnx")  # Before pruning
netron.start("output/mobilenetv2_after_pruning.onnx")   # After pruning

Serving 'output/mobilenetv2_before_pruning.onnx' at http://localhost:8080
Serving 'output/mobilenetv2_after_pruning.onnx' at http://localhost:8081


('localhost', 8081)