In [1]:
import torch

In [2]:
from hess.losses import fisher_trace
from hess.utils import get_hessian

%pdb

Automatic pdb calling has been turned ON


In [3]:
model = torch.nn.Sequential(
            torch.nn.Linear(3, 30),
            torch.nn.ReLU(),
            torch.nn.Linear(30, 10))

In [4]:
input = torch.randn(50, 3)
target = torch.randint(10, (50,))

In [5]:
torch.nn.functional.cross_entropy(model(input), target)

tensor(2.4054, grad_fn=<NllLossBackward>)

In [6]:
true_hessian = get_hessian(train_x=input, train_y=target, loss=torch.nn.functional.cross_entropy, 
                           model=model)

In [21]:
npars = sum(p.numel() for p in model.parameters())
diag_sigma = torch.zeros(npars, 1, requires_grad = True)

In [22]:
optimizer = torch.optim.SGD([diag_sigma], lr = 0.1, momentum = 0.9)

In [23]:
beta = 1e-3
for i in range(5000):
    optimizer.zero_grad()
    (fisher_trace(input, target, diag_sigma, model, beta = beta, samples=1)).backward()
    optimizer.step()

In [24]:
hestimate = beta/torch.nn.functional.softplus(1e-5+diag_sigma).view(-1)

In [25]:
(true_hessian.diag() - hestimate).norm() 

tensor(17.0659, grad_fn=<NormBackward0>)

In [26]:
((true_hessian.diag() - hestimate) / true_hessian.diag()).median()

tensor(-64.5636, grad_fn=<MedianBackward0>)

In [27]:
true_hessian.diag() - hestimate

tensor([-2.0282e-02, -1.9565e-01, -9.2431e-02, -2.0047e-01, -5.8985e-01,
        -3.8746e-01, -1.0891e-01, -2.7125e-01, -8.3081e-02, -4.0694e-01,
        -3.9429e-01, -3.8562e-01, -1.2257e-01, -2.8800e-01, -3.0970e-01,
        -1.8582e-01, -4.0348e-01, -2.1122e-01, -6.0114e-01, -7.4624e-01,
        -7.0665e-01, -2.3865e-01, -3.1685e-01, -2.1733e-01, -1.2838e-01,
        -2.9411e-01, -3.1360e-01, -2.7461e-01,  5.1521e-03, -1.3402e-01,
        -6.5897e-02, -1.3666e-01, -5.3392e-02, -4.0137e-01, -2.2367e-01,
        -3.2909e-01, -1.3384e-01, -3.5503e-01, -3.9724e-01, -2.2636e-01,
         6.0273e-03, -8.6071e-02, -2.6249e-02,  1.1729e-03, -1.9427e-01,
        -3.1402e-01, -2.2426e-01, -3.3761e-01, -3.5432e-01, -4.8201e-01,
        -4.7688e-01, -9.9228e-02, -1.0215e-01,  3.2258e-03, -3.2209e-01,
        -5.2675e-02, -1.5388e-01, -8.0212e-02, -1.2655e-02, -3.3897e-01,
        -1.9524e-01, -3.1715e-01, -3.8932e-01, -3.1381e-01, -1.1930e-01,
        -3.9631e-01, -4.5190e-02, -2.0721e-01, -1.6