In [2]:
import torch.nn.utils.prune as prune
import torch
from torchvision.models import resnet18
import torch_pruning as pruning
from torchsummary import summary
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = torch.load('resnet/resnet18.pth').to(device)
# build layer dependency for resnet18
model.eval()
# prune.random_unstructured(model.conv1, name="weight", amount=0.3)
summary(model, (3, 64, 64))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           9,408
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
         MaxPool2d-4           [-1, 64, 16, 16]               0
            Conv2d-5           [-1, 64, 16, 16]          36,864
       BatchNorm2d-6           [-1, 64, 16, 16]             128
              ReLU-7           [-1, 64, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          36,864
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
       BasicBlock-11           [-1, 64, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]          36,864
      BatchNorm2d-13           [-1, 64, 16, 16]             128
             ReLU-14           [-1, 64,

In [3]:
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.layer1[0].conv1, 'weight'),
    (model.layer1[0].conv2, 'weight'),
    (model.layer1[1].conv1, 'weight'),
    (model.layer1[1].conv2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
summary(model, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           9,408
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
         MaxPool2d-4           [-1, 64, 16, 16]               0
            Conv2d-5           [-1, 64, 16, 16]          36,864
       BatchNorm2d-6           [-1, 64, 16, 16]             128
              ReLU-7           [-1, 64, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          36,864
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
       BasicBlock-11           [-1, 64, 16, 16]               0
           Conv2d-12           [-1, 64, 16, 16]          36,864
      BatchNorm2d-13           [-1, 64, 16, 16]             128
             ReLU-14           [-1, 64,

In [6]:
print(
    "Sparsity in layer1[0].conv1: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv1.weight == 0))
        / float(model.layer1[0].conv1.weight.nelement())
    )
)
print(
    "Sparsity in layer1[0].conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv2.weight == 0))
        / float(model.layer1[0].conv2.weight.nelement())
    )
)
# print(
#     "Sparsity in fc1.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc1.weight == 0))
#         / float(model.fc1.weight.nelement())
#     )
# )
# print(
#     "Sparsity in fc2.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc2.weight == 0))
#         / float(model.fc2.weight.nelement())
#     )
# )
# print(
#     "Sparsity in fc3.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc3.weight == 0))
#         / float(model.fc3.weight.nelement())
#     )
# )
# print(
#     "Global sparsity: {:.2f}%".format(
#         100. * float(
#             torch.sum(model.conv1.weight == 0)
#             + torch.sum(model.conv2.weight == 0)
#             + torch.sum(model.fc1.weight == 0)
#             + torch.sum(model.fc2.weight == 0)
#             + torch.sum(model.fc3.weight == 0)
#         )
#         / float(
#             model.conv1.weight.nelement()
#             + model.conv2.weight.nelement()
#             + model.fc1.weight.nelement()
#             + model.fc2.weight.nelement()
#             + model.fc3.weight.nelement()
#         )
#     )
# )

Sparsity in layer1[0].conv1: 27.68%
Sparsity in layer1[0].conv2.weight: 17.69%


In [None]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask
