In [1]:
import torch

from typing import Optional
import logging
from rigl_torch.utils.rigl_utils import *

In [21]:
conv1 = torch.nn.Conv2d(
    in_channels=3,
    out_channels=32,
    kernel_size=(3,3),
    stride=1,
    padding="same"
)


In [11]:
conv1.weight.min()

tensor(-0.1912, grad_fn=<MinBackward1>)

In [6]:
conv1.weight.mean()

tensor(0.0031, grad_fn=<MeanBackward0>)

In [22]:

import math
k = 1 / ( conv1.in_channels * math.prod(conv1.kernel_size) ) 
math.sqrt(k)

0.19245008972987526

In [23]:
mask = torch.ones(size=conv1.weight.shape)
start_of_zeros = mask.numel() //2 
mask = mask.view(-1)
mask[start_of_zeros:] = 0 
idx = torch.randperm(mask.shape[0])
mask= mask[idx].reshape(conv1.weight.shape)

In [24]:
with torch.no_grad():
    conv1.weight *= mask

In [26]:
active_weights_per_neuron = [neuron.sum().item() for neuron in mask]

active_weights_per_neuron

[15.0,
 18.0,
 16.0,
 17.0,
 10.0,
 10.0,
 19.0,
 10.0,
 12.0,
 11.0,
 9.0,
 15.0,
 15.0,
 13.0,
 13.0,
 16.0,
 14.0,
 9.0,
 12.0,
 12.0,
 14.0,
 13.0,
 16.0,
 17.0,
 13.0,
 14.0,
 10.0,
 11.0,
 14.0,
 16.0,
 13.0,
 15.0]

In [27]:
fan_in_tensor = get_fan_in_tensor(mask)
fan_in_tensor

tensor([15, 18, 16, 17, 10, 10, 19, 10, 12, 11,  9, 15, 15, 13, 13, 16, 14,  9,
        12, 12, 14, 13, 16, 17, 13, 14, 10, 11, 14, 16, 13, 15])

In [20]:
with torch.no_grad():
    conv1.weight[1].normal_(0,1)

In [32]:
gain = math.sqrt(2.0)
updated_tensor = conv1.weight.clone()
for i in range(len(mask)):
    fan_in = fan_in_tensor[i]
    std = gain / math.sqrt(fan_in)
    if i == 0:
        neuron_at_zero = updated_tensor[0]
    with torch.no_grad():
        updated_tensor[i] = updated_tensor[i].normal_(0,std)
    if i != 0:
        assert (neuron_at_zero == updated_tensor[0]).all()
        
updated_tensor *= mask
    
        

In [35]:
conv1.weight[mask]

IndexError: tensors used as indices must be long, byte or bool tensors

In [49]:
type(mask)

torch.Tensor

In [48]:
mask = mask.to(torch.bool)

In [50]:
conv1.weight.masked_select(mask)

tensor([ 1.7190e-01, -4.7952e-02, -3.7306e-02, -1.1591e-02,  8.6550e-02,
        -1.5842e-01, -1.4070e-01,  2.6583e-02, -8.3651e-02, -3.8811e-02,
        -8.3743e-02, -1.0348e-01,  1.8720e-02,  1.0755e-02,  1.9022e-01,
        -7.7309e-02, -7.6543e-02,  1.1409e-01,  5.0463e-02, -4.4772e-02,
        -1.5374e-02, -1.7493e-02, -6.9071e-02, -1.8608e-01,  1.2023e-02,
        -1.5188e-01,  5.3804e-02,  1.6484e-02, -1.0767e-01, -8.7878e-02,
        -4.9170e-02,  1.0453e-01,  1.8108e-01,  4.0512e-02, -3.4362e-02,
        -4.7837e-02, -3.3729e-02, -1.8923e-01,  1.2815e-01,  9.9988e-02,
        -2.5742e-03,  1.7486e-01, -1.5750e-01, -1.5425e-01,  1.8690e-01,
         1.0344e-01, -4.9908e-02,  1.5715e-01,  9.6739e-02, -1.4277e-01,
        -1.5022e-01,  1.5513e-01,  1.5798e-01,  3.8821e-02,  7.1917e-02,
        -8.8293e-02,  1.3669e-01, -3.9527e-02, -1.3154e-01, -2.8228e-02,
        -1.2069e-01,  1.2208e-01, -1.8175e-01,  1.5874e-01,  9.6998e-02,
        -7.7406e-02,  4.2566e-02, -1.7303e-01, -6.3

In [None]:
torch.where(
    mask,
    
)

In [33]:
conv1.weight[0]

tensor([[[ 0.0000,  0.1719, -0.0000],
         [-0.0480, -0.0000,  0.0000],
         [-0.0000, -0.0373, -0.0000]],

        [[-0.0116, -0.0000, -0.0000],
         [ 0.0865, -0.1584, -0.1407],
         [ 0.0266, -0.0837, -0.0388]],

        [[-0.0837,  0.0000, -0.1035],
         [ 0.0187,  0.0000, -0.0000],
         [ 0.0108, -0.0000,  0.1902]]], grad_fn=<SelectBackward0>)

In [34]:
updated_tensor[0]

tensor([[[-0.0000, -0.1293,  0.0000],
         [-0.3681,  0.0000, -0.0000],
         [ 0.0000,  0.0626,  0.0000]],

        [[ 0.3123,  0.0000,  0.0000],
         [-0.5981,  0.0905, -0.2842],
         [ 0.2263,  0.4572,  0.0855]],

        [[ 0.4777, -0.0000,  0.3340],
         [ 0.6535, -0.0000,  0.0000],
         [ 0.2555, -0.0000,  0.3668]]], grad_fn=<SelectBackward0>)

In [101]:
conv1.weight.max(), conv1.weight.min()

(tensor(0.1921, grad_fn=<MaxBackward1>),
 tensor(-0.1923, grad_fn=<MinBackward1>))

In [97]:
conv1.weight[0].max(), conv1.weight[0].min()

(tensor(0.1888, grad_fn=<MaxBackward1>),
 tensor(-0.1706, grad_fn=<MinBackward1>))

In [98]:
updated_tensor[0].max(), updated_tensor[0].min()

(tensor(0.4831, grad_fn=<MaxBackward1>),
 tensor(-0.6985, grad_fn=<MinBackward1>))

In [106]:
gain / math.sqrt(fan_in_tensor[0])

0.37796447300922725

In [108]:
updated_tensor[0] = 0

In [110]:
updated_tensor *= mask

In [111]:
updated_tensor

tensor([[[[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.1637,  0.6784],
          [ 0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000]],

         [[-0.2891, -0.5376,  0.0000],
          [-0.0258,  0.5696,  0.0085],
          [-0.2021,  0.2465, -0.0000]],

         [[ 0.0604,  0.3006, -0.0000],
          [-0.0381,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.3074]]],


        [[[-0.3855, -0.0000, -0.3334],
          [ 0.0000,  0.4683, -0.3986],
          [-0.0000,  0.2333,  0.0473]],

         [[-0.2454,  0.2809,  0.0000],
          [ 0.4172, -0.0000,  0.0092],
          [-1.0648,  0.0000,  0.0000]],

         [[-0.0492,  0.4232, -0.1303],
     

In [107]:
math.sqrt(0)

0.0

In [104]:
updated_tensor[0].std()

tensor(0.3611, grad_fn=<StdBackward0>)