This is a jupyter demo to show some details in the manuscript "A Dynamic Pruning Method on Multiple Sparse Structures in Deep Neural Networks"

In [1]:
import torch
import copy

class DPU(torch.autograd.Function):
    """
    Dynamic Pruning Uint is used to process the dense weights
    """
    @staticmethod
    def forward(ctx, input, mask):
        # mask indicates which channels in weight tensor are important
        # if one channel is important, it will be multiplied by 1, otherwise 0.
        out = input * mask
        return out

    @staticmethod
    def backward(ctx, grad_output):
        # Pass the gradient directly through
        grad_input = grad_output.clone()
        return grad_input, None


class Conv2d_with_DPU(torch.nn.Conv2d):
    """
    Convolutional layer with filter-wise DPU
    It has a function named 'calculate_mask' that use l1-norm to evaluate the importance 
    of dense weights and generate a mask
    """
    def calculate_mask(self, pruning_rate):
        # pruning_rate should be between 0 to 1
        assert 0 <= pruning_rate < 1
        
        # the following implementation is output channel pruning and use l1-norm
        mask = torch.ones([self.out_channels,])
        num_channel_pruned = int(self.out_channels * pruning_rate)
        scores = self.weight.abs().sum(dim=(1, 2, 3))
        _, index_channel_pruned = torch.topk(scores, num_channel_pruned, largest=False)
        mask[index_channel_pruned] = 0
        self.mask = mask

    def forward(self, input):
        # calculate the new weight according mask
        weight = self.weight
        bias = self.bias
        
        if hasattr(self, 'mask'):
            weight = DPU.apply(self.weight, self.mask.view(-1, 1, 1, 1))
            if self.bias is not None:
                bias = DPU.apply(self.bias, self.mask)

        # Note that the redundant channels always output 0, so the gradient of them
        # will disapper when their's outputs go through the ReLU activation function.
        # To prevent this problem, we add a tiny positive number to outputs of all
        # channels, which has almost no effect on the outputs but can get non-0 gradient
        # for the redundant channels.
        return self._conv_forward(input, weight, bias) + 1e-6

In [2]:
# create a Conv2d with filter-wise DPU
module = Conv2d_with_DPU(in_channels=1,
                         out_channels=2,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         bias=False)

# create a input
input = torch.randn([1, 1, 4, 4])

print(module.weight)
# calculate module's mask
module.calculate_mask(pruning_rate=0.5)
# print the mask calculated
print(module.mask)
# check the scores of each output channel
print(module.weight.abs().sum(dim=(1, 2, 3)))


Parameter containing:
tensor([[[[ 0.1589,  0.3250, -0.2882],
          [ 0.2477, -0.0303, -0.2452],
          [-0.0641,  0.0733, -0.0732]]],


        [[[-0.1772,  0.2492, -0.2567],
          [ 0.1221,  0.1858, -0.2877],
          [-0.1682, -0.3068, -0.0706]]]], requires_grad=True)
tensor([0., 1.])
tensor([1.5058, 1.8242], grad_fn=<SumBackward1>)


In [3]:
# Observing the W', we can see one of the output channel of weight has been zeros.
DPU.apply(module.weight, module.mask.view(-1, 1, 1, 1))

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.1772,  0.2492, -0.2567],
          [ 0.1221,  0.1858, -0.2877],
          [-0.1682, -0.3068, -0.0706]]]], grad_fn=<DPUBackward>)

In [4]:
output1 = module(input)
# Let the output1 go through the relu function and then deriving it,
# we can see that the gradients of the redudant channels are not 0, 
# which indicates that thest gradients can be updated.
output2 = torch.nn.functional.relu(output1)
output2.sum().backward()
print(module.weight.grad)

tensor([[[[ 0.6681,  2.7298,  4.3273],
          [ 2.2807,  3.8396,  4.5354],
          [ 2.8895,  4.5643,  4.7680]]],


        [[[ 0.9566,  3.4354, -1.4009],
          [ 1.1768,  5.1519, -0.2187],
          [-0.7392, -1.3219,  1.3150]]]])


The following is a comparative experiment to demonstrate that the gradient of the redundant weights in the DPU is not zero

In [22]:
import torch
import copy
import torch.nn.utils.prune as prune

module = Conv2d_with_DPU(in_channels=1,
                         out_channels=2,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         bias=False)

# create a input
input = torch.randn([1, 1, 4, 4])

In [23]:
# normal pruning
module_normal_pruning = copy.deepcopy(module)
module_normal_pruning = prune.ln_structured(
    module=module_normal_pruning,
    name='weight',
    amount=0.5,
    n=1,
    dim=0
    )
output1 = module_normal_pruning(input)
output2 = torch.nn.functional.relu(output1)
output2.sum().backward()
print("module_normal_pruning's grad:\n",
      module_normal_pruning.weight_orig.grad)

module_normal_pruning's grad:
 tensor([[[[ 5.7206,  3.4495, -3.1273],
          [ 2.4458,  0.5445, -4.4313],
          [-3.3191,  2.3399, -1.3938]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]]]])


In [24]:
# Conv with DPU
module_DPU = copy.deepcopy(module)
module_DPU.calculate_mask(pruning_rate=0.5)
output3 = module_DPU(input)
output4 = torch.nn.functional.relu(output3)
output4.sum().backward()
print()
print("module_DPU's grad:\n", module_DPU.weight.grad)


module_DPU's grad:
 tensor([[[[ 5.7206,  3.4495, -3.1273],
          [ 2.4458,  0.5445, -4.4313],
          [-3.3191,  2.3399, -1.3938]]],


        [[[ 1.2246,  1.5253, -1.9536],
          [ 1.5667,  1.3428, -2.7680],
          [-0.7879, -1.7737, -5.8281]]]])


we can see that the normal pruning using PyTorch implemention get zero grads of the redundant filter, while the DPU get non-zero grads, and the importance channels' grads are equal between two method. 
