## Torch-Pruning(DepGraph)
1. pruner.function: https://github.com/VainF/Torch-Pruning/blob/master/torch_pruning/pruner/function.py
2. torch_pruning.dependency: https://github.com/VainF/Torch-Pruning/blob/master/torch_pruning/dependency.py


In [1]:
import torch
import torch_pruning as tp
from torchsummary import summary
import torch.backends.cudnn as cudnn

OSError: [WinError 127] 找不到指定的程序。 Error loading "c:\Users\user\anaconda3\envs\YOLOv8_env\lib\site-packages\torch\lib\c10_cuda.dll" or one of its dependencies.

In [2]:
if torch.cuda.is_available():
    cudnn.benchmark = True
    device = "cuda"
    print(torch.cuda.get_device_name())
else:
    device = "cpu"
    print("Use CPU")

Quadro RTX 3000 with Max-Q Design


In [None]:
# Load pytorch weights
PATH = r'my_weights/Resnet18_e20_b5_t70_v30.pth'
model = torch.load(PATH).to(device)
print(summary(model, input_size=(3, 32, 32)))

### A Minimal Example of DepGraph

In [4]:
# 1. Build dependency graph for a resnet18. This requires a dummy input for forwarding
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,32,32).to(device))

In [5]:
# 2. To prune the output channels of model.conv1, we need to find the corresponding group with a pruning function and pruning indices.
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
print(group.details())


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs (3) =[2, 6, 9]  (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] 
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs (3) =[2, 6, 9] 
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs (3) =[2, 6, 9] 
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithInd

In [6]:
# 3. Do the pruning
if DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.
    print(DG.check_pruning_group(group))
    group.prune()

print(summary(model, input_size=(3, 32, 32)))

True
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 61, 16, 16]           8,967
       BatchNorm2d-2           [-1, 61, 16, 16]             122
              ReLU-3           [-1, 61, 16, 16]               0
         MaxPool2d-4             [-1, 61, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          35,136
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 61, 8, 8]          35,136
       BatchNorm2d-9             [-1, 61, 8, 8]             122
             ReLU-10             [-1, 61, 8, 8]               0
       BasicBlock-11             [-1, 61, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          35,136
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [

In [7]:
# 4. Save & Load
model.zero_grad() # clear gradients to avoid a large file size
torch.save(model, PATH.replace(".pth", "(DG).pth")) # !! no .state_dict here since the structure has been changed after pruning
# model = torch.load(PATH)

### GroupTaylorImportance + MetaPruner
1. Importance : https://github.com/VainF/Torch-Pruning/blob/adf1b075aa4f53043937d29e1953516ef477fc81/torch_pruning/pruner/importance.py#L37
2. Pruner : https://github.com/VainF/Torch-Pruning/tree/master/torch_pruning/pruner/algorithms

In [18]:
PATH = r'my_weights/Resnet18_e20_b5_t70_v30.pth'
model = torch.load(PATH).to(device)
summary(model, input_size=(3, 32, 32))

  model = torch.load(PATH).to(device)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
# 1. Importance criterion
imp = tp.importance.GroupTaylorImportance()

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 10:
        ignored_layers.append(m) # DO NOT prune the final classifier!
print(ignored_layers)

example_inputs = torch.randn(1, 3, 32, 32).to(device)
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs = example_inputs,
    importance = imp,
    pruning_ratio = 0.5,
    ignored_layers = ignored_layers,
)

# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
if isinstance(imp, tp.importance.GroupTaylorImportance):
    loss = model(example_inputs).sum() 
    loss.backward() # before pruner.step()

pruner.step()

[]


In [16]:
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 16, 16]           4,704
       BatchNorm2d-2           [-1, 32, 16, 16]              64
              ReLU-3           [-1, 32, 16, 16]               0
         MaxPool2d-4             [-1, 32, 8, 8]               0
            Conv2d-5             [-1, 32, 8, 8]           9,216
       BatchNorm2d-6             [-1, 32, 8, 8]              64
              ReLU-7             [-1, 32, 8, 8]               0
            Conv2d-8             [-1, 32, 8, 8]           9,216
       BatchNorm2d-9             [-1, 32, 8, 8]              64
             ReLU-10             [-1, 32, 8, 8]               0
       BasicBlock-11             [-1, 32, 8, 8]               0
           Conv2d-12             [-1, 32, 8, 8]           9,216
      BatchNorm2d-13             [-1, 32, 8, 8]              64
             ReLU-14             [-1, 3

In [None]:
model_path = PATH.replace('.pth', '(GroupTaylorImportance_MetaPruner).pth')
torch.save(model, model_path)