# Torch pruning

In [1]:
import torch
import torchvision
import torch_pruning as tp
import os, sys
import torch.nn as nn
import random
from torchvision.models.detection import ssd300_vgg16
import warnings 
warnings.filterwarnings('ignore')
from torchinfo import summary

In [2]:
def my_prune(model, example_inputs, output_transform, model_name, ratio, model_save_path):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ori_size = tp.utils.count_params(model)
    model.cpu().eval()
    ignored_layers = []

    for p in model.parameters():
        p.requires_grad_(True)

    #########################################
    # Ignore unprunable modules
    #########################################
    for m in model.modules():
        if isinstance(m, nn.Linear) and m.out_features == 1000:
            ignored_layers.append(m)

    if 'ssd' in model_name:
        ignored_layers.append(model.head)
        ignored_layers.append(model.backbone.extra)
        ignored_layers.append(model.backbone.features[-2])

    round_to = None
    channel_groups = {}

    unwrapped_parameters = None

    #########################################
    # Build network pruners
    #########################################
    importance = tp.importance.MagnitudeImportance(p=1)
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs=example_inputs,
        importance=importance,
        iterative_steps=1,
        pruning_ratio=ratio, # adjust the ratio as needed
        global_pruning=False,
        round_to=round_to,
        unwrapped_parameters=unwrapped_parameters,
        ignored_layers=ignored_layers,
        channel_groups=channel_groups,
    )

    #########################################
    # Pruning
    #########################################
    # print("==============Before pruning=================")
    # print("Model Name: {}".format(model_name))
    # print(model)

    layer_channel_cfg = {}
    for module in model.modules():
        if module not in pruner.ignored_layers:
            #print(module)
            if isinstance(module, nn.Conv2d):
                layer_channel_cfg[module] = module.out_channels
            elif isinstance(module, nn.Linear):
                layer_channel_cfg[module] = module.out_features

    pruner.step()

    # print("==============After pruning=================")
    # print(model)

    #########################################
    # Testing
    #########################################
    with torch.no_grad():
        if isinstance(example_inputs, dict):
            out = model(**example_inputs)
        else:
            out = model(example_inputs)

        if output_transform:
            out = output_transform(out)
        # print("{} Pruning: ".format(model_name))
        
        params_after_prune = tp.utils.count_params(model)
        print("Params: %s => %s" % (ori_size, params_after_prune))

        if isinstance(out, (dict,list,tuple)):
            # print("  Output:")
            for o in tp.utils.flatten_as_list(out):
                # print(o.shape)
                pass
        else:
            # print("  Output:", out.shape)
            pass
        # print("------------------------------------------------------\n")

    torch.save(model, model_save_path)
    
    print("model saved to file \""+model_save_path)


    return ori_size, params_after_prune

---

In [3]:
pparaDict = {}

model = torchvision.models.detection.ssd300_vgg16(weights='DEFAULT')

ratio = 0.0
model_save_path = str(ratio)+'.pth'
ori_para, after_para = my_prune(model, torch.randn(1,3,300,300),None,'ssd', ratio = ratio, model_save_path = model_save_path)
pparaDict[ratio] = [ori_para, after_para]

# print(pparaDict)

Params: 35641826 => 35641826
model saved to file "0.0.pth


In [4]:
summary(model, (1,3,300,300))

Layer (type:depth-idx)                   Output Shape              Param #
SSD                                      [200, 4]                  --
├─GeneralizedRCNNTransform: 1-1          [1, 3, 300, 300]          --
├─SSDFeatureExtractorVGG: 1-2            [1, 256, 1, 1]            512
│    └─Sequential: 2-1                   [1, 512, 38, 38]          --
│    │    └─Conv2d: 3-1                  [1, 54, 300, 300]         1,512
│    │    └─ReLU: 3-2                    [1, 54, 300, 300]         --
│    │    └─Conv2d: 3-3                  [1, 54, 300, 300]         26,298
│    │    └─ReLU: 3-4                    [1, 54, 300, 300]         --
│    │    └─MaxPool2d: 3-5               [1, 54, 150, 150]         --
│    │    └─Conv2d: 3-6                  [1, 108, 150, 150]        52,596
│    │    └─ReLU: 3-7                    [1, 108, 150, 150]        --
│    │    └─Conv2d: 3-8                  [1, 108, 150, 150]        105,084
│    │    └─ReLU: 3-9                    [1, 108, 150, 150]        -