In [3]:
import torch
from torchvision.models import resnet18, resnet50
import torch_pruning as tp
from torchinfo import summary
import os

sparsities = [0, 0.0625, 0.125, 0.1875, 0.25, 0.3125, 0.375, 0.4375, 0.5, 0.5625, 0.625, 0.6875, 0.75]

for sparsity in sparsities:
  model = resnet50(pretrained=True)

  # Importance criteria
  example_inputs = torch.randn(1, 3, 224, 224)
  imp = tp.importance.TaylorImportance()

  ignored_layers = []
  for m in model.modules():
      if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
          ignored_layers.append(m) # DO NOT prune the final classifier!

  iterative_steps = 5 # progressive pruning
  current_step = 1
  prune_amounts = [x / 64 for x in range(48)]

  pruner = tp.pruner.MagnitudePruner(
      model,
      example_inputs,
      importance=imp,
      iterative_steps=iterative_steps,
      ch_sparsity=sparsity, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
      ignored_layers=ignored_layers,
  )

  base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)

  print("Pruning sparsity:", sparsity,)
  for i in range(iterative_steps):
      if isinstance(imp, tp.importance.TaylorImportance):
          # Taylor expansion requires gradients for importance estimation
          loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
          loss.backward() # before pruner.step()
      pruner.step()
      macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
      print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
      current_step += 1

  state_dict = tp.state_dict(model) # the pruned model, e.g., a resnet-18-half
  torch.save(state_dict, "./resnet50/"+str(sparsity)+"_"+'pruned.pth')
  model_statistics = summary(model, (1, 3, 224, 224), depth=3, col_names=["kernel_size", "input_size","output_size", "num_params", "mult_adds"],)
  model_statistics_str = str(model_statistics)

  import pickle
  with open("./resnet50/"+str(sparsity)+"_"+'statistics.txt', 'wb') as f:
      pickle.dump(model_statistics_str,f)  
  
  print(model)

  # validate model



Pruning sparsity: 0
Pruning step: 1 multiply–accumulate (macs): 4121925096.0 number of parameters 25557032
Pruning step: 2 multiply–accumulate (macs): 4121925096.0 number of parameters 25557032
Pruning step: 3 multiply–accumulate (macs): 4121925096.0 number of parameters 25557032
Pruning step: 4 multiply–accumulate (macs): 4121925096.0 number of parameters 25557032
Pruning step: 5 multiply–accumulate (macs): 4121925096.0 number of parameters 25557032


  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 