# Magnitude pruning of Faster RCNN ResNet50

This code loads a Faster R-CNN ResNet-50 model, prunes it using magnitude-based pruning, saves the pruned model, and evaluates the pruned model. It also calculates and displays the size of the pruned model file. Detailed explanations for each variable and functionality are provided below.

1.   List item
2.   List item



In [1]:
# @title Torch-Pruning
!pip install --upgrade torch_pruning




[notice] A new release of pip is available: 23.1.2 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
import torch_pruning as tp

### 1. Choose a model to prune

In [3]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the Faster R-CNN ResNet-50 model trained on COCO
# base_model = fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
# base_model.load_state_dict(torch.load("/content/drive/Othercomputers/My Laptop/F_CNN/model/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth"))
model  = fasterrcnn_resnet50_fpn(weights='FasterRCNN_ResNet50_FPN_Weights.COCO_V1', weights_backbone=None).to(device)
model1 = model

### 2. Prepare a pruner

Note: When we prune a model like ConvNext and ViT, torch-pruning will show a warning about unwrapped parameters. This warning is caused by `nn.Parameter` that does not belong to any standard layers such as `nn.Conv2d`, `nn.BatchNorm`. By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [5]:
# @title Pruning using Magnitude Pruner

"""Functionality: This code  applies magnitude-based pruning to the model.
It then saves the pruned model's state dictionary and evaluates the pruned model.
Finally, it calculates and displays the size of the pruned model file.

Variables:
- example_inputs: Example inputs for magnitude-based pruning.
- imp: Importance metric for pruning.
- ignored_layers: List of final classifier layers to be ignored during pruning.
- pruner: MagnitudePruner object for pruning the model.
- records: List to store pruning records for each layer.
- g: Pruning group for each step in the pruning process.
- dep: Dependency object containing layer information.
- idxs: Indices of pruned channels in the layer.
- layer: Layer being pruned.
- pruning_fn: Pruning function applied to the layer.
- file_name: Path to the saved pruned model file.
- file_stats: Statistics of the saved model file.
"""


example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)
ignored_layers = []

# DO NOT prune the final classifier!
for m in model1.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

pruner = tp.pruner.MagnitudePruner(
    model1,
    example_inputs,
    importance=imp,
    iterative_steps=1,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

records = []
for g in pruner.step(interactive=True):
    dep, idxs = g[0]
    layer = dep.layer
    pruning_fn = dep.pruning_fn
    records.append((layer, idxs, pruning_fn))
    g.prune()

for rec in records:
    print(rec)
    print("")



(Linear(in_features=512, out_features=182, bias=True), [0, 1, 2, 3, 29, 31, 40, 41, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 83, 87, 91, 92, 93, 94, 95, 97, 99, 101, 103, 104, 105, 106, 107, 116, 117, 118, 119, 120, 121, 122, 123, 129, 131, 136, 137, 138, 139, 140, 144, 145, 146, 147, 148, 149, 150, 151, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 166, 172, 173, 175, 180, 181, 182, 183, 184, 208, 212, 213, 214, 215, 216, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 264, 265, 266, 267, 272, 273, 274, 275, 276, 277, 278, 279, 280, 282, 284, 285, 286, 287, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 326, 328, 330, 331, 332, 333, 334, 335, 340, 341, 342, 343, 344, 348, 349, 350, 351, 356, 357, 358, 359, 360, 361, 362, 363], <bound method LinearPruner.prune_out_channels of <torch_pruning.pruner.

In [7]:
model.zero_grad() # We don't want to store gradient information
torch.save(model.state_dict(),'Mg_fasterrcnn_resnet50_pruned_model_2.pth')
# model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_pruned_model_2.pth'))
model.load_state_dict(torch.load('Mg_fasterrcnn_resnet50_pruned_model_2.pth'), strict=False)
model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [8]:
# get file size in python
import os

file_name = "Mg_fasterrcnn_resnet50_pruned_model_2.pth"

file_stats = os.stat(file_name)

print(file_stats)
print(f'File Size in Bytes is {file_stats.st_size}')
print(f'File Size in MegaBytes is {file_stats.st_size / (1024 * 1024)}')

os.stat_result(st_mode=33206, st_ino=5348024558172762, st_dev=2318610902, st_nlink=2, st_uid=0, st_gid=0, st_size=44939950, st_atime=1707398929, st_mtime=1707398917, st_ctime=1707398917)
File Size in Bytes is 44939950
File Size in MegaBytes is 42.858076095581055
