## pytorch Pruner Demo
Purpose:
- Demonstrate the pruners.BasePruner module with modules.MaskedModule

In [5]:
import sys
sys.path.insert(0,'../')
from pytorchpruner.modules import MaskedModule
from pytorchpruner.pruners import BasePruner

## a Toy CNN
Let's define a toy CNN

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8,16, kernel_size=5)
        self.fc1 = nn.Linear(256, 64)
        self.fc2 = nn.Linear(64, 10)
        self.nonlins = {'conv1':('max_relu',(2,2)),'conv2':('max_relu',(2,2)),'fc1':'relu','fc2':'log_softmax'}

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 256)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x,dim=1)

def weight_init(m):
    if isinstance(m,(torch.nn.Conv2d,torch.nn.Linear)):
        nn.init.xavier_uniform(m.weight)

## Masked module and Pruner
Masked module is needed Pruner to work on the network. This can be done explicitly or implicitly.

In [22]:
model = MaskedModule(Net())
# model = Net() #implicit MaskedModel(model) is called during initilization
pruner = BasePruner(model)
#dummy batch sample
x=Variable(torch.Tensor(32,1,28,28)) #mnist batch
y=Variable((torch.ones(32)).long())

In [23]:
#Before and after pruning
print(model(x)[0])
print(pruner.masked_model.module.conv1.weight[0])
pruner.prune(0.5)
print(model(x)[0])
print(model.module.conv1.weight[0])
## Printing the mask tensor of the first conv layer
print(model._mask_dict[model.module.conv1][0])


Variable containing:
-2.3978
-2.3298
-2.1738
-2.3781
-2.2667
-2.4077
-2.2205
-2.1774
-2.3512
-2.3596
[torch.FloatTensor of size 10]

Variable containing:
(0 ,.,.) = 
 -0.1216  0.0017  0.0447  0.0141 -0.1600
  0.0304 -0.1724  0.0179  0.1675 -0.0229
 -0.1923  0.1447 -0.1297  0.0304  0.0879
 -0.1815  0.0625  0.1408  0.0132 -0.0177
  0.1821  0.0707  0.1053 -0.0330  0.0306
[torch.FloatTensor of size 1x5x5]

Variable containing:
-2.3958
-2.3270
-2.1862
-2.3798
-2.2650
-2.3988
-2.2199
-2.1820
-2.3522
-2.3523
[torch.FloatTensor of size 10]

Variable containing:
(0 ,.,.) = 
 -0.1216  0.0000  0.0000  0.0000 -0.1600
  0.0000 -0.1724  0.0000  0.1675  0.0000
 -0.1923  0.1447 -0.1297  0.0000  0.0000
 -0.1815  0.0000  0.1408  0.0000  0.0000
  0.1821  0.0000  0.1053  0.0000  0.0000
[torch.FloatTensor of size 1x5x5]


(0 ,.,.) = 
  1  0  0  0  1
  0  1  0  1  0
  1  1  1  0  0
  1  0  1  0  0
  1  0  1  0  0
[torch.ByteTensor of size 1x5x5]



In [24]:
#applying mask on gradients
output = model(x)
loss = F.nll_loss(output, y)
loss.backward()
print(model.module.conv1.weight.grad[0])
model.apply_mask_on_gradients()
print(model.module.conv1.weight.grad[0])

Variable containing:
(0 ,.,.) = 
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
[torch.FloatTensor of size 1x5x5]

Variable containing:
(0 ,.,.) = 
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
  0  0  0  0  0
[torch.FloatTensor of size 1x5x5]



## Saving and loading maskedModule


In [10]:
torch.save(model,'test.mod')
model2 = torch.load('test.mod')
print(model2)
print(model2._mask_dict[model2.module.conv1][0])

MaskedModule(
  (module): Net(
    (conv1): Conv2d (1, 8, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d (8, 16, kernel_size=(5, 5), stride=(1, 1))
    (fc1): Linear(in_features=256, out_features=64)
    (fc2): Linear(in_features=64, out_features=10)
  )
)

(0 ,.,.) = 
  0  1  1  1  0
  1  1  0  0  1
  1  1  1  0  1
  0  1  1  1  0
  1  0  0  0  0
[torch.ByteTensor of size 1x5x5]



  "type " + obj.__name__ + ". It won't be checked "
