In [36]:
import os
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

from minatar import Environment

from my_dqn import Conv_QNet

In [37]:
game = "breakout"
env = Environment(game)

proj_dir = os.path.abspath(".")
default_save_folder = os.path.join(proj_dir, "checkpoints", game)
file_name = os.path.join(default_save_folder, game + "_model")

state_shape = env.state_shape()

in_features = (state_shape[2], state_shape[0], state_shape[1])
in_channels = in_features[0]
num_actions = env.num_actions()

model = Conv_QNet(in_features, in_channels, num_actions)
# checkpoint = torch.load(file_name)
# model.load_state_dict(checkpoint["policy_model_state_dict"])


In [38]:
in_features

(4, 10, 10)

In [4]:
module = model.features[0]
print(list(module.named_parameters())) # contains weights and biases
# print(list(module.named_buffers()))

[('weight', Parameter containing:
tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.1248],
          [ 0.0132, -0.0921,  0.1108]],

         [[-0.1398,  0.0166,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.1419, -0.0343, -0.1584],
          [ 0.1015,  0.0611,  0.1528]],

         [[-0.1534,  0.0868,  0.1099],
          [-0.0140, -0.0388,  0.0737],
          [ 0.0507,  0.1321,  0.0193]]],


        [[[ 0.1071, -0.1104,  0.1132],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.1311, -0.1033,  0.1515]],

         [[-0.0742, -0.0330,  0.0952],
          [ 0.1079, -0.0233, -0.0577],
          [-0.0330,  0.0717, -0.0036]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.1640, -0.0969, -0.0323]],

         [[-0.1384, -0.1427, -0.0439],
          [-0.1355,  0.1070, -0.1497],
          [ 0.0363,  0.1412,  0.1173]]],


        [[

In [5]:
print(list(module.named_modules()))

[('', Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1)))]


In [6]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1))

In [7]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1410, -0.0647,  0.0736,  0.0949,  0.0799, -0.1207, -0.0145,  0.0371,
         0.1489,  0.0041,  0.0048, -0.0505, -0.0566,  0.0342, -0.1200,  0.1447],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.1248],
          [ 0.0132, -0.0921,  0.1108]],

         [[-0.1398,  0.0166,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.1419, -0.0343, -0.1584],
          [ 0.1015,  0.0611,  0.1528]],

         [[-0.1534,  0.0868,  0.1099],
          [-0.0140, -0.0388,  0.0737],
          [ 0.0507,  0.1321,  0.0193]]],


        [[[ 0.1071, -0.1104,  0.1132],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.1311, -0.1033,  0.1515]],

         [[-0.0742, -0.0330,  0.0952],
          [ 0.1079, -0.0233, -0.0577],
          [-0.0330,  0.0717, -0.0036]],

         [[ 0.0943, -0.1

In [8]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 0.],
          [0., 0., 1.]],

         [[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 0., 0.]],

         [[1., 0., 0.],
          [0., 0., 0.],
          [1., 1., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[0., 0., 0.],
          [1., 0., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]],

         [[1., 1., 0.],
          [0., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.]],

         [[1., 

In [9]:
print(module.weight)

tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.0000],
          [ 0.0000, -0.0000,  0.1108]],

         [[-0.1398,  0.0000,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.0000, -0.0343, -0.1584],
          [ 0.1015,  0.0000,  0.0000]],

         [[-0.1534,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0507,  0.1321,  0.0000]]],


        [[[ 0.1071, -0.0000,  0.0000],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.0000, -0.1033,  0.1515]],

         [[-0.0000, -0.0000,  0.0000],
          [ 0.1079, -0.0000, -0.0577],
          [-0.0330,  0.0717, -0.0000]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.0000, -0.0969, -0.0000]],

         [[-0.1384, -0.1427, -0.0000],
          [-0.0000,  0.0000, -0.1497],
          [ 0.0000,  0.1412,  0.1173]]],


        [[[-0.0163,  0.0721, -0.1167],
     

In [10]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x0000029FD6F265C0>)])


In [11]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1))

In [12]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.1248],
          [ 0.0132, -0.0921,  0.1108]],

         [[-0.1398,  0.0166,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.1419, -0.0343, -0.1584],
          [ 0.1015,  0.0611,  0.1528]],

         [[-0.1534,  0.0868,  0.1099],
          [-0.0140, -0.0388,  0.0737],
          [ 0.0507,  0.1321,  0.0193]]],


        [[[ 0.1071, -0.1104,  0.1132],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.1311, -0.1033,  0.1515]],

         [[-0.0742, -0.0330,  0.0952],
          [ 0.1079, -0.0233, -0.0577],
          [-0.0330,  0.0717, -0.0036]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.1640, -0.0969, -0.0323]],

         [[-0.1384, -0.1427, -0.0439],
          [-0.1355,  0.1070, -0.1497],
          [ 0.0363,  0.1412,  0.1173]]],


     

In [13]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 0.],
          [0., 0., 1.]],

         [[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 0., 0.]],

         [[1., 0., 0.],
          [0., 0., 0.],
          [1., 1., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[0., 0., 0.],
          [1., 0., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]],

         [[1., 1., 0.],
          [0., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.]],

         [[1., 

In [14]:
print(module.bias)

tensor([ 0.1410, -0.0647,  0.0736,  0.0949,  0.0799, -0.1207, -0.0000,  0.0371,
         0.1489,  0.0000,  0.0000, -0.0505, -0.0566,  0.0342, -0.1200,  0.1447],
       grad_fn=<MulBackward0>)


In [15]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x0000029FD6F265C0>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x0000029FD6F26560>)])


### Iterative Pruning

In [16]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.0000],
          [ 0.0000, -0.0000,  0.1108]],

         [[-0.1398,  0.0000,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.0000, -0.0343, -0.1584],
          [ 0.1015,  0.0000,  0.0000]],

         [[-0.1534,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0507,  0.1321,  0.0000]]],


        [[[ 0.1071, -0.0000,  0.0000],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.0000, -0.1033,  0.1515]],

         [[-0.0000, -0.0000,  0.0000],
          [ 0.1079, -0.0000, -0.0577],
          [-0.0330,  0.0717, -0.0000]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.0000, -0.0969, -0.0000]],

         [[-0.1384, -0.1427, -0.0000],
          [-0.0000,  0.0000, -0.1497],
          [ 0.0000,  0.1412,  0.1173]]],


        [[[-0.0163,  0.0721, -0.1167],
     

In [20]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook)) 
# print(hook)

[<torch.nn.utils.prune.RandomUnstructured object at 0x0000029FD6F265C0>, <torch.nn.utils.prune.LnStructured object at 0x0000029FD6F25C60>]


In [21]:
print(model.state_dict().keys())

odict_keys(['features.0.weight_orig', 'features.0.bias_orig', 'features.0.weight_mask', 'features.0.bias_mask', 'fc.0.weight', 'fc.0.bias', 'fc.2.weight', 'fc.2.bias'])


In [22]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.1248],
          [ 0.0132, -0.0921,  0.1108]],

         [[-0.1398,  0.0166,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.1419, -0.0343, -0.1584],
          [ 0.1015,  0.0611,  0.1528]],

         [[-0.1534,  0.0868,  0.1099],
          [-0.0140, -0.0388,  0.0737],
          [ 0.0507,  0.1321,  0.0193]]],


        [[[ 0.1071, -0.1104,  0.1132],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.1311, -0.1033,  0.1515]],

         [[-0.0742, -0.0330,  0.0952],
          [ 0.1079, -0.0233, -0.0577],
          [-0.0330,  0.0717, -0.0036]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.1640, -0.0969, -0.0323]],

         [[-0.1384, -0.1427, -0.0439],
          [-0.1355,  0.1070, -0.1497],
          [ 0.0363,  0.1412,  0.1173]]],


     

In [23]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 0.],
          [0., 0., 1.]],

         [[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 0., 0.]],

         [[1., 0., 0.],
          [0., 0., 0.],
          [1., 1., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[0., 0., 0.],
          [1., 0., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]],

         [[1., 1., 0.],
          [0., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 

In [24]:
print(module.weight)

tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.0000],
          [ 0.0000, -0.0000,  0.1108]],

         [[-0.1398,  0.0000,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.0000, -0.0343, -0.1584],
          [ 0.1015,  0.0000,  0.0000]],

         [[-0.1534,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0507,  0.1321,  0.0000]]],


        [[[ 0.1071, -0.0000,  0.0000],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.0000, -0.1033,  0.1515]],

         [[-0.0000, -0.0000,  0.0000],
          [ 0.1079, -0.0000, -0.0577],
          [-0.0330,  0.0717, -0.0000]],

         [[ 0.0943, -0.1078,  0.0957],
          [ 0.1087,  0.0862,  0.0176],
          [-0.0000, -0.0969, -0.0000]],

         [[-0.1384, -0.1427, -0.0000],
          [-0.0000,  0.0000, -0.1497],
          [ 0.0000,  0.1412,  0.1173]]],


        [[[-0.0163,  0.0721, -0.1167],
     

In [25]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([ 0.1410, -0.0647,  0.0736,  0.0949,  0.0799, -0.1207, -0.0145,  0.0371,
         0.1489,  0.0041,  0.0048, -0.0505, -0.0566,  0.0342, -0.1200,  0.1447],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0659, -0.1499, -0.0302],
          [ 0.1369, -0.0479,  0.0000],
          [ 0.0000, -0.0000,  0.1108]],

         [[-0.1398,  0.0000,  0.1013],
          [-0.0023, -0.0939,  0.1633],
          [ 0.1184,  0.0621,  0.1172]],

         [[ 0.1647, -0.0522, -0.1057],
          [ 0.0000, -0.0343, -0.1584],
          [ 0.1015,  0.0000,  0.0000]],

         [[-0.1534,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [ 0.0507,  0.1321,  0.0000]]],


        [[[ 0.1071, -0.0000,  0.0000],
          [ 0.1522,  0.0296,  0.0933],
          [ 0.0000, -0.1033,  0.1515]],

         [[-0.0000, -0.0000,  0.0000],
          [ 0.1079, -0.0000, -0.0577],
          [-0.0330,  0.0717, -0.0000]],

         [[ 0.0943, -0.1

In [26]:
print(list(module.named_buffers()))

[('bias_mask', tensor([1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1.]))]


In [27]:
in_features = (state_shape[2], state_shape[0], state_shape[1])
in_channels = in_features[0]
num_actions = env.num_actions()

model = Conv_QNet(in_features, in_channels, num_actions)

for name, module in model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['features.0.bias_mask', 'features.0.weight_mask', 'fc.0.weight_mask', 'fc.2.weight_mask'])


### Global pruning

In [30]:
in_features = (state_shape[2], state_shape[0], state_shape[1])
in_channels = in_features[0]
num_actions = env.num_actions()

model = Conv_QNet(in_features, in_channels, num_actions)

parameters_to_prune = (
    (model.features[0], 'weight'),
    (model.fc[0], 'weight'),
    (model.fc[2], 'weight'),

)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [31]:
print(
    "Sparsity in features[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.features[0].weight == 0))
        / float(model.features[0].weight.nelement())
    )
)
print(
    "Sparsity in fc[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc[0].weight == 0))
        / float(model.fc[0].weight.nelement())
    )
)
print(
    "Sparsity in fc[2].weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc[2].weight == 0))
        / float(model.fc[2].weight.nelement())
    )
)

Sparsity in features[0].weight: 3.65%
Sparsity in fc[0].weight: 20.15%
Sparsity in fc[2].weight: 6.90%


In [33]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

In [35]:
in_features = (state_shape[2], state_shape[0], state_shape[1])
in_channels = in_features[0]
num_actions = env.num_actions()

model = Conv_QNet(in_features, in_channels, num_actions)

foobar_unstructured(model.features[0], name='bias')

print(model.features[0].bias_mask)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])


In [129]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd

import torch.nn.utils.prune as prune

from minatar import Environment

import seaborn as sns
import matplotlib.pyplot as plt

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"


class Min_Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.channels = 4
        
        # conv layers
        self.conv_l = nn.Conv2d(self.channels, 4, kernel_size=3, stride=1)

        self.fc = nn.Linear(self.channels * 6 * 6, 64)

    def forward(self, x):
        x = x.float()
        x = self.conv_l(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


In [138]:

ex_net = Min_Net()
ex_net

ex_in = torch.rand(1, 4, 8, 8)
ex_net(ex_in)

module = ex_net.conv_l
print(list(module.named_parameters()))
# print(list(module.named_modules()))

[('weight', Parameter containing:
tensor([[[[-0.1359,  0.1507,  0.0952],
          [ 0.0433,  0.1151,  0.0142],
          [ 0.0966,  0.1308,  0.1424]],

         [[-0.1265, -0.1062, -0.1450],
          [ 0.1567,  0.1410,  0.0960],
          [ 0.1469,  0.1657, -0.0208]],

         [[-0.0699, -0.1558, -0.0226],
          [ 0.0231, -0.0439, -0.1445],
          [-0.1067,  0.0490,  0.1484]],

         [[-0.0908, -0.0513, -0.1067],
          [-0.0910,  0.0658, -0.0123],
          [-0.0605, -0.0792,  0.0628]]],


        [[[-0.1022,  0.0258, -0.0934],
          [ 0.0703,  0.0788,  0.1502],
          [ 0.1577,  0.0669, -0.1007]],

         [[-0.1082, -0.0067,  0.0499],
          [ 0.0172,  0.0131,  0.0789],
          [ 0.1048,  0.0077,  0.0466]],

         [[-0.0791,  0.1320, -0.0260],
          [ 0.0056,  0.1209, -0.1407],
          [-0.1059, -0.1074, -0.0478]],

         [[ 0.1194, -0.0973, -0.1364],
          [ 0.1022,  0.1560,  0.0664],
          [-0.1407,  0.0970,  0.0751]]],


        [[

In [145]:
# prune.random_unstructured(module, name="weight", amount=0.25)
# prune.l1_unstructured(module, name="weight", amount=0.25)
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=1)

print(list(module.named_parameters()))


[('bias', Parameter containing:
tensor([-0.0241, -0.1004,  0.0716,  0.0496], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1359,  0.1507,  0.0952],
          [ 0.0433,  0.1151,  0.0142],
          [ 0.0966,  0.1308,  0.1424]],

         [[-0.1265, -0.1062, -0.1450],
          [ 0.1567,  0.1410,  0.0960],
          [ 0.1469,  0.1657, -0.0208]],

         [[-0.0699, -0.1558, -0.0226],
          [ 0.0231, -0.0439, -0.1445],
          [-0.1067,  0.0490,  0.1484]],

         [[-0.0908, -0.0513, -0.1067],
          [-0.0910,  0.0658, -0.0123],
          [-0.0605, -0.0792,  0.0628]]],


        [[[-0.1022,  0.0258, -0.0934],
          [ 0.0703,  0.0788,  0.1502],
          [ 0.1577,  0.0669, -0.1007]],

         [[-0.1082, -0.0067,  0.0499],
          [ 0.0172,  0.0131,  0.0789],
          [ 0.1048,  0.0077,  0.0466]],

         [[-0.0791,  0.1320, -0.0260],
          [ 0.0056,  0.1209, -0.1407],
          [-0.1059, -0.1074, -0.0478]],

         [[ 0.1194, -0.0973,

In [146]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]],

         [[0., 

In [141]:
pruned_mask = list(module.named_buffers())[0][1]
pruned_mask.size()

torch.Size([4, 4, 3, 3])

In [144]:
pruned_mask[1]

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]]])

In [170]:
game = "breakout"
env = Environment(game)

proj_dir = os.path.abspath(".")
default_save_folder = os.path.join(proj_dir, "checkpoints", game)
file_name = os.path.join(default_save_folder, game + "_model")

state_shape = env.state_shape()

in_features = (state_shape[2], state_shape[0], state_shape[1])
in_channels = in_features[0]
num_actions = env.num_actions()

model = Conv_QNet(in_features, in_channels, num_actions)

for name, module in model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name="weight", amount=0.1, n=2, dim=1)
        
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.1)
        


<generator object Module.parameters at 0x0000029FDEF73220>
<generator object Module.parameters at 0x0000029FDEF73220>
<generator object Module.parameters at 0x0000029FDEF73220>


In [163]:
for name, module in model.named_modules():
    
    print(isinstance(module, torch.nn.Sequential))

False
False
False
False
False
False
False
False


In [164]:
list(model.named_modules())

[('',
  Conv_QNet(
    (features): Sequential(
      (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (fc): Sequential(
      (0): Linear(in_features=1024, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=6, bias=True)
    )
  )),
 ('features',
  Sequential(
    (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )),
 ('features.0', Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1))),
 ('features.1', ReLU()),
 ('fc',
  Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=6, bias=True)
  )),
 ('fc.0', Linear(in_features=1024, out_features=128, bias=True)),
 ('fc.1', ReLU()),
 ('fc.2', Linear(in_features=128, out_features=6, bias=True))]

In [168]:
for name, module in model.named_modules():
    print(name, type(module))

 <class 'my_dqn.Conv_QNet'>
features <class 'torch.nn.modules.container.Sequential'>
features.0 <class 'torch.nn.modules.conv.Conv2d'>
features.1 <class 'torch.nn.modules.activation.ReLU'>
fc <class 'torch.nn.modules.container.Sequential'>
fc.0 <class 'torch.nn.modules.linear.Linear'>
fc.1 <class 'torch.nn.modules.activation.ReLU'>
fc.2 <class 'torch.nn.modules.linear.Linear'>
