# Pruning Demo


In [58]:
import sys 
sys.path.append('../src/')

from models import LeNetFC 

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F


In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNetFC().to(device)

## Inspect a module

In [60]:
module = model.fc1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[-0.0060, -0.0218, -0.0119,  ..., -0.0187,  0.0355,  0.0146],
        [ 0.0223, -0.0075, -0.0187,  ...,  0.0239,  0.0068,  0.0028],
        [-0.0183,  0.0175,  0.0241,  ...,  0.0157,  0.0232,  0.0326],
        ...,
        [ 0.0110,  0.0016,  0.0083,  ..., -0.0178,  0.0097,  0.0284],
        [-0.0129,  0.0260,  0.0183,  ..., -0.0120, -0.0122, -0.0218],
        [ 0.0119, -0.0349,  0.0135,  ..., -0.0229,  0.0346,  0.0061]],
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0231, -0.0195,  0.0143, -0.0247,  0.0069,  0.0042, -0.0328, -0.0053,
        -0.0342, -0.0012, -0.0094, -0.0017, -0.0039, -0.0275,  0.0151, -0.0187,
        -0.0278,  0.0331,  0.0274,  0.0206, -0.0245,  0.0108,  0.0072, -0.0184,
        -0.0230, -0.0148,  0.0305,  0.0239,  0.0032,  0.0075,  0.0340, -0.0310,
         0.0081,  0.0151, -0.0034,  0.0098,  0.0079,  0.0126,  0.0023, -0.0291,
         0.0154, -0.0120, -0.0121, -0.0089,  0.0281, -0.0163,  0.0135,  

In [61]:
# prune the weights (not biases here)
p = 0.3
pruned_tensor = prune.random_unstructured(module, name='weight', amount=p)
print(pruned_tensor)

Linear(in_features=784, out_features=300, bias=True)


In [62]:
# notice 'weight_orig' is stored, 'weight' now contains pruned params
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0231, -0.0195,  0.0143, -0.0247,  0.0069,  0.0042, -0.0328, -0.0053,
        -0.0342, -0.0012, -0.0094, -0.0017, -0.0039, -0.0275,  0.0151, -0.0187,
        -0.0278,  0.0331,  0.0274,  0.0206, -0.0245,  0.0108,  0.0072, -0.0184,
        -0.0230, -0.0148,  0.0305,  0.0239,  0.0032,  0.0075,  0.0340, -0.0310,
         0.0081,  0.0151, -0.0034,  0.0098,  0.0079,  0.0126,  0.0023, -0.0291,
         0.0154, -0.0120, -0.0121, -0.0089,  0.0281, -0.0163,  0.0135,  0.0300,
         0.0181,  0.0158,  0.0140,  0.0206, -0.0189,  0.0099, -0.0235, -0.0305,
         0.0293,  0.0342,  0.0357,  0.0195,  0.0295, -0.0278,  0.0254,  0.0221,
         0.0008,  0.0105,  0.0349,  0.0135, -0.0243,  0.0130, -0.0126,  0.0348,
        -0.0270,  0.0047, -0.0182, -0.0011,  0.0065,  0.0042,  0.0334,  0.0322,
         0.0129,  0.0168,  0.0151,  0.0281,  0.0256, -0.0064,  0.0021,  0.0329,
        -0.0173, -0.0172,  0.0208,  0.0148,  0.0126, -0.0187, -0.0219, -0.0340,
        

In [63]:
# see mask
print(module.weight_mask[:2])

tensor([[0., 1., 0.,  ..., 0., 1., 1.],
        [1., 1., 0.,  ..., 0., 0., 0.]])


In [64]:
# see new pruned params (compare to weight_orig printed above, see how mask is applied?)
print(module.weight[:2])

tensor([[-0.0000, -0.0218, -0.0000,  ..., -0.0000,  0.0355,  0.0146],
        [ 0.0223, -0.0075, -0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<SliceBackward>)


In [65]:
# remove 3 smallest bias params according to L1 norm
#prune.l1_unstructured(module, name="bias", amount=3)

In [66]:
# note bias_orig
#print(list(module.named_parameters()))

## Iterative Pruning

In [67]:
# total magnitude of weights
torch.sum(abs(module.weight))

tensor(2940.2097, grad_fn=<SumBackward0>)

In [43]:
prune.ln_structured(module, name='weight', amount=.5, n=2, dim=0)

Linear(in_features=784, out_features=300, bias=True)

In [44]:
# can see half the weights are now zeroed out (pruned/remove)
torch.sum(abs(module.weight))

tensor(1510.6038, grad_fn=<SumBackward0>)

In [45]:
# history of pruning applied to weight param
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break

for h in hook:
    print(h)

<torch.nn.utils.prune.RandomUnstructured object at 0x7fe910115a50>
<torch.nn.utils.prune.LnStructured object at 0x7fe8f02c2990>


In [46]:
# serialized and retrievable 
print(model.state_dict().keys())

odict_keys(['fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


## Make Pruning Permanent
Remove pruning re-parametrization.  

In [47]:
# Before: with weight_orig in the named_paramaters
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-2.6465e-03,  2.5319e-02,  1.3529e-02, -1.3225e-02, -3.9753e-05,
         5.4984e-03, -3.2310e-02, -9.5201e-03,  1.2895e-02,  1.1898e-02,
         2.1479e-02,  1.6147e-03,  2.4142e-02,  1.6380e-02, -1.8390e-02,
         1.9494e-02,  4.0276e-03, -2.1081e-02, -9.9492e-03, -2.3118e-02,
         2.6920e-02,  1.7556e-02, -2.7953e-02, -6.7913e-03, -5.4557e-03,
        -3.0452e-02, -2.0725e-03, -3.1632e-02, -8.0891e-03, -1.9213e-02,
         4.3721e-03, -3.5544e-02,  8.8698e-03,  4.7285e-03, -2.8924e-02,
         1.2181e-02, -1.5976e-02,  1.2288e-02, -1.2595e-02, -1.5588e-02,
         2.6274e-02, -1.6819e-02, -2.7067e-03, -9.3004e-03, -3.4223e-02,
         2.3675e-02, -3.5529e-03,  2.3711e-02,  3.0173e-02,  1.7235e-02,
        -5.9229e-03,  9.7212e-03,  3.5452e-02, -3.1832e-02,  3.5085e-02,
         8.7124e-04, -1.8860e-02,  1.2506e-02, -8.5054e-03,  1.7609e-02,
        -1.2864e-02, -5.9707e-03, -1.6760e-02,  7.2212e-03,  2.5267e-02,
         1.0027e-02

In [48]:
# Before: 
print(list(module.named_buffers()))

[('weight_mask', tensor([[0., 1., 0.,  ..., 1., 0., 1.],
        [1., 1., 1.,  ..., 0., 0., 1.],
        [1., 0., 1.,  ..., 1., 1., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]))]


In [49]:
# Before
print(module.weight)

tensor([[ 0.0000,  0.0059, -0.0000,  ...,  0.0253,  0.0000, -0.0355],
        [ 0.0097, -0.0302, -0.0115,  ...,  0.0000, -0.0000,  0.0041],
        [-0.0179, -0.0000, -0.0082,  ...,  0.0196, -0.0244, -0.0000],
        ...,
        [-0.0266,  0.0056, -0.0123,  ..., -0.0000,  0.0345,  0.0000],
        [-0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000, -0.0000],
        [-0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<MulBackward0>)


In [50]:
# note how its just 'weight' in the parameters now, not weight_orig
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-2.6465e-03,  2.5319e-02,  1.3529e-02, -1.3225e-02, -3.9753e-05,
         5.4984e-03, -3.2310e-02, -9.5201e-03,  1.2895e-02,  1.1898e-02,
         2.1479e-02,  1.6147e-03,  2.4142e-02,  1.6380e-02, -1.8390e-02,
         1.9494e-02,  4.0276e-03, -2.1081e-02, -9.9492e-03, -2.3118e-02,
         2.6920e-02,  1.7556e-02, -2.7953e-02, -6.7913e-03, -5.4557e-03,
        -3.0452e-02, -2.0725e-03, -3.1632e-02, -8.0891e-03, -1.9213e-02,
         4.3721e-03, -3.5544e-02,  8.8698e-03,  4.7285e-03, -2.8924e-02,
         1.2181e-02, -1.5976e-02,  1.2288e-02, -1.2595e-02, -1.5588e-02,
         2.6274e-02, -1.6819e-02, -2.7067e-03, -9.3004e-03, -3.4223e-02,
         2.3675e-02, -3.5529e-03,  2.3711e-02,  3.0173e-02,  1.7235e-02,
        -5.9229e-03,  9.7212e-03,  3.5452e-02, -3.1832e-02,  3.5085e-02,
         8.7124e-04, -1.8860e-02,  1.2506e-02, -8.5054e-03,  1.7609e-02,
        -1.2864e-02, -5.9707e-03, -1.6760e-02,  7.2212e-03,  2.5267e-02,
         1.0027e-02

In [51]:
# no weight_mask anymore 
print(list(module.named_buffers()))

[]


## Prune Layers

In [52]:
# look at layers
for idx, (name, module) in enumerate(model.named_modules()):
    if idx > 0 and idx < 4:
        print(name, module)
        

fc1 Linear(in_features=784, out_features=300, bias=True)
fc2 Linear(in_features=300, out_features=100, bias=True)
fc3 Linear(in_features=100, out_features=10, bias=True)


In [53]:
for idx, (name, module) in enumerate(model.named_modules()):
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=p)
        prune.l1_unstructured(module, name='bias', amount=len(module.bias)//2)

In [54]:
print(dict(model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['fc1.weight_mask', 'fc1.bias_mask', 'fc2.weight_mask', 'fc2.bias_mask', 'fc3.weight_mask', 'fc3.bias_mask'])


## Pruning Container for Iterative Pruning

Load FC2 layer module

In [75]:
fc2_module = model.fc2
print(fc2_module)

Linear(in_features=300, out_features=100, bias=True)


In [76]:
list(fc2_module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0571, -0.0313,  0.0509,  ..., -0.0313, -0.0230, -0.0480],
          [-0.0507,  0.0363,  0.0012,  ..., -0.0116, -0.0488,  0.0446],
          [ 0.0556,  0.0529, -0.0282,  ...,  0.0096, -0.0392,  0.0090],
          ...,
          [ 0.0095, -0.0503,  0.0200,  ..., -0.0385, -0.0335, -0.0384],
          [-0.0177, -0.0288, -0.0486,  ...,  0.0447, -0.0326, -0.0449],
          [ 0.0238, -0.0117,  0.0345,  ..., -0.0086, -0.0463,  0.0420]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 0.0165, -0.0475,  0.0276,  0.0074,  0.0144, -0.0508,  0.0158, -0.0143,
           0.0308, -0.0459, -0.0136, -0.0099, -0.0464, -0.0151,  0.0182, -0.0118,
           0.0114, -0.0267, -0.0393, -0.0225,  0.0266, -0.0448,  0.0501,  0.0438,
           0.0325, -0.0140,  0.0118,  0.0167, -0.0183, -0.0093,  0.0045, -0.0577,
           0.0123,  0.0384, -0.0037,  0.0513,  0.0098, -0.0497,  0.0486,  0.0246,
           0.0140, -0.0477, -0.0268, -0.0

In [78]:
prune.l1_unstructured(fc2_module, name='weight', amount=.6)

Linear(in_features=300, out_features=100, bias=True)

In [81]:
# named_parameters has weight_orig now
print(fc2_module.weight)
print(fc2_module.weight_mask)

tensor([[ 0.0571, -0.0000,  0.0509,  ..., -0.0000, -0.0000, -0.0480],
        [-0.0507,  0.0363,  0.0000,  ..., -0.0000, -0.0488,  0.0446],
        [ 0.0556,  0.0529, -0.0000,  ...,  0.0000, -0.0392,  0.0000],
        ...,
        [ 0.0000, -0.0503,  0.0000,  ..., -0.0385, -0.0000, -0.0384],
        [-0.0000, -0.0000, -0.0486,  ...,  0.0447, -0.0000, -0.0449],
        [ 0.0000, -0.0000,  0.0345,  ..., -0.0000, -0.0463,  0.0420]],
       grad_fn=<MulBackward0>)
tensor([[1., 0., 1.,  ..., 0., 0., 1.],
        [1., 1., 0.,  ..., 0., 1., 1.],
        [1., 1., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 1., 0.,  ..., 1., 0., 1.],
        [0., 0., 1.,  ..., 1., 0., 1.],
        [0., 0., 1.,  ..., 0., 1., 1.]])


In [89]:
# value of all original weights (weight_orig)
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(863.4540, grad_fn=<SumBackward0>)

In [82]:
# roughly 60%
torch.sum(abs(fc2_module.weight))

tensor(552.8780, grad_fn=<SumBackward0>)

In [91]:
# make first prune permanent
prune.remove(fc2_module, 'weight')

Linear(in_features=300, out_features=100, bias=True)

In [95]:
# should be 60% now or roughly 552
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(552.8780, grad_fn=<SumBackward0>)

In [96]:
# next pruning iteration 
prune.l1_unstructured(fc2_module, name='weight', amount=.4)

Linear(in_features=300, out_features=100, bias=True)

In [98]:
# now weight_orig is not the weights before pruning, but the weight before this specific pruning step
list(fc2_module.named_parameters())

[('bias',
  Parameter containing:
  tensor([ 0.0165, -0.0475,  0.0276,  0.0074,  0.0144, -0.0508,  0.0158, -0.0143,
           0.0308, -0.0459, -0.0136, -0.0099, -0.0464, -0.0151,  0.0182, -0.0118,
           0.0114, -0.0267, -0.0393, -0.0225,  0.0266, -0.0448,  0.0501,  0.0438,
           0.0325, -0.0140,  0.0118,  0.0167, -0.0183, -0.0093,  0.0045, -0.0577,
           0.0123,  0.0384, -0.0037,  0.0513,  0.0098, -0.0497,  0.0486,  0.0246,
           0.0140, -0.0477, -0.0268, -0.0076, -0.0155, -0.0374,  0.0231,  0.0372,
          -0.0025,  0.0414,  0.0163,  0.0543,  0.0188,  0.0540,  0.0557, -0.0388,
          -0.0089,  0.0378,  0.0262,  0.0457,  0.0396,  0.0044, -0.0520,  0.0098,
           0.0133, -0.0374, -0.0359, -0.0026, -0.0449,  0.0391, -0.0342,  0.0143,
           0.0031,  0.0542,  0.0185, -0.0498,  0.0264,  0.0254, -0.0041,  0.0414,
           0.0310,  0.0224,  0.0452, -0.0045,  0.0172, -0.0448, -0.0475,  0.0338,
           0.0434,  0.0115,  0.0103, -0.0211,  0.0127,  0.0208, 

In [99]:
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(552.8780, grad_fn=<SumBackward0>)

In [100]:
# should be roughly 40% of 552, but its still 552! -- HERES THE PROBLEM 
torch.sum(abs(fc2_module.weight))

tensor(552.8780, grad_fn=<SumBackward0>)

### Try to "remove" after both pruning steps have been applied

In [124]:
# reload FC model since fc2 module has been changed 
model = LeNetFC().to(device)
fc2_module = model.fc2

In [125]:
list(fc2_module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[-0.0545, -0.0436,  0.0258,  ...,  0.0039, -0.0443, -0.0524],
          [-0.0548,  0.0324,  0.0165,  ...,  0.0058, -0.0258, -0.0042],
          [-0.0028,  0.0349,  0.0082,  ...,  0.0084,  0.0015, -0.0075],
          ...,
          [ 0.0412, -0.0450,  0.0222,  ...,  0.0240,  0.0115,  0.0080],
          [ 0.0036,  0.0430, -0.0047,  ...,  0.0107, -0.0437,  0.0300],
          [-0.0123,  0.0268,  0.0109,  ...,  0.0131, -0.0397,  0.0510]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-0.0323, -0.0114,  0.0135,  0.0502,  0.0495, -0.0440,  0.0510, -0.0386,
           0.0079,  0.0013, -0.0199, -0.0119, -0.0153, -0.0424, -0.0486,  0.0399,
          -0.0167, -0.0399,  0.0238,  0.0025, -0.0201, -0.0229, -0.0247,  0.0334,
          -0.0066,  0.0481,  0.0128,  0.0372,  0.0081, -0.0062,  0.0534,  0.0568,
          -0.0334, -0.0539, -0.0522,  0.0269, -0.0117, -0.0299,  0.0328, -0.0458,
          -0.0125,  0.0021,  0.0016,  0.0

In [126]:
prune.l1_unstructured(fc2_module, name='weight', amount=.6)
prune.l1_unstructured(fc2_module, name='weight', amount=.4)

Linear(in_features=300, out_features=100, bias=True)

In [128]:
# history of pruning applied to weight param
for hook in fc2_module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break

for h in hook:
    print(h, h.amount)

<torch.nn.utils.prune.L1Unstructured object at 0x7fe910ba8b90> 0.6
<torch.nn.utils.prune.L1Unstructured object at 0x7fe910ba8210> 0.4


In [130]:
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(867.2637, grad_fn=<SumBackward0>)

In [131]:
# BAD: pytorch prunes .4 of original weights! instead it should prune .4 of the .6 remaining weights!
torch.sum(abs(fc2_module.weight))

tensor(365.2234, grad_fn=<SumBackward0>)

In [134]:
# should be this much left
(.6 * 867)*.4

208.07999999999998

In [135]:
# but this is how much is left 
.4*867

346.8

In [136]:
prune.remove(fc2_module, 'weight')

Linear(in_features=300, out_features=100, bias=True)

In [137]:
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(365.2234, grad_fn=<SumBackward0>)

### Try to do on my own because pytorch is stupid

In [148]:
# reload FC model since fc2 module has been changed 
model = LeNetFC().to(device)
fc2_module = model.fc2

In [149]:
list(fc2_module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0474,  0.0063, -0.0233,  ..., -0.0230, -0.0287,  0.0006],
          [-0.0418,  0.0276, -0.0396,  ...,  0.0406,  0.0070,  0.0261],
          [ 0.0473, -0.0533, -0.0564,  ..., -0.0090, -0.0297, -0.0120],
          ...,
          [-0.0546,  0.0388,  0.0545,  ..., -0.0157,  0.0194, -0.0388],
          [ 0.0576,  0.0390, -0.0176,  ..., -0.0220, -0.0543, -0.0045],
          [ 0.0191, -0.0024, -0.0301,  ..., -0.0111,  0.0191, -0.0450]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-2.9264e-02,  4.1564e-02, -4.4723e-03, -4.1308e-02,  2.0183e-02,
          -3.9997e-02,  1.3494e-02,  4.3906e-02,  4.8878e-02,  5.4643e-02,
           2.4308e-02, -8.1346e-03, -1.7897e-02,  1.3112e-02,  4.4896e-02,
          -1.9509e-02,  4.8706e-02,  5.0557e-03, -4.9757e-02, -1.4422e-02,
           7.4990e-03,  3.1629e-02, -5.1888e-02, -5.3309e-02, -3.4581e-02,
          -6.1646e-03,  1.6914e-02,  4.8195e-02, -5.8696e-03, -2.8868e-02,
 

In [150]:
# careful with indexes! need to check if 'weight' or 'weight_orig'
torch.sum(torch.abs(list(fc2_module.named_parameters())[0][1]))

tensor(864.4819, grad_fn=<SumBackward0>)

In [151]:
pruned_module = prune.l1_unstructured(fc2_module, name='weight', amount=.6)
pruned_module

Linear(in_features=300, out_features=100, bias=True)

In [152]:
pruned_module.weight

tensor([[ 0.0474,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0418,  0.0000, -0.0396,  ...,  0.0406,  0.0000,  0.0000],
        [ 0.0473, -0.0533, -0.0564,  ..., -0.0000, -0.0000, -0.0000],
        ...,
        [-0.0546,  0.0388,  0.0545,  ..., -0.0000,  0.0000, -0.0388],
        [ 0.0576,  0.0390, -0.0000,  ..., -0.0000, -0.0543, -0.0000],
        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0450]],
       grad_fn=<MulBackward0>)

In [153]:
pruned_module.weight_mask

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 1., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 1.],
        [1., 1., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [155]:
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(864.4819, grad_fn=<SumBackward0>)

In [156]:
torch.sum(torch.abs(pruned_module.weight))

tensor(554.0620, grad_fn=<SumBackward0>)

In [162]:
# could be dangerous to change these class variables directly, but there is no function to do so..
fc2_module._parameters['weight_orig'] = pruned_module.weight

In [163]:
fc2_module._parameters['weight_orig']

tensor([[ 0.0474,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0418,  0.0000, -0.0396,  ...,  0.0406,  0.0000,  0.0000],
        [ 0.0473, -0.0533, -0.0564,  ..., -0.0000, -0.0000, -0.0000],
        ...,
        [-0.0546,  0.0388,  0.0545,  ..., -0.0000,  0.0000, -0.0388],
        [ 0.0576,  0.0390, -0.0000,  ..., -0.0000, -0.0543, -0.0000],
        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0450]],
       grad_fn=<MulBackward0>)

In [164]:
# ok this is good
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(554.0620, grad_fn=<SumBackward0>)

In [174]:
pruned_module = prune.l1_unstructured(fc2_module, name='weight', amount=.4)
pruned_module

Linear(in_features=300, out_features=100, bias=True)

In [166]:
pruned_module.weight

tensor([[ 0.0474,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0473, -0.0533, -0.0564,  ..., -0.0000, -0.0000, -0.0000],
        ...,
        [-0.0546,  0.0000,  0.0545,  ..., -0.0000,  0.0000, -0.0000],
        [ 0.0576,  0.0000, -0.0000,  ..., -0.0000, -0.0543, -0.0000],
        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0450]],
       grad_fn=<MulBackward0>)

In [167]:
torch.sum(torch.abs(pruned_module.weight))

tensor(365.7526, grad_fn=<SumBackward0>)

In [168]:
fc2_module._parameters['weight_orig'] = pruned_module.weight

In [169]:
fc2_module._parameters['weight_orig']

tensor([[ 0.0474,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0473, -0.0533, -0.0564,  ..., -0.0000, -0.0000, -0.0000],
        ...,
        [-0.0546,  0.0000,  0.0545,  ..., -0.0000,  0.0000, -0.0000],
        [ 0.0576,  0.0000, -0.0000,  ..., -0.0000, -0.0543, -0.0000],
        [ 0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0450]],
       grad_fn=<MulBackward0>)

In [170]:
# still doesn't work, just does .4 * 864 
torch.sum(torch.abs(list(fc2_module.named_parameters())[1][1]))

tensor(365.7526, grad_fn=<SumBackward0>)

In [171]:
(.6 * 864.4819) * .4

207.475656

### Try to prune completely without pytorch pruning

In [179]:
import numpy as np

In [175]:
# reload FC model since fc2 module has been changed 
model = LeNetFC().to(device)
fc2_module = model.fc2

In [176]:
torch.sum(torch.abs(list(fc2_module.named_parameters())[0][1]))

tensor(860.0298, grad_fn=<SumBackward0>)

In [191]:
nparams_toprune = int(round(0.6 * fc2_module.weight.nelement()))
topk = torch.topk(torch.abs(fc2_module.weight).view(-1), k=nparams_toprune, largest=False)

orig = fc2_module.weight
mask = torch.ones_like(orig)
mask.view(-1)[topk.indices] = 0

mask

tensor([[1., 0., 0.,  ..., 0., 1., 0.],
        [1., 1., 0.,  ..., 1., 1., 0.],
        [0., 0., 1.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 1.,  ..., 0., 1., 0.],
        [1., 1., 0.,  ..., 0., 1., 0.],
        [1., 0., 1.,  ..., 0., 1., 0.]])

In [192]:
fc2_module.weight

Parameter containing:
tensor([[-0.0572, -0.0097, -0.0082,  ...,  0.0041, -0.0560, -0.0169],
        [ 0.0416,  0.0504,  0.0343,  ..., -0.0385, -0.0532, -0.0116],
        [ 0.0148, -0.0199, -0.0506,  ..., -0.0227,  0.0029, -0.0556],
        ...,
        [-0.0180,  0.0263,  0.0502,  ...,  0.0262,  0.0571,  0.0259],
        [-0.0543, -0.0424,  0.0031,  ..., -0.0163, -0.0508, -0.0111],
        [-0.0448, -0.0203, -0.0421,  ..., -0.0070,  0.0426, -0.0152]],
       requires_grad=True)

In [194]:
pruned_module = mask * fc2_module.weight
pruned_module

tensor([[-0.0572, -0.0000, -0.0000,  ...,  0.0000, -0.0560, -0.0000],
        [ 0.0416,  0.0504,  0.0000,  ..., -0.0385, -0.0532, -0.0000],
        [ 0.0000, -0.0000, -0.0506,  ..., -0.0000,  0.0000, -0.0556],
        ...,
        [-0.0000,  0.0000,  0.0502,  ...,  0.0000,  0.0571,  0.0000],
        [-0.0543, -0.0424,  0.0000,  ..., -0.0000, -0.0508, -0.0000],
        [-0.0448, -0.0000, -0.0421,  ..., -0.0000,  0.0426, -0.0000]],
       grad_fn=<MulBackward0>)

In [195]:
torch.sum(torch.abs(pruned_module))

tensor(552.9429, grad_fn=<SumBackward0>)

In [196]:
fc2_module._parameters['weight'] = pruned_module

In [197]:
# ok that worked
torch.sum(torch.abs(list(fc2_module.named_parameters())[0][1]))

tensor(552.9429, grad_fn=<SumBackward0>)

In [200]:
# weight is now updated 
torch.sum(torch.abs(fc2_module.weight))
fc2_module.weight

tensor([[-0.0572, -0.0000, -0.0000,  ...,  0.0000, -0.0560, -0.0000],
        [ 0.0416,  0.0504,  0.0000,  ..., -0.0385, -0.0532, -0.0000],
        [ 0.0000, -0.0000, -0.0506,  ..., -0.0000,  0.0000, -0.0556],
        ...,
        [-0.0000,  0.0000,  0.0502,  ...,  0.0000,  0.0571,  0.0000],
        [-0.0543, -0.0424,  0.0000,  ..., -0.0000, -0.0508, -0.0000],
        [-0.0448, -0.0000, -0.0421,  ..., -0.0000,  0.0426, -0.0000]],
       grad_fn=<MulBackward0>)

Next iterations must be different, only prune non-zero values 

In [211]:
fc2_module.weight != 0

tensor([[ True, False, False,  ..., False,  True, False],
        [ True,  True, False,  ...,  True,  True, False],
        [False, False,  True,  ..., False, False,  True],
        ...,
        [False, False,  True,  ..., False,  True, False],
        [ True,  True, False,  ..., False,  True, False],
        [ True, False,  True,  ..., False,  True, False]])

In [225]:
nparams_toprune = int(round(0.4 * ((1-0.6) * fc2_module.weight.nelement())))
nparams_toprune

4800

In [228]:
topk = torch.topk(torch.abs(fc2_module.weight).view(-1), k=nparams_toprune, largest=False)
topk

torch.return_types.topk(
values=tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<TopkBackward>),
indices=tensor([15000,     1,     2,  ...,  4797,  4798,  4799]))

In [229]:
orig = fc2_module.weight
orig

tensor([[-0.0572, -0.0000, -0.0000,  ...,  0.0000, -0.0560, -0.0000],
        [ 0.0416,  0.0504,  0.0000,  ..., -0.0385, -0.0532, -0.0000],
        [ 0.0000, -0.0000, -0.0506,  ..., -0.0000,  0.0000, -0.0556],
        ...,
        [-0.0000,  0.0000,  0.0502,  ...,  0.0000,  0.0571,  0.0000],
        [-0.0543, -0.0424,  0.0000,  ..., -0.0000, -0.0508, -0.0000],
        [-0.0448, -0.0000, -0.0421,  ..., -0.0000,  0.0426, -0.0000]],
       grad_fn=<MulBackward0>)

In [230]:
mask = torch.ones_like(orig)
mask

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [231]:
mask.view(-1)[topk.indices] = 0

mask

tensor([[1., 0., 0.,  ..., 0., 1., 0.],
        [1., 1., 0.,  ..., 1., 1., 0.],
        [0., 0., 1.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 1.,  ..., 0., 1., 0.],
        [1., 1., 0.,  ..., 0., 1., 0.],
        [1., 0., 1.,  ..., 0., 1., 1.]])

In [233]:
torch.sum(torch.abs(fc2_module.weight * mask))

tensor(552.9429, grad_fn=<SumBackward0>)

516.0

## A new attempt at pruning 5/31 (without pytorch pruning)

In [248]:
import sys 
sys.path.append('../src/')

from models import LeNetFC 

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

import numpy as np

In [249]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNetFC().to(device)

In [250]:
masks = {}
p = .6  # prune 60% (so should have 40% left)

In [251]:
module = model.fc1 
#https://pytorch.org/docs/master/generated/torch.topk.html
topk = torch.topk(torch.abs(module.weight).view(-1), k=10, largest=True)

In [252]:
# total params 
len(module.weight.view(-1))

235200

In [253]:
num_to_prune = int(p * len(module.weight.view(-1)))
num_to_prune

141120

In [254]:
# find k smallest weights
topk = torch.topk(torch.abs(module.weight).view(-1), k=num_to_prune, largest=False)

In [255]:
topk

torch.return_types.topk(
values=tensor([1.7136e-07, 2.4214e-07, 2.6822e-07,  ..., 2.1443e-02, 2.1443e-02,
        2.1444e-02], grad_fn=<TopkBackward>),
indices=tensor([222719,  63452, 223659,  ..., 194661, 186227, 129055]))

In [256]:
mask = torch.ones_like(module.weight)
mask.size()

torch.Size([300, 784])

In [257]:
mask

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [258]:
mask.view(-1)[topk.indices] = 0

In [259]:
mask

tensor([[0., 1., 0.,  ..., 1., 0., 1.],
        [0., 0., 1.,  ..., 1., 0., 1.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 1.,  ..., 0., 1., 0.]])

In [260]:
module.weight

Parameter containing:
tensor([[-1.0884e-02, -3.0770e-02,  1.7654e-02,  ...,  2.2255e-02,
          1.8245e-02,  2.4897e-02],
        [-4.7620e-03,  1.9240e-02,  2.6717e-02,  ..., -2.7123e-02,
         -1.4065e-02,  2.2950e-02],
        [ 2.8371e-02,  6.5669e-03, -2.1588e-02,  ..., -1.8915e-03,
          1.4815e-04,  1.3375e-02],
        ...,
        [-3.4601e-03,  1.4871e-05, -1.5159e-02,  ...,  3.0557e-02,
         -1.0168e-02, -2.9870e-02],
        [ 1.1464e-02, -1.4459e-02, -1.3789e-03,  ...,  1.1140e-02,
         -2.1868e-02,  7.8533e-03],
        [ 1.6093e-02,  1.0116e-02, -3.0530e-02,  ..., -1.3417e-02,
         -2.9252e-02, -1.1651e-02]], requires_grad=True)

In [261]:
mask*module.weight

tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0223,  0.0000,  0.0249],
        [-0.0000,  0.0000,  0.0267,  ..., -0.0271, -0.0000,  0.0230],
        [ 0.0284,  0.0000, -0.0216,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0219,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       grad_fn=<MulBackward0>)

In [262]:
# put it all together and prune 
for name, param in model.named_parameters():
    if 'weight' in name:
        # find smallest magnitude weights
        num_to_prune = int(p * len(param.view(-1)))
        topk = torch.topk(torch.abs(param).view(-1), k=num_to_prune, largest=False)
        
        # create mask
        mask = torch.ones_like(param)
        mask.view(-1)[topk.indices] = 0
        
        masks[name] = mask
        
        print(name)
        print(param)
        print(mask)
        print()

fc1.weight
Parameter containing:
tensor([[-1.0884e-02, -3.0770e-02,  1.7654e-02,  ...,  2.2255e-02,
          1.8245e-02,  2.4897e-02],
        [-4.7620e-03,  1.9240e-02,  2.6717e-02,  ..., -2.7123e-02,
         -1.4065e-02,  2.2950e-02],
        [ 2.8371e-02,  6.5669e-03, -2.1588e-02,  ..., -1.8915e-03,
          1.4815e-04,  1.3375e-02],
        ...,
        [-3.4601e-03,  1.4871e-05, -1.5159e-02,  ...,  3.0557e-02,
         -1.0168e-02, -2.9870e-02],
        [ 1.1464e-02, -1.4459e-02, -1.3789e-03,  ...,  1.1140e-02,
         -2.1868e-02,  7.8533e-03],
        [ 1.6093e-02,  1.0116e-02, -3.0530e-02,  ..., -1.3417e-02,
         -2.9252e-02, -1.1651e-02]], requires_grad=True)
tensor([[0., 1., 0.,  ..., 1., 0., 1.],
        [0., 0., 1.,  ..., 1., 0., 1.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 1.,  ..., 0., 1., 0.]])

fc2.weight
Parameter containing:
tensor([[ 0.0519, -0.0166, 

In [263]:
masks

{'fc1.weight': tensor([[0., 1., 0.,  ..., 1., 0., 1.],
         [0., 0., 1.,  ..., 1., 0., 1.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 1.,  ..., 0., 1., 0.]]),
 'fc2.weight': tensor([[1., 0., 0.,  ..., 0., 1., 0.],
         [1., 0., 1.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [1., 1., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         [1., 1., 0.,  ..., 1., 0., 0.]]),
 'fc3.weight': tensor([[1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
          0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1.,
          0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 1., 0.

In [264]:
# apply the mask and view
for name, param in model.named_parameters():
    if name in masks.keys():
        param.requires_grad_(requires_grad=False)
        param.mul_(masks[name])
        param.requires_grad_(requires_grad=True)
        print(name, param, param.size())

fc1.weight Parameter containing:
tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0223,  0.0000,  0.0249],
        [-0.0000,  0.0000,  0.0267,  ..., -0.0271, -0.0000,  0.0230],
        [ 0.0284,  0.0000, -0.0216,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0219,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       requires_grad=True) torch.Size([300, 784])
fc2.weight Parameter containing:
tensor([[ 0.0519, -0.0000,  0.0000,  ...,  0.0000, -0.0521,  0.0000],
        [-0.0546,  0.0000, -0.0548,  ...,  0.0000,  0.0415, -0.0000],
        [-0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0449],
        ...,
        [ 0.0415,  0.0441,  0.0000,  ...,  0.0000, -0.0000, -0.0000],
        [ 0.0499, -0.0000, -0.0000,  ...,  0.0000,  0.0562, -0.0000],
        [-0.0410,  0.0365, -0.0000,  ..., -0.0520,  0.0000,  0.0000]],
       requires_

In [265]:
for v in masks.values():
    print(v.size())

torch.Size([300, 784])
torch.Size([100, 300])
torch.Size([10, 100])


### Need to reverse
focus on what to keep, not what to prune

### Try 2nd iteration of pruning

In [266]:
masks2 = {}
p = .4

In [267]:
# note module.weight has been changed and prune (in place op)
module.weight

Parameter containing:
tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0223,  0.0000,  0.0249],
        [-0.0000,  0.0000,  0.0267,  ..., -0.0271, -0.0000,  0.0230],
        [ 0.0284,  0.0000, -0.0216,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0219,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       requires_grad=True)

In [268]:
# of size 94080
unpruned_weights = module.weight[module.weight != 0]

In [269]:
num_to_keep = int((1-p) * len(unpruned_weights))
num_to_keep

56448

In [270]:
topk = torch.topk(torch.abs(module.weight).view(-1), k=num_to_keep, largest=True)
mask = torch.zeros_like(module.weight)
mask.view(-1)[topk.indices] = 1
mask

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 1., 0.]])

In [271]:
unpruned_weights = mask[mask != 0]
print(len(unpruned_weights))

56448


In [272]:
# should have this many nonzero params after 2 iterations
(1-.6) * (len(module.weight.view(-1))) * (1-.4)

56448.0

In [273]:
iter2_pruned = mask*module.weight
iter2_pruned

tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [ 0.0284,  0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       grad_fn=<MulBackward0>)

In [274]:
unpruned_weights = iter2_pruned[iter2_pruned != 0]
len(unpruned_weights)

56448

### Calculate masks iteration 2

In [275]:
# put it all together and prune 
for name, param in model.named_parameters():
    if 'weight' in name:
        # find LARGEST magnitude weights
        unpruned_weights = param[param != 0]
        num_to_keep = int((1-p) * len(unpruned_weights))
        topk = torch.topk(torch.abs(param).view(-1), k=num_to_keep, largest=True)
        
        # create mask
        mask = torch.zeros_like(param)
        mask.view(-1)[topk.indices] = 1
        
        masks2[name] = mask
        
        print(name)
        print(param)
        print(mask)
        print()

fc1.weight
Parameter containing:
tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0223,  0.0000,  0.0249],
        [-0.0000,  0.0000,  0.0267,  ..., -0.0271, -0.0000,  0.0230],
        [ 0.0284,  0.0000, -0.0216,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0219,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       requires_grad=True)
tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 1., 0.]])

fc2.weight
Parameter containing:
tensor([[ 0.0519, -0.0000,  0.0000,  ...,  0.0000, -0.0521,  0.0000],
        [-0.0546,  0.0000, -0.0548,  ...,  0.0000,  0.0415, -0.0000],
        [-0.0000, -0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0449],
        

### Apply masks iteration 2

In [276]:
# note these are the params before applying the masks ^ 
# apply the mask and view
for name, param in model.named_parameters():
    if name in masks2.keys():
        param.requires_grad_(requires_grad=False)
        param.mul_(masks2[name])
        param.requires_grad_(requires_grad=True)
        print(name)
        print(param)
        print(mask)
        print()

fc1.weight
Parameter containing:
tensor([[-0.0000, -0.0308,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [ 0.0284,  0.0000, -0.0000,  ..., -0.0000,  0.0000,  0.0000],
        ...,
        [-0.0000,  0.0000, -0.0000,  ...,  0.0306, -0.0000, -0.0299],
        [ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0305,  ..., -0.0000, -0.0293, -0.0000]],
       requires_grad=True)
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
         0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 1., 1.],
        [0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0.

In [279]:
# check correct amount has been pruned
for name, param in model.named_parameters():
    if name in masks2.keys():
        # remain percent iter 1, remain percent iter 2
        theoretical_unpruned = (1-.6) * (1-.4) * (len(param.view(-1)))
        actual_unpruned_param = len(param[param != 0])
        actual_nonzero_mask = torch.sum(masks2[name])
                                    
        assert(theoretical_unpruned == actual_unpruned_param == actual_nonzero_mask)

In [280]:
# it works!
masks2

{'fc1.weight': tensor([[0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 1., 0.]]),
 'fc2.weight': tensor([[1., 0., 0.,  ..., 0., 1., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [0., 1., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 'fc3.weight': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
          0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1.,
          0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0.