In [19]:
import torch
from torchvision.models import resnet18
import torch_pruning as pruning
model = torch.load('resnet/resnet18.pth')
# build layer dependency for resnet18
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [20]:
DG = pruning.DependencyGraph(model, fake_input=torch.randn(1,3,224,224) )
# get a pruning plan according to the dependency graph. idxs is the indices of pruned filters.
pruning_plan = DG.get_pruning_plan(model.conv1, pruning.prune_conv, idxs=[2, 6, 9] )
print(pruning_plan)
# execute this plan (prune the model)
pruning_plan.exec()


-------------
[ prune_conv on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), Index=[2, 6, 9], NumPruned=441]
[ prune_batchnorm on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[2, 6, 9], NumPruned=6]
[ prune_related_conv on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[2, 6, 9], NumPruned=1728]
[ _prune_elementwise_op on elementwise (_ElementWiseOp()), Index=[2, 6, 9], NumPruned=0]
[ prune_batchnorm on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), Index=[2, 6, 9], NumPruned=6]
[ prune_conv on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[2, 6, 9], NumPruned=1728]
[ prune_related_conv on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), Index=[2, 6, 9], NumPruned=1728]
[ _prune_elementwise_op on

11211

In [3]:
from torchvision.models import alexnet
import torch_pruning as pruning
import numpy as np 
import torch
import torch.nn.functional as F

# model = alexnet(pretrained=True)
print("Before pruning: ")
print(model.features[:4])
print(model.features[0].weight.shape)
print(model.features[3].weight.shape)

pruning.prune_conv(model.features[0], idxs=[0,1,3,4])
pruning.prune_related_conv( model.features[3], idxs=[0,1,3,4] )

print("\nAfter pruning: ")
print(model.features[:4])
print(model.features[0].weight.shape)
print(model.features[3].weight.shape)

mask1 = np.random.randint(low=0, high=2, size=model.features[0].weight.shape)
pruning.mask_weight( model.features[0],mask1 )
print("add mask1, masking %d weights"%( (mask1!=0).sum() ))

mask2 = np.random.randint(low=0, high=2, size=model.features[0].weight.shape)
pruning.mask_weight( model.features[0], mask2)
print("add mask2, masking %d weights"%( (mask2!=0).sum() ))

print("%d weights were actually masked"%( (model.features[0].weight_mask.numpy()!=0).sum() ))
print( "mask1 | mask2 == weight_mask: ", np.alltrue( np.logical_or(mask1, mask2) == model.features[0].weight_mask.numpy() ) )

random_inputs = torch.randn((1,3,224,224))
output = model(random_inputs)

conv1_output = model.features[0](random_inputs) 
masked_weight = torch.tensor( np.logical_or(mask1, mask2) ) *  model.features[0].weight
conv1_output_target = F.conv2d(random_inputs, masked_weight, bias=model.features[0].bias, stride=4, padding=2) 
print( "Correct output from conv1:", torch.all( conv1_output == conv1_output_target ) )

Before pruning: 
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
torch.Size([64, 3, 3, 3])
torch.Size([128, 64, 3, 3])

After pruning: 
Sequential(
  (0): Conv2d(3, 60, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(60, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
torch.Size([60, 3, 3, 3])
torch.Size([128, 60, 3, 3])
add mask1, masking 840 weights
add mask2, masking 797 weights
1213 weights were actually masked
mask1 | mask2 == weight_mask:  True


RuntimeError: The size of tensor a (224) must match the size of tensor b (57) at non-singleton dimension 3