In [1]:
import math
import hess
import matplotlib.pyplot as plt
import numpy as np
from hess.nets import SubnetConv, SubnetLinear
import hess.net_utils as net_utils
import torch
from torch import nn

In [2]:
class simple_net(nn.Module):
    """
    Small MLP
    """
    def __init__(self, in_dim, out_dim, k=16,
                 n_layers=5, kernel_size=3,
                activation=nn.ReLU()):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        module = nn.ModuleList()

        module.append(SubnetLinear(in_dim, k))
        module.append(activation)
        
        for ll in range(n_layers-1):
            module.append(SubnetLinear(k, k))
            module.append(activation)
        
        module.append(SubnetLinear(k, k))
        module.append(activation)
        module.append(SubnetLinear(k, out_dim))
        self.sequential = nn.Sequential(*module)        
            
    def forward(self,x):
        return self.sequential(x)


In [3]:
model = simple_net(2, 1, n_layers=10, k=32)

In [4]:
hess.net_utils.set_model_prune_rate(model, 0.5)

==> Setting prune rate of network to 0.5
==> Setting prune rate of sequential.0 to 0.5
==> Setting prune rate of sequential.2 to 0.5
==> Setting prune rate of sequential.4 to 0.5
==> Setting prune rate of sequential.6 to 0.5
==> Setting prune rate of sequential.8 to 0.5
==> Setting prune rate of sequential.10 to 0.5
==> Setting prune rate of sequential.12 to 0.5
==> Setting prune rate of sequential.14 to 0.5
==> Setting prune rate of sequential.16 to 0.5
==> Setting prune rate of sequential.18 to 0.5
==> Setting prune rate of sequential.20 to 0.5
==> Setting prune rate of sequential.22 to 0.5


In [5]:
hess.net_utils.freeze_model_weights(model)

=> Freezing model weights
==> No gradient to sequential.0.weight
==> No gradient to sequential.0.bias
==> No gradient to sequential.2.weight
==> No gradient to sequential.2.bias
==> No gradient to sequential.4.weight
==> No gradient to sequential.4.bias
==> No gradient to sequential.6.weight
==> No gradient to sequential.6.bias
==> No gradient to sequential.8.weight
==> No gradient to sequential.8.bias
==> No gradient to sequential.10.weight
==> No gradient to sequential.10.bias
==> No gradient to sequential.12.weight
==> No gradient to sequential.12.bias
==> No gradient to sequential.14.weight
==> No gradient to sequential.14.bias
==> No gradient to sequential.16.weight
==> No gradient to sequential.16.bias
==> No gradient to sequential.18.weight
==> No gradient to sequential.18.bias
==> No gradient to sequential.20.weight
==> No gradient to sequential.20.bias
==> No gradient to sequential.22.weight
==> No gradient to sequential.22.bias


In [6]:
def twospirals(n_points, noise=.5, random_state=920):
    """
     Returns the two spirals dataset.
    """
    n = np.sqrt(np.random.rand(n_points,1)) * 600 * (2*np.pi)/360
    d1x = -1.5*np.cos(n)*n + np.random.randn(n_points,1) * noise
    d1y =  1.5*np.sin(n)*n + np.random.randn(n_points,1) * noise
    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))),
            np.hstack((np.zeros(n_points),np.ones(n_points))))

In [7]:
X, Y = twospirals(500, noise=1.3)
train_x = torch.FloatTensor(X)
train_y = torch.FloatTensor(Y).unsqueeze(-1)

In [8]:
for lyr in model.modules():
    if hasattr(lyr, "scores"):
        print(lyr.scores)

Parameter containing:
tensor([[-0.6805, -0.5308],
        [ 0.4431,  0.0765],
        [ 0.3548,  0.4486],
        [-0.3093,  0.4176],
        [-0.6001, -0.1593],
        [ 0.4924, -0.3110],
        [ 0.2195,  0.4171],
        [-0.3360, -0.5974],
        [ 0.7046, -0.5868],
        [-0.6134, -0.4793],
        [ 0.1346,  0.1888],
        [ 0.6655, -0.2576],
        [-0.3166,  0.4170],
        [-0.6417,  0.5027],
        [-0.5255, -0.6265],
        [ 0.1312, -0.2367],
        [-0.2362,  0.1211],
        [-0.5833, -0.3117],
        [-0.5705,  0.4482],
        [-0.1635,  0.2870],
        [ 0.3672,  0.5534],
        [-0.5860,  0.6640],
        [ 0.3357, -0.5757],
        [-0.1606,  0.3436],
        [-0.2939,  0.2051],
        [ 0.2180,  0.3499],
        [-0.0643, -0.0999],
        [-0.3001,  0.5813],
        [-0.6385, -0.4662],
        [ 0.4752,  0.3502],
        [ 0.1441, -0.2650],
        [-0.0102, -0.0730]], requires_grad=True)
Parameter containing:
tensor([[ 0.0253,  0.1739, -0.0218,  ..

In [9]:
for lyr in model.modules():
    if hasattr(lyr, "weight"):
        print(lyr.weight)

Parameter containing:
tensor([[-0.1049, -0.5091],
        [ 0.0008, -0.2151],
        [-0.2647,  0.3022],
        [ 0.0137, -0.5334],
        [-0.1613,  0.0230],
        [ 0.5045, -0.5883],
        [-0.5227,  0.0291],
        [ 0.4694,  0.3930],
        [ 0.2369, -0.0193],
        [ 0.2917, -0.2766],
        [ 0.2325, -0.1015],
        [ 0.1618, -0.7062],
        [ 0.6663, -0.7034],
        [-0.4291, -0.2518],
        [-0.3533,  0.2434],
        [-0.5068, -0.2230],
        [ 0.1746, -0.5608],
        [-0.3467,  0.5507],
        [-0.5462, -0.0882],
        [-0.0271, -0.0289],
        [ 0.1256,  0.0813],
        [ 0.1422,  0.2504],
        [ 0.3903, -0.0893],
        [-0.2465,  0.6611],
        [-0.6960,  0.4050],
        [-0.1336, -0.6650],
        [ 0.3294,  0.2478],
        [-0.5209,  0.0027],
        [-0.0888,  0.1637],
        [-0.2543, -0.4986],
        [-0.4168, -0.3390],
        [ 0.6975, -0.0053]])
Parameter containing:
tensor([[-0.0084,  0.1076,  0.0093,  ...,  0.0231, -0.1419,

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.BCEWithLogitsLoss()


for step in range(1000):
    optimizer.zero_grad()
    outputs = model(train_x)

    loss=loss_func(outputs,train_y)
    print(loss)
    loss.backward()
    optimizer.step()

tensor(0.6939, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6937, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6935, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6933, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6933, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6932, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6598, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6598, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6598, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6596, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6601, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6599, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6602, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6600, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6597, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6597, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6596, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6596, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6599, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6597, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6597, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6600, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6580, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6592, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6591, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6590, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6589, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6581, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6590, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBac

tensor(0.6588, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6583, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6582, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6587, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6589, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6585, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6586, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6584, grad_fn=<BinaryCrossEntropyWithLogitsBac

In [11]:
for lyr in model.modules():
    if hasattr(lyr, "scores"):
        print(lyr.scores)

Parameter containing:
tensor([[-1.1446e+01, -1.0757e+01],
        [ 4.4309e-01,  7.6500e-02],
        [ 1.0777e+01,  1.1239e+01],
        [-2.2667e-03,  1.0999e+01],
        [-1.1210e+01,  1.1734e-03],
        [ 9.0338e+00,  2.3746e+00],
        [ 1.0836e+01, -8.3393e-04],
        [-2.3756e+00, -9.7403e+00],
        [ 1.0862e+01, -3.4907e+00],
        [ 3.4701e-03, -1.0821e+01],
        [ 9.4683e-04,  1.0747e+01],
        [ 1.0667e+01,  3.2435e-01],
        [-3.2528e-01,  2.9647e-01],
        [-1.1314e+01,  1.1160e+01],
        [-1.1401e+01,  2.6962e-03],
        [ 1.3121e-01, -2.3672e-01],
        [ 8.1015e-04,  1.0981e+01],
        [-1.1461e+01,  6.0730e-04],
        [-1.1276e+01,  1.1147e+01],
        [-1.1113e+01,  1.1197e+01],
        [-1.0732e-03,  1.1359e+01],
        [-6.9136e-04,  1.2370e+01],
        [ 1.0586e+01,  1.7236e-02],
        [-1.6056e-01,  3.4361e-01],
        [-2.9393e-01,  2.0509e-01],
        [ 2.1801e-01,  3.4993e-01],
        [-6.4326e-02, -9.9929e-02],
      

In [12]:
for lyr in model.modules():
    if hasattr(lyr, "weight"):
        print(lyr.weight)

Parameter containing:
tensor([[-0.1049, -0.5091],
        [ 0.0008, -0.2151],
        [-0.2647,  0.3022],
        [ 0.0137, -0.5334],
        [-0.1613,  0.0230],
        [ 0.5045, -0.5883],
        [-0.5227,  0.0291],
        [ 0.4694,  0.3930],
        [ 0.2369, -0.0193],
        [ 0.2917, -0.2766],
        [ 0.2325, -0.1015],
        [ 0.1618, -0.7062],
        [ 0.6663, -0.7034],
        [-0.4291, -0.2518],
        [-0.3533,  0.2434],
        [-0.5068, -0.2230],
        [ 0.1746, -0.5608],
        [-0.3467,  0.5507],
        [-0.5462, -0.0882],
        [-0.0271, -0.0289],
        [ 0.1256,  0.0813],
        [ 0.1422,  0.2504],
        [ 0.3903, -0.0893],
        [-0.2465,  0.6611],
        [-0.6960,  0.4050],
        [-0.1336, -0.6650],
        [ 0.3294,  0.2478],
        [-0.5209,  0.0027],
        [-0.0888,  0.1637],
        [-0.2543, -0.4986],
        [-0.4168, -0.3390],
        [ 0.6975, -0.0053]])
Parameter containing:
tensor([[-0.0084,  0.1076,  0.0093,  ...,  0.0231, -0.1419,

In [13]:
mask = net_utils.get_mask_from_subnet(model)

In [14]:
mask

[tensor([[1., 1.],
         [0., 0.],
         [1., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 1.],
         [1., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 0.],
         [1., 1.],
         [1., 0.],
         [0., 0.],
         [0., 1.],
         [1., 0.],
         [1., 1.],
         [1., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [1., 0.],
         [1., 1.],
         [1., 1.],
         [0., 0.],
         [0., 0.]], grad_fn=<GetSubnetBackward>),
 tensor([[1., 1., 1.,  ..., 0., 1., 1.],
         [0., 0., 1.,  ..., 1., 1., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [0., 1., 0.,  ..., 1., 1., 0.],
         [1., 0., 0.,  ..., 1., 1., 0.],
         [1., 0., 1.,  ..., 1., 0., 0.]], grad_fn=<GetSubnetBackward>),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 1

In [16]:
weights = net_utils.get_weights_from_subnet(model)

In [17]:
weights

[Parameter containing:
 tensor([[-0.1049, -0.5091],
         [ 0.0008, -0.2151],
         [-0.2647,  0.3022],
         [ 0.0137, -0.5334],
         [-0.1613,  0.0230],
         [ 0.5045, -0.5883],
         [-0.5227,  0.0291],
         [ 0.4694,  0.3930],
         [ 0.2369, -0.0193],
         [ 0.2917, -0.2766],
         [ 0.2325, -0.1015],
         [ 0.1618, -0.7062],
         [ 0.6663, -0.7034],
         [-0.4291, -0.2518],
         [-0.3533,  0.2434],
         [-0.5068, -0.2230],
         [ 0.1746, -0.5608],
         [-0.3467,  0.5507],
         [-0.5462, -0.0882],
         [-0.0271, -0.0289],
         [ 0.1256,  0.0813],
         [ 0.1422,  0.2504],
         [ 0.3903, -0.0893],
         [-0.2465,  0.6611],
         [-0.6960,  0.4050],
         [-0.1336, -0.6650],
         [ 0.3294,  0.2478],
         [-0.5209,  0.0027],
         [-0.0888,  0.1637],
         [-0.2543, -0.4986],
         [-0.4168, -0.3390],
         [ 0.6975, -0.0053]]), Parameter containing:
 tensor([[-0.0084,  0.107