In [1]:
from typing import Any
from sklearn.datasets import load_diabetes
import torch

ds: Any = load_diabetes()
X = torch.from_numpy(ds['data']).float()
y = torch.from_numpy(ds['target']).float()

In [2]:
from sngp_torch import RandomFeatureGP
from torch.optim import LBFGS
import torch
from torch.nn import functional as F

gp = RandomFeatureGP(X.shape[1], 1, 'mse', num_rff=128)
opt = LBFGS(gp.parameters(), line_search_fn='strong_wolfe')

def closure():
    loss = F.mse_loss(gp(X).squeeze(), y)
    loss.backward()
    return loss

opt.step(closure)
opt.zero_grad()

with gp.record_covariance():
    gp(X)

gp.posterior(X)

torch.return_types.linalg_eigh(
eigenvalues=tensor([0.3373, 0.9798, 0.9928, 0.9949, 0.9960, 0.9971, 0.9974, 0.9981, 0.9984,
        0.9997, 0.9999, 0.9999, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.00

_LinAlgError: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 56 is not positive-definite).

In [None]:
X.shape

torch.Size([442, 10])

In [10]:
# from functorch import hessian
from torch.nn.functional import binary_cross_entropy_with_logits
import torch

features = torch.randn(1000, 256).cuda()
weights = torch.randn(256).cuda()
labels = torch.empty(1000).bernoulli(0.5).cuda()

def loss_fn(w):
    return binary_cross_entropy_with_logits(torch.squeeze(features @ w[..., None]), labels)

# loss = binary_cross_entropy_with_logits(logits, labels)


In [6]:
import torch.autograd.functional as F

%timeit F.hessian(loss_fn, weights, vectorize=True)

474 µs ± 7.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
F.hessian(loss_fn, weights, vectorize=True)

tensor([[ 2.5585e-02, -1.0187e-04, -3.9561e-04,  ...,  4.9680e-04,
         -5.4930e-06, -1.9473e-04],
        [-1.0187e-04,  2.7010e-02, -3.9171e-03,  ...,  3.6581e-03,
          8.8218e-04,  1.0315e-03],
        [-3.9561e-04, -3.9171e-03,  2.7941e-02,  ..., -4.0300e-04,
          1.8802e-03,  1.3838e-03],
        ...,
        [ 4.9680e-04,  3.6581e-03, -4.0300e-04,  ...,  2.7453e-02,
          7.9008e-04,  2.1069e-03],
        [-5.4930e-06,  8.8218e-04,  1.8802e-03,  ...,  7.9008e-04,
          2.8246e-02, -3.9255e-04],
        [-1.9473e-04,  1.0315e-03,  1.3838e-03,  ...,  2.1069e-03,
         -3.9255e-04,  2.7641e-02]], device='cuda:0')

In [38]:
foo = torch.rand(512, 512).cuda()
foo = foo @ foo.t()
%timeit torch.linalg.cholesky_ex(foo, upper=False)

515 µs ± 748 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
def fast_hessian(w):
    logits = features @ w[..., None]
    probs = logits.sigmoid()
    x = torch.sqrt(probs * (1 - probs)) * features
    return x.T @ x

fast_hessian(weights)

tensor([[ 2.5585e+01, -1.0187e-01, -3.9561e-01,  ...,  4.9680e-01,
         -5.4931e-03, -1.9473e-01],
        [-1.0187e-01,  2.7010e+01, -3.9171e+00,  ...,  3.6581e+00,
          8.8218e-01,  1.0315e+00],
        [-3.9561e-01, -3.9171e+00,  2.7941e+01,  ..., -4.0300e-01,
          1.8802e+00,  1.3838e+00],
        ...,
        [ 4.9680e-01,  3.6581e+00, -4.0300e-01,  ...,  2.7453e+01,
          7.9008e-01,  2.1069e+00],
        [-5.4931e-03,  8.8218e-01,  1.8802e+00,  ...,  7.9008e-01,
          2.8246e+01, -3.9255e-01],
        [-1.9473e-01,  1.0315e+00,  1.3838e+00,  ...,  2.1069e+00,
         -3.9255e-01,  2.7641e+01]], device='cuda:0')