In [2]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp
from torchstat import stat

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
current_step = 1
prune_amounts = [x / 64 for x in range(48)]

# 0.015625 -> 1 0.03125 -> 2 0.0625 -> 4 0.125 -> 8 0.25 -> 16 0.5 -> 32
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.75, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)


for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
    current_step += 1

    # finetune your model here
    # finetune(model)
    # ...
    
model



layer4.0.downsample.0 [92, 1, 138, 407, 154, 301, 479, 232, 328, 371, 116, 61, 364, 110, 146, 266, 352, 498, 497, 68, 495, 419, 458, 489, 486, 179, 327, 476, 281, 283, 441, 173, 437, 210, 212, 163, 290, 225, 133, 9, 357, 445, 385, 48, 6, 402, 100, 12, 177, 34, 285, 21, 117, 252, 373, 494, 455, 492, 264, 40, 511, 303, 111, 404, 83, 331, 446, 279, 36, 473, 427, 122, 67, 129, 10, 69, 422]
layer3.0.downsample.0 [137, 86, 130, 40, 70, 133, 58, 211, 34, 190, 2, 102, 61, 221, 3, 50, 44, 140, 62, 8, 114, 81, 16, 131, 69, 240, 213, 31, 128, 186, 78, 184, 205, 24, 228, 191, 72, 122, 149]
layer2.0.downsample.0 [53, 24, 55, 18, 76, 28, 109, 34, 125, 15, 127, 23, 37, 22, 5, 21, 47, 12, 61, 27]
conv1 [18, 4, 7, 54, 35, 9, 25, 13, 16, 32]
layer1.0.conv1 [44, 35, 33, 57, 55, 21, 22, 32, 16, 61]
layer1.1.conv1 [3, 18, 38, 22, 59, 34, 31, 53, 45, 12]
layer2.0.conv1 [71, 79, 30, 62, 22, 41, 21, 90, 31, 109, 8, 14, 40, 85, 61, 92, 37, 35, 80, 73]
layer2.1.conv1 [126, 17, 39, 2, 10, 23, 84, 113, 95, 7, 97,

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# # 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
# group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[1] )

# # 3. prune all grouped layers that are coupled with model.conv1 (included).
# if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
#     group.prune()


# prune of channel 1 happens from start of conv.weights 6 groups for each channel
channels_to_prune = [x for x in range(48)]
tp.prune_conv_out_channels( model.conv1, idxs = channels_to_prune )
tp.prune_batchnorm_out_channels( model.bn1, idxs= channels_to_prune )
tp.prune_conv_in_channels(model.layer1[0].conv1, idxs= channels_to_prune )

print("Shape:", model.conv1.weight.shape)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        tp.prune_conv_out_channels( module, idxs = channels_to_prune )
                


# stat(model, (3, 224, 224))
            



# macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
# tp.prune_conv_out_channels( model.conv1, idxs=[0] )

# tp.prune_batchnorm_out_channels( model.bn1, idxs=[1] )

# tp.prune_conv_in_channels( model.layer1[0].conv1, idxs=[1] )

# model
# model.conv1.weight.shape
# 1 , 3 , 7, 7
model





[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
Shape: torch.Size([16, 3, 7, 7])


ResNet(
  (conv1): Conv2d(3, -32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
 

In [5]:
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

group.prune()

[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
conv1 [2, 6, 9]


In [6]:
model

ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group(model.layer1[0].conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
group.prune()

layer1.0.conv1 [2, 6, 9]


In [10]:
model

ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(61, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
prune_amounts = [x / 64 for x in range(65)]
prune_amounts

[0.0,
 0.015625,
 0.03125,
 0.046875,
 0.0625,
 0.078125,
 0.09375,
 0.109375,
 0.125,
 0.140625,
 0.15625,
 0.171875,
 0.1875,
 0.203125,
 0.21875,
 0.234375,
 0.25,
 0.265625,
 0.28125,
 0.296875,
 0.3125,
 0.328125,
 0.34375,
 0.359375,
 0.375,
 0.390625,
 0.40625,
 0.421875,
 0.4375,
 0.453125,
 0.46875,
 0.484375,
 0.5,
 0.515625,
 0.53125,
 0.546875,
 0.5625,
 0.578125,
 0.59375,
 0.609375,
 0.625,
 0.640625,
 0.65625,
 0.671875,
 0.6875,
 0.703125,
 0.71875,
 0.734375,
 0.75,
 0.765625,
 0.78125,
 0.796875,
 0.8125,
 0.828125,
 0.84375,
 0.859375,
 0.875,
 0.890625,
 0.90625,
 0.921875,
 0.9375,
 0.953125,
 0.96875,
 0.984375,
 1.0]