In [12]:
import torch
import hess
import math
import numpy as np
from hess.nets import Transformer
from hess.eigs import hessian_eigenpairs

In [13]:
nx = 500
train_x = torch.linspace(0, 10, nx).unsqueeze(-1)
train_y = torch.sin(train_x * math.pi) + torch.sin(train_x)

In [14]:
optimus = Transformer(train_x, train_y, n_hidden=2, hidden_size=10,
                     activation=torch.nn.Tanh())


In [19]:
use_cuda = torch.cuda.is_available()
if use_cuda:
    torch.cuda.set_device(4)
    optimus = optimus.cuda()
    train_x, train_y = train_x.cuda(), train_y.cuda()
    
device = train_x.device

In [20]:
device

device(type='cuda', index=4)

In [21]:
trained_pars = torch.load("../hess/saved-models/toy_regression.pt", map_location=device)

In [32]:
optimus.net.load_state_dict(trained_pars)

<All keys matched successfully>

In [33]:
hessian = torch.load("../hess/saved-models/toy_hessian.pt", map_location=device)

In [34]:
e_val, e_vec = np.linalg.eig(hessian.cpu())
idx = e_val.argsort()[::-1]   
e_val = torch.FloatTensor(e_val[idx].real)
e_vec = torch.FloatTensor(e_vec[:,idx].real)

## Try the lanczos stuff

In [35]:
%pdb

Automatic pdb calling has been turned ON


In [52]:
test_vals, test_vecs = hessian_eigenpairs(optimus.net, n_eigs=200,
                                        criterion=torch.nn.MSELoss(), inputs=train_x,
                                         targets=train_y)

In [56]:
test_vals.sort(descending=True)[0]

tensor([1.7385e+04, 1.4113e+03, 1.1012e+02, 8.8872e+01, 6.8777e+01, 3.8682e+01,
        1.7675e+01, 8.9182e+00, 5.0146e+00, 2.3566e+00, 1.8266e+00, 1.2695e+00,
        1.2015e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+

In [51]:
e_val

tensor([ 1.7385e+04,  1.4113e+03,  1.1012e+02,  8.8872e+01,  6.8777e+01,
         3.8682e+01,  1.7675e+01,  8.9182e+00,  5.0146e+00,  2.3566e+00,
         1.8266e+00,  1.2695e+00,  1.2015e+00,  8.8116e-01,  5.8618e-01,
         4.6501e-01,  3.4789e-01,  2.4220e-01,  2.2982e-01,  1.4210e-01,
         1.3389e-01,  1.1880e-01,  1.1700e-01,  1.0562e-01,  7.8716e-02,
         5.1237e-02,  4.1035e-02,  3.5055e-02,  3.0159e-02,  2.1911e-02,
         1.9493e-02,  1.7075e-02,  1.5414e-02,  1.2765e-02,  9.7085e-03,
         9.5526e-03,  8.4885e-03,  8.1917e-03,  6.9872e-03,  6.0436e-03,
         5.7240e-03,  5.3047e-03,  4.9998e-03,  4.1278e-03,  3.9492e-03,
         3.2965e-03,  3.2525e-03,  2.7953e-03,  2.4126e-03,  2.0824e-03,
         1.7235e-03,  1.5245e-03,  1.3423e-03,  1.2894e-03,  1.1279e-03,
         1.0919e-03,  9.3152e-04,  6.7363e-04,  6.5668e-04,  5.4868e-04,
         4.4450e-04,  4.0606e-04,  3.8114e-04,  2.4106e-04,  1.9018e-04,
         1.6914e-04,  1.3781e-04,  1.2608e-04,  1.0