# Quick Start

In this part, we provide some basic yet all-in-one examples to show the features of Torch-Pruning. 

In [2]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import torch
from torchvision.models import resnet18
import torch_pruning as tp

### Method 1. Pruning with DepGraph

``DependencyGraph`` serves as the cornerstone of Torch-Pruning, which automatically identifies and groups all layers with inter-dependency. In structural pruning, two layers with dependency should be pruned simultaneously. Therefore, to prune a complicated model, we need to handle those layers carefully. The following example shows the pipeline of pruning single layer in a ResNet-18.

In [17]:
# 0. prepare your model and example inputs
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
# 构建依赖关系图
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
# 修剪 conv1 的输出通道，修剪函数，修剪通道索引
pruning_idxs = pruning_idxs=[2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

# 3. prune all grouped layer that is coupled with model.conv1
# 修剪
if DG.check_pruning_group(pruning_group):
    pruning_group.prune()

After invoking the ``.exec`` method, an inplace pruning will be applied to the model. Upon printing the model, we can notice that multiple layers, such as "model.conv1", "model.bn1", and "model.layer1[0].conv1" are pruned by Torch-Pruning.

In [18]:
print("After pruning:")
print(model)

After pruning:
ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(i

Let's inspect the pruning group. The results will show how a pruning operation triggers (=>) another one.

In [19]:
print(pruning_group)


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), len(idxs)=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), len(idxs)=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), len(idxs)=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWis

You can also get all groups from ``DependencyGraph`` using the ``get_all_groups`` method. 

In [20]:
all_groups = list(DG.get_all_groups())
print("Number of Groups: %d"%len(all_groups))
print(f"The first Group: {all_groups[0]}")
print("The last Group:", all_groups[-1])

Number of Groups: 13
The first Group: 
--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on fc (Linear(in_features=512, out_features=1000, bias=True)) => prune_out_channels on fc (Linear(in_features=512, out_features=1000, bias=True)), len(idxs)=1000
--------------------------------

The last Group: 
--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.1.conv1 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on layer4.1.conv1 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), len(idxs)=512
[1] prune_out_channels on layer4.1.conv1 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on layer4.1.bn1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=512
[2] prune_out_channels on 

如何扫描所有组（高级）：
一个模型中可能有多个组。我们可以使用 DG.get_all_groups(ignored_layers, root_module_types) 按顺序扫描所有可剪枝的组。每个组将从与 root_module_types 中的 nn.Module 类匹配的层开始。 ignored_layers 参数用于跳过一些不应被剪枝的层。例如，我们可以跳过 ResNet 模型中的第一个卷积层。

In [25]:
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[torch.nn.Conv2d, torch.nn.Linear]):
    # 按顺序处理组
    idxs = [2, 4, 6]
    group.prune(idxs=idxs)
    print(group)



--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on fc (Linear(in_features=-12, out_features=-9, bias=True)) => prune_out_channels on fc (Linear(in_features=-12, out_features=-9, bias=True)), len(idxs)=0
--------------------------------


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.0.downsample.0 (Conv2d(-12, -15, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(-12, -15, kernel_size=(1, 1), stride=(2, 2), bias=False)), len(idxs)=0
--------------------------------


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer3.0.downsample.0 (Conv2d(-12, -15, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer3.0.downsample.0 (Conv2d(-12, -15, kernel_size=(1, 1), stride=(2, 2), bias=False)), len(idxs)=0
-

---

### 2. High-level Pruners

Pruning a neural network using the ``DependencyGraph`` can be still complicated, especially for models with numerous layers. Therefore, we also offer high-level pruners to simplify this process. For example, you can easily prune a ResNet18 model with a simple magnitude-based pruner. This method removes weights with small magnitude in the network, resulting in a smaller and faster model without too much performance lost in accuracy

通过 DepGraph，我们开发了几个高级剪枝工具来简化剪枝过程。通过指定所需的通道剪枝比例，剪枝工具将扫描所有可剪枝的组，估计权重重要性并执行剪枝。您可以使用自己的训练代码微调剩余的权重。有关此过程的详细信息，请参阅本教程，该教程展示了如何从头实现 Network Slimming（ICCV 2017）剪枝工具。此外，VainF/Isomorphic-Pruning 中提供了 ViT 和 ConvNext 剪枝的更实用示例。

In [2]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 0. importance criterion for parameter selections
# 重要性准则，在这里我们计算分组权重的 L2 范数作为重要性评分。
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and module.out_features == 1000:
        ignored_layers.append(module)  # DO NOT prune the final classifier!
        print(name)

# 2. Pruner initialization
# 剪枝初始化
iterative_steps = 5  # You can prune your model to the target pruning ratio iteratively.
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    isomorphic=True,
    global_pruning=True,  # If False, a uniform ratio will be assigned to different layers.
    importance=imp,  # importance criterion for parameter selection
    iterative_steps=iterative_steps,  # the number of iterations to achieve target ratio
    pruning_ratio=0.5,  # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    pruning_ratio_dict={
        model.conv1: 0.2,
    },
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    # 3. the pruner.step will remove some channels from the model with least importance
    # tp.utils.print_tool.before_pruning(model)
    pruner.step()
    # tp.utils.print_tool.after_pruning(model)

    # 4. Do whatever you like here, such as fintuning
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i + 1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    # finetune your model here
    # finetune(model)
    # ...

print(model)

fc
torch.Size([1, 1000])
  Iter 1/5, Params: 11.69 M => 9.98 M
  Iter 1/5, MACs: 1.82 G => 1.37 G
torch.Size([1, 1000])
  Iter 2/5, Params: 11.69 M => 8.32 M
  Iter 2/5, MACs: 1.82 G => 1.03 G
torch.Size([1, 1000])
  Iter 3/5, Params: 11.69 M => 6.72 M
  Iter 3/5, MACs: 1.82 G => 0.76 G
torch.Size([1, 1000])
  Iter 4/5, Params: 11.69 M => 5.22 M
  Iter 4/5, MACs: 1.82 G => 0.54 G
torch.Size([1, 1000])
  Iter 5/5, Params: 11.69 M => 3.85 M
  Iter 5/5, MACs: 1.82 G => 0.38 G
ResNet(
  (conv1): Conv2d(3, 51, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(51, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_sta

### 交互式剪枝

所有高级剪枝器都提供交互式剪枝支持。您可以使用方法 pruner.step(interactive=True) 获取所有组，并通过调用 group.prune() 交互式地剪枝。如果需要控制或监控剪枝过程，此功能特别有用。

In [23]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 0. importance criterion for parameter selections
# 重要性准则，在这里我们计算分组权重的 L2 范数作为重要性评分。
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and module.out_features == 1000:
        ignored_layers.append(module)  # DO NOT prune the final classifier!
        print(name)

# 2. Pruner initialization
# 剪枝初始化
iterative_steps = 5  # You can prune your model to the target pruning ratio iteratively.
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    isomorphic=True,
    global_pruning=True,  # If False, a uniform ratio will be assigned to different layers.
    importance=imp,  # importance criterion for parameter selection
    iterative_steps=iterative_steps,  # the number of iterations to achieve target ratio
    pruning_ratio=0.5,  # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    pruning_ratio_dict={
        model.conv1: 0.2,
    },
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    for group in pruner.step(interactive=True):
        print(group)
        # get the indxs
        dep, idxs = group[0]
        # get the root module
        target_model = dep.target.module 
        # get the pruning function
        pruning_fn = dep.handler
        
        print(dep)
        print(idxs)
        print(target_model)
        print(pruning_fn)
        
        group.prune()
        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
    
    # 4. Do whatever you like here, such as fintuning
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i + 1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    # finetune your model here
    # finetune(model)
    # ...

# print(model)

fc

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)), len(idxs)=52
[1] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=52
[2] prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_6(AddBackward0), len(idxs)=52
[3] prune_out_channels on _ElementWiseOp_6(AddBackward0) => prune_out_channels on layer4.0.bn2 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=52
[4] prune_out_

### 底层剪枝函数

在 Torch-Pruning 中，我们提供了一系列低级剪枝函数，这些函数仅剪枝单个层或模块。要手动剪枝 ResNet-18 的 model.conv1 ，剪枝流程应如下所示：

In [None]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 0. importance criterion for parameter selections
# 重要性准则，在这里我们计算分组权重的 L2 范数作为重要性评分。
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and module.out_features == 1000:
        ignored_layers.append(module)  # DO NOT prune the final classifier!
        print(name)

# 2. Pruner initialization
# 剪枝初始化
iterative_steps = 5  # You can prune your model to the target pruning ratio iteratively.
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    isomorphic=True,
    global_pruning=True,  # If False, a uniform ratio will be assigned to different layers.
    importance=imp,  # importance criterion for parameter selection
    iterative_steps=iterative_steps,  # the number of iterations to achieve target ratio
    pruning_ratio=0.5,  # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    pruning_ratio_dict={
        model.conv1: 0.2,
    },
)

tp.prune_conv_out_channels(model.conv1, idxs=[2, 6, 9])

# 手动修改后续层的通道
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer1[0].conv1, idxs=[2,6,9] )
# tp.prune_conv_out_channels( model.layer1[0].conv1, idxs=[2,6,9] )
# tp.prune_batchnorm_out_channels( model.layer1[0].bn1, idxs=[2,6,9] )
# tp.prune_conv_in_channels( model.layer1[0].conv2, idxs=[2,6,9] )

print(model)
print(model(example_inputs).shape)

### 加载和保存
我们介绍一种利用 pruning_history 来存储和读取剪枝后模型的方法，该方法与PyTorch采用的state_dict非常相似

In [4]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# print(model)
# Global metrics

imp = tp.importance.MagnitudeImportance(p=2)

ignored_layers = []
# DO NOT prune the final classifier!
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

iterative_steps = 5
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, 
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # print(model)
    # print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i + 1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )

state_dict = {
    "model": model.state_dict(),
    "pruning": pruner.pruning_history(),
}

print(pruner.pruning_history())

torch.save(state_dict, "pruned_model.pth")


# Create a new model and pruner
model = resnet18(pretrained=True)

# 构建模型关系依赖图
DG = tp.DependencyGraph().build_dependency(model, example_inputs)
# 读取权重
state_dict = torch.load("pruned_model.pth")
# 加载剪枝历史
DG.load_pruning_history(state_dict["pruning"])
# 加载权重
model.load_state_dict(state_dict["model"])
# print(model)

  Iter 1/5, Params: 11.69 M => 9.48 M
  Iter 1/5, MACs: 1.82 G => 1.47 G
  Iter 2/5, Params: 11.69 M => 7.53 M
  Iter 2/5, MACs: 1.82 G => 1.18 G
  Iter 3/5, Params: 11.69 M => 5.82 M
  Iter 3/5, MACs: 1.82 G => 0.91 G
  Iter 4/5, Params: 11.69 M => 4.32 M
  Iter 4/5, MACs: 1.82 G => 0.68 G
  Iter 5/5, Params: 11.69 M => 3.06 M
  Iter 5/5, MACs: 1.82 G => 0.49 G
[['layer4.0.downsample.0', True, [5, 11, 19, 39, 40, 41, 64, 66, 72, 87, 90, 94, 105, 123, 128, 140, 154, 175, 203, 210, 211, 224, 238, 243, 281, 299, 302, 308, 313, 314, 320, 325, 342, 344, 347, 348, 351, 355, 376, 386, 396, 408, 416, 423, 449, 454, 482, 483, 496, 501, 503, 505]], ['layer3.0.downsample.0', True, [2, 30, 34, 38, 43, 44, 46, 76, 81, 125, 140, 157, 171, 172, 176, 185, 193, 195, 208, 231, 236, 237, 238, 241, 249, 255]], ['layer2.0.downsample.0', True, [1, 3, 26, 42, 43, 46, 56, 65, 68, 78, 93, 97, 103]], ['conv1', True, [2, 4, 9, 13, 18, 35, 36]], ['layer1.0.conv1', True, [11, 17, 29, 30, 44, 56, 62]], ['layer1.1.

<All keys matched successfully>