In [1]:
import math
import hess
import matplotlib.pyplot as plt
import numpy as np
from hess.nets import MaskedNetLinear, SubNetLinear

import hess.net_utils as net_utils
import torch
from torch import nn

In [2]:
torch.random.manual_seed(10)

<torch._C.Generator at 0x122f30750>

In [3]:
n_hidden = 5
width = 1024

In [4]:
model = SubNetLinear(2, 1, n_layers=n_hidden, k=width)

In [5]:
hess.net_utils.set_model_prune_rate(model, 0.6)

==> Setting prune rate of network to 0.6
==> Setting prune rate of sequential.0 to 0.6
==> Setting prune rate of sequential.2 to 0.6
==> Setting prune rate of sequential.4 to 0.6
==> Setting prune rate of sequential.6 to 0.6
==> Setting prune rate of sequential.8 to 0.6
==> Setting prune rate of sequential.10 to 0.6
==> Setting prune rate of sequential.12 to 0.6


In [6]:
hess.net_utils.freeze_model_weights(model)

=> Freezing model weights
==> No gradient to sequential.0.weight
==> No gradient to sequential.0.bias
==> No gradient to sequential.2.weight
==> No gradient to sequential.2.bias
==> No gradient to sequential.4.weight
==> No gradient to sequential.4.bias
==> No gradient to sequential.6.weight
==> No gradient to sequential.6.bias
==> No gradient to sequential.8.weight
==> No gradient to sequential.8.bias
==> No gradient to sequential.10.weight
==> No gradient to sequential.10.bias
==> No gradient to sequential.12.weight
==> No gradient to sequential.12.bias


In [7]:
def twospirals(n_points, noise=.5, random_state=920):
    """
     Returns the two spirals dataset.
    """
    n = np.sqrt(np.random.rand(n_points,1)) * 600 * (2*np.pi)/360
    d1x = -1.5*np.cos(n)*n + np.random.randn(n_points,1) * noise
    d1y =  1.5*np.sin(n)*n + np.random.randn(n_points,1) * noise
    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))),
            np.hstack((np.zeros(n_points),np.ones(n_points))))

In [8]:
X, Y = twospirals(500, noise=1.3)
train_x = torch.FloatTensor(X)
train_y = torch.FloatTensor(Y).unsqueeze(-1)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = torch.nn.BCEWithLogitsLoss()
losses = []

for step in range(2000):
    optimizer.zero_grad()
    outputs = model(train_x)

    loss=loss_func(outputs,train_y)
    print(loss)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()

tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6866, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6599, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6137, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6405, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6087, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6021, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6104, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6109, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6024, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5925, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5950, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5983, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5906, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5882, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5899, grad_fn=<BinaryCrossEntropyWithLogitsBac

KeyboardInterrupt: 

In [None]:
# losses = [l() for l in losses]

In [None]:
plt.figure(figsize=(9, 5))
plt.plot(losses[:500])
plt.ylim(0., 0.75)

In [None]:
buffer = 0.3
h = 0.1
x_min, x_max = train_x[:, 0].min() - buffer, train_x[:, 0].max() + buffer
y_min, y_max = train_x[:, 1].min() - buffer, train_x[:, 1].max() + buffer

xx,yy=np.meshgrid(np.arange(x_min.cpu(), x_max.cpu(), h), 
                  np.arange(y_min.cpu(), y_max.cpu(), h))
in_grid = torch.FloatTensor([xx.ravel(), yy.ravel()]).t()

In [None]:
pred = torch.sigmoid(model(in_grid).squeeze().cpu()).reshape(xx.shape)

In [None]:
plt.figure(figsize=(15, 10))
plt.contourf(xx, yy, pred.detach(), alpha=0.5)
# plt.title("Temp Scaled", fontsize=24)
plt.title("Classifier", fontsize=24)
plt.colorbar()
plt.scatter(train_x[:, 0].cpu(), train_x[:, 1].cpu(), c=train_y[:, 0].cpu(), cmap=plt.cm.binary)
plt.savefig("./two-spiral-classifier.pdf", bbox_inches="tight")

In [None]:
pred = torch.sigmoid(model(in_grid).squeeze().cpu()).reshape(xx.shape)