In [1]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp
from torchstat import stat
from torchinfo import summary
import os


def save_checkpoint(state, filepath, name):
    torch.save(state, os.path.join(filepath, name+'checkpoint.pth'))

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
current_step = 1
prune_amounts = [x / 64 for x in range(48)]

# 0.015625 -> 1 0.03125 -> 2 0.0625 -> 4 0.125 -> 8 0.25 -> 16 0.5 -> 32
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)


for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
    current_step += 1

    # finetune your model here
    # finetune(model)
    # ...
# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
# print(new_model) # This will be a pruned model.

summary(new_model, (3, 224, 224))





layer4.0.downsample.0 [407, 375, 495, 441, 476, 6, 497, 187, 266, 100, 290, 35, 163, 154, 116, 210, 92, 328, 327, 486, 225, 458, 61, 475, 385, 422, 414, 357, 352, 179, 445, 456, 354, 9, 68, 419, 48, 117, 212, 40, 111, 177, 264, 12, 21, 446, 41, 34, 285, 252, 390, 213]
layer3.0.downsample.0 [40, 70, 130, 137, 211, 133, 86, 2, 58, 190, 3, 61, 81, 34, 50, 221, 44, 8, 140, 102, 114, 16, 149, 141, 240, 184]
layer2.0.downsample.0 [53, 24, 18, 55, 15, 28, 5, 76, 125, 21, 22, 47, 109]
conv1 [18, 4, 7, 54, 35, 9, 25]
layer1.0.conv1 [4, 35, 33, 61, 57, 55, 44]
layer1.1.conv1 [3, 38, 18, 22, 34, 12, 57]
layer2.0.conv1 [71, 79, 22, 30, 0, 92, 41, 59, 62, 11, 8, 109, 70]
layer2.1.conv1 [76, 126, 10, 39, 17, 84, 94, 95, 23, 2, 97, 89, 117]
layer3.0.conv1 [64, 172, 174, 157, 168, 89, 150, 147, 248, 145, 143, 203, 201, 178, 179, 197, 107, 108, 110, 134, 187, 21, 230, 40, 24, 25]
layer3.1.conv1 [71, 2, 230, 121, 131, 242, 186, 38, 187, 243, 60, 13, 158, 19, 44, 174, 172, 169, 113, 143, 232, 117, 155, 2

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Conv2d: 1]

In [31]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print (name, list(param.data.shape))

conv1.weight [32, 3, 7, 7]
bn1.weight [32]
bn1.bias [32]
layer1.0.conv1.weight [32, 32, 3, 3]
layer1.0.bn1.weight [32]
layer1.0.bn1.bias [32]
layer1.0.conv2.weight [32, 32, 3, 3]
layer1.0.bn2.weight [32]
layer1.0.bn2.bias [32]
layer1.1.conv1.weight [32, 32, 3, 3]
layer1.1.bn1.weight [32]
layer1.1.bn1.bias [32]
layer1.1.conv2.weight [32, 32, 3, 3]
layer1.1.bn2.weight [32]
layer1.1.bn2.bias [32]
layer2.0.conv1.weight [64, 32, 3, 3]
layer2.0.bn1.weight [64]
layer2.0.bn1.bias [64]
layer2.0.conv2.weight [64, 64, 3, 3]
layer2.0.bn2.weight [64]
layer2.0.bn2.bias [64]
layer2.0.downsample.0.weight [64, 32, 1, 1]
layer2.0.downsample.1.weight [64]
layer2.0.downsample.1.bias [64]
layer2.1.conv1.weight [64, 64, 3, 3]
layer2.1.bn1.weight [64]
layer2.1.bn1.bias [64]
layer2.1.conv2.weight [64, 64, 3, 3]
layer2.1.bn2.weight [64]
layer2.1.bn2.bias [64]
layer3.0.conv1.weight [128, 64, 3, 3]
layer3.0.bn1.weight [128]
layer3.0.bn1.bias [128]
layer3.0.conv2.weight [128, 128, 3, 3]
layer3.0.bn2.weight [128]


In [38]:
# from torchsummary import summary

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = model.cpu()

# summary(model, (3, 224, 224))

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [40]:
# model.zero_grad() # We don't want to store gradient information
# torch.save(model, 'model.pth') # without .state_dict
# model = torch.load('model.pth') # load the pruned model

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

AttributeError: Can't pickle local object 'summary.<locals>.register_hook.<locals>.hook'

In [21]:
from torchstat import stat
import torchvision.models as models

# model = models.resnet50()
stat(model, input_size=example_inputs)

AssertionError: 

In [4]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# # 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
# group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[1] )

# # 3. prune all grouped layers that are coupled with model.conv1 (included).
# if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
#     group.prune()


# prune of channel 1 happens from start of conv.weights 6 groups for each channel
channels_to_prune = [x for x in range(48)]
tp.prune_conv_out_channels( model.conv1, idxs = channels_to_prune )
tp.prune_batchnorm_out_channels( model.bn1, idxs= channels_to_prune )
tp.prune_conv_in_channels(model.layer1[0].conv1, idxs= channels_to_prune )

print("Shape:", model.conv1.weight.shape)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        tp.prune_conv_out_channels( module, idxs = channels_to_prune )
                


# stat(model, (3, 224, 224))
            



# macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
# tp.prune_conv_out_channels( model.conv1, idxs=[0] )

# tp.prune_batchnorm_out_channels( model.bn1, idxs=[1] )

# tp.prune_conv_in_channels( model.layer1[0].conv1, idxs=[1] )

# model
# model.conv1.weight.shape
# 1 , 3 , 7, 7
model





[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
Shape: torch.Size([16, 3, 7, 7])


ResNet(
  (conv1): Conv2d(3, -32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
 

In [5]:
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

group.prune()

[MAdd]: AdaptiveAvgPool2d is not supported!
[Flops]: AdaptiveAvgPool2d is not supported!
[Memory]: AdaptiveAvgPool2d is not supported!
conv1 [2, 6, 9]


In [6]:
model

ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group(model.layer1[0].conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )
group.prune()

layer1.0.conv1 [2, 6, 9]


In [10]:
model

ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(61, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
prune_amounts = [x / 64 for x in range(65)]
prune_amounts

[0.0,
 0.015625,
 0.03125,
 0.046875,
 0.0625,
 0.078125,
 0.09375,
 0.109375,
 0.125,
 0.140625,
 0.15625,
 0.171875,
 0.1875,
 0.203125,
 0.21875,
 0.234375,
 0.25,
 0.265625,
 0.28125,
 0.296875,
 0.3125,
 0.328125,
 0.34375,
 0.359375,
 0.375,
 0.390625,
 0.40625,
 0.421875,
 0.4375,
 0.453125,
 0.46875,
 0.484375,
 0.5,
 0.515625,
 0.53125,
 0.546875,
 0.5625,
 0.578125,
 0.59375,
 0.609375,
 0.625,
 0.640625,
 0.65625,
 0.671875,
 0.6875,
 0.703125,
 0.71875,
 0.734375,
 0.75,
 0.765625,
 0.78125,
 0.796875,
 0.8125,
 0.828125,
 0.84375,
 0.859375,
 0.875,
 0.890625,
 0.90625,
 0.921875,
 0.9375,
 0.953125,
 0.96875,
 0.984375,
 1.0]