In [1]:
import math
import torch
import numpy as np
import pickle
from torch import nn

import hess
import hess.net_utils as net_utils
import hess.utils as utils
from hess.nets import MaskedNetLinear, SubNetLinear
# from hess.nets import MaskedLayerLinear, SubLayerLinear

In [2]:
def twospirals(n_points, noise=.5, random_state=920):
    """
     Returns the two spirals dataset.
    """
    n = np.sqrt(np.random.rand(n_points,1)) * 600 * (2*np.pi)/360
    d1x = -1.5*np.cos(n)*n + np.random.randn(n_points,1) * noise
    d1y =  1.5*np.sin(n)*n + np.random.randn(n_points,1) * noise
    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))),
            np.hstack((np.zeros(n_points),np.ones(n_points))))

In [3]:
X, Y = twospirals(500, noise=1.3)
train_x = torch.FloatTensor(X)
train_y = torch.FloatTensor(Y).unsqueeze(-1)

###################################
## Set up nets and match weights ##
###################################

n_hidden = 5
width = 15

subnet_model = SubNetLinear(in_dim=2, out_dim=1, n_layers=n_hidden, k=width)
masked_model = MaskedNetLinear(in_dim=2, out_dim=1, n_layers=n_hidden, k=width)

hess.net_utils.set_model_prune_rate(subnet_model, 0.5)
hess.net_utils.freeze_model_weights(subnet_model)

weights = net_utils.get_weights_from_subnet(subnet_model)

net_utils.apply_weights(masked_model, weights)
mask = net_utils.get_mask_from_subnet(subnet_model)
net_utils.apply_mask(masked_model, mask)
mask = utils.flatten(mask)
print(mask)

use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(3)
    train_x, train_y = train_x.cuda(), train_y.cuda()
    subnet_model = subnet_model.cuda()
    masked_model = masked_model.cuda()

==> Setting prune rate of network to 0.5
==> Setting prune rate of sequential.0 to 0.5
==> Setting prune rate of sequential.2 to 0.5
==> Setting prune rate of sequential.4 to 0.5
==> Setting prune rate of sequential.6 to 0.5
==> Setting prune rate of sequential.8 to 0.5
==> Setting prune rate of sequential.10 to 0.5
==> Setting prune rate of sequential.12 to 0.5
=> Freezing model weights
==> No gradient to sequential.0.weight
==> No gradient to sequential.0.bias
==> No gradient to sequential.2.weight
==> No gradient to sequential.2.bias
==> No gradient to sequential.4.weight
==> No gradient to sequential.4.bias
==> No gradient to sequential.6.weight
==> No gradient to sequential.6.bias
==> No gradient to sequential.8.weight
==> No gradient to sequential.8.bias
==> No gradient to sequential.10.weight
==> No gradient to sequential.10.bias
==> No gradient to sequential.12.weight
==> No gradient to sequential.12.bias
==> Applied Weights
==> Applied Mask
tensor([1., 1., 1.,  ..., 1., 0., 1.

In [4]:
print(sum(p.numel() for p in masked_model.parameters()))

1261


In [5]:
######################
## Train the Subnet ##
######################

optimizer = torch.optim.Adam(subnet_model.parameters(), lr=0.01)
loss_func = torch.nn.BCEWithLogitsLoss()
eigs_every = 10
n_eigs = 100
eigs_out = []

for step in range(1000):
    optimizer.zero_grad()
    outputs = subnet_model(train_x)

    loss=loss_func(outputs,train_y)
    print(loss)
    loss.backward()
    optimizer.step()

    if step % eigs_every == 0:
        mask = net_utils.get_mask_from_subnet(subnet_model)
        net_utils.apply_mask(masked_model, mask)
        mask = utils.flatten(mask)
        print("mask shape = ", mask.numel())

        eigs = utils.get_hessian_eigs(loss_func, masked_model, mask=mask,
                                      n_eigs=n_eigs, train_x=train_x,
                                      train_y=train_y)

        eigs_out.append(eigs)

tensor(0.6934, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
==> Applied Mask
mask shape =  1170
padded rhs shape =  torch.Size([1261, 1])
mask shape =  torch.Size([1170])


IndexError: The shape of the mask [1170] at index 0does not match the shape of the indexed tensor [1261, 1] at index 0

In [None]:
fpath = "./saved-subnet-hessian/"
fname = "subnet_eigs.pkl"

with open(fpath + fname, 'wb') as f:
    pickle.dump(eigs_out, f)

fname = "subnet_model.pt"
torch.save(subnet_model.state_dict(), fpath + fname)

fname = "masked_model.pt"
torch.save(masked_model.state_dict(), fpath + fname)