In [2]:
import os
os.chdir("/home/debiansigma/repos/posteriors/tests/laplace")
os.chdir("..")

from functools import partial
import torch
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call, hessian
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors import tree_size, diag_normal_log_prob
from posteriors.laplace import dense_hessian

from scenarios import TestModel


def normal_log_likelihood(y, y_pred):
    return (
        Normal(y_pred, 1, validate_args=False).log_prob(y).sum(dim=-1)
    )  # validate args introduces control flows not yet supported in torch.func.vmap


def log_posterior_n(params, batch, model, n_data):
    y_pred = functional_call(model, params, batch[0])
    return diag_normal_log_prob(params, mean=0.0, sd_diag=1.0) + normal_log_likelihood(
        batch[1], y_pred
    ) * n_data, torch.tensor([])


def test_dense_fisher_vmap():
    torch.manual_seed(42)
    model = TestModel()

    xs = torch.randn(100, 10)
    ys = model(xs)

    batch_size=2

    dataloader = DataLoader(
        TensorDataset(xs, ys),
        batch_size=batch_size,
    )

    def log_posterior(p, b):
        return log_posterior_n(p, b, model, len(xs))[0].mean(), torch.tensor([])

    params = dict(model.named_parameters())

    # Test inplace = False
    transform = dense_hessian.build(log_posterior)
    laplace_state = transform.init(params)
    laplace_state_prec_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, rescale=batch_size/xs.size()[0], inplace=False)

    flat_params, params_unravel = tree_ravel(params)
    
    num_params = tree_size(params)
    expected = torch.zeros((num_params, num_params))
    for x, y in zip(xs, ys):
        with torch.no_grad():
            neg_log_p = lambda p: - log_posterior(params_unravel(p), (x,y))[0]
            hess = hessian(neg_log_p)(flat_params)
        expected += hess / xs.size()[0]

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert not torch.allclose(laplace_state.prec, laplace_state_prec_init)
    print(expected)
    print(laplace_state.prec)

    # Also check full batch
    laplace_state_fb = transform.init(params)
    laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

    assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5)

    # Test inplace = True
    transform = dense_hessian.build(log_posterior)
    laplace_state = transform.init(params)
    laplace_state_prec_diag_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, rescale=batch_size/xs.size()[0], inplace=True)

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert torch.allclose(laplace_state.prec, laplace_state_prec_diag_init, atol=1e-5)

    # Test sampling
    num_samples = 10000
    laplace_state.prec.data += 0.1 * torch.eye(
        num_params
    )  # regularize to ensure PSD and reduce variance

    mean_copy = tree_map(lambda x: x.clone(), laplace_state.params)
    sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt()

    samples = dense_hessian.sample(laplace_state, (num_samples,))

    samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
    samples_sd = tree_map(lambda x: x.std(dim=0), samples)
    samples_sd_flat = tree_ravel(samples_sd)[0]

    for key in samples_mean:
        assert samples[key].shape[0] == num_samples
        assert samples[key].shape[1:] == samples_mean[key].shape
        assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1)
        assert torch.allclose(mean_copy[key], laplace_state.params[key])

    assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1)


test_dense_fisher_vmap()

tensor([[101.0000,   5.3003,   3.6100, -15.0475,   5.2662,  18.4212,   2.4490,
         -12.4884,   4.4733,  -1.8450,   4.2531],
        [  5.3003, 108.0566,  -5.4149,  -5.1864,  -5.6259,  -2.9519,  -1.0276,
           2.4567,  22.9299,  -5.7365,   5.1043],
        [  3.6100,  -5.4149, 117.0197,  -7.5019,  35.3815,  12.1161,   4.0535,
           4.4423,   6.7399,  -4.9319, -19.8263],
        [-15.0475,  -5.1864,  -7.5019,  70.4873,  -0.6540,  -2.1473,  -3.5546,
         -24.9956,  -8.7571,  -2.9979,  -0.4580],
        [  5.2662,  -5.6259,  35.3815,  -0.6540,  97.8591,   3.1908,   4.8563,
           3.0504,  -7.1151,  -3.2673, -19.2185],
        [ 18.4212,  -2.9519,  12.1161,  -2.1473,   3.1908,  84.9926,  18.7322,
           0.2138,   5.1761,   7.1025,  -6.9943],
        [  2.4490,  -1.0276,   4.0535,  -3.5546,   4.8563,  18.7322,  87.4687,
           4.5106,   4.3176,  14.7430,  11.8299],
        [-12.4884,   2.4567,   4.4423, -24.9956,   3.0504,   0.2138,   4.5106,
         104.6368,

In [21]:
from functools import partial
import torch
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors import tree_size, empirical_fisher, diag_normal_log_prob
from posteriors.laplace import dense_fisher

from scenarios import TestModel


def normal_log_likelihood(y, y_pred):
    return (
        Normal(y_pred, 1, validate_args=False).log_prob(y).sum(dim=-1)
    )  # validate args introduces control flows not yet supported in torch.func.vmap


def log_posterior_n(params, batch, model, n_data):
    y_pred = functional_call(model, params, batch[0])
    return diag_normal_log_prob(params, mean=0.0, sd_diag=1.0) + normal_log_likelihood(
        batch[1], y_pred
    ) * n_data, torch.tensor([])


def test_dense_fisher_vmap():
    torch.manual_seed(42)
    model = TestModel()

    xs = torch.randn(100, 10)
    ys = model(xs)

    dataloader = DataLoader(
        TensorDataset(xs, ys),
        batch_size=2,
    )

    def log_posterior(p, b):
        return log_posterior_n(p, b, model, len(xs))[0].mean(), torch.tensor([])

    log_posterior_per_sample = torch.vmap(log_posterior, in_dims=(None, 0))

    params = dict(model.named_parameters())

    # Test inplace = False
    transform = dense_fisher.build(log_posterior)
    laplace_state = transform.init(params)
    laplace_state_prec_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, inplace=False)

    num_params = tree_size(params)
    expected = torch.zeros((num_params, num_params))
    for x, y in zip(xs, ys):
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        with torch.no_grad():
            fisher = empirical_fisher(
                lambda p: log_posterior_per_sample(p, (x, y)),
                has_aux=True,
                normalize=False,
            )(params)[0]

        expected += fisher

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert not torch.allclose(laplace_state.prec, laplace_state_prec_init)

    print("eee\n")
    print(expected)
    print("ppp\n")
    print(laplace_state.prec)

    # Also check full batch
    laplace_state_fb = transform.init(params)
    laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

    assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5)

    #  Test per_sample
    log_posterior_per_sample = partial(log_posterior_n, model=model, n_data=len(xs))
    transform_ps = dense_fisher.build(log_posterior_per_sample, per_sample=True)
    laplace_state_ps = transform_ps.init(params)
    for batch in dataloader:
        laplace_state_ps = transform_ps.update(
            laplace_state_ps,
            batch,
        )

    assert torch.allclose(laplace_state_ps.prec, laplace_state_fb.prec, atol=1e-5)

    # Test inplace = True
    transform = dense_fisher.build(log_posterior)
    laplace_state = transform.init(params)
    laplace_state_prec_diag_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, inplace=True)

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert torch.allclose(laplace_state.prec, laplace_state_prec_diag_init, atol=1e-5)

    # Test sampling
    num_samples = 10000
    laplace_state.prec.data += 0.1 * torch.eye(
        num_params
    )  # regularize to ensure PSD and reduce variance

    mean_copy = tree_map(lambda x: x.clone(), laplace_state.params)
    sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt()

    samples = dense_fisher.sample(laplace_state, (num_samples,))

    samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
    samples_sd = tree_map(lambda x: x.std(dim=0), samples)
    samples_sd_flat = tree_ravel(samples_sd)[0]

    for key in samples_mean:
        assert samples[key].shape[0] == num_samples
        assert samples[key].shape[1:] == samples_mean[key].shape
        assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1)
        assert torch.allclose(mean_copy[key], laplace_state.params[key])

    assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1)
test_dense_fisher_vmap()

eee

tensor([[ 7.5550,  6.6453,  7.2144, -2.0363,  7.9845, -1.9044,  1.7539, -4.2317,
          5.1046,  7.6623, -6.3767],
        [ 6.6453,  5.8452,  6.3457, -1.7911,  7.0231, -1.6751,  1.5428, -3.7222,
          4.4900,  6.7397, -5.6089],
        [ 7.2144,  6.3457,  6.8891, -1.9445,  7.6245, -1.8186,  1.6749, -4.0409,
          4.8745,  7.3168, -6.0892],
        [-2.0363, -1.7911, -1.9445,  0.5488, -2.1520,  0.5133, -0.4727,  1.1406,
         -1.3758, -2.0652,  1.7187],
        [ 7.9845,  7.0231,  7.6245, -2.1520,  8.4384, -2.0127,  1.8536, -4.4723,
          5.3948,  8.0979, -6.7392],
        [-1.9044, -1.6751, -1.8186,  0.5133, -2.0127,  0.4801, -0.4421,  1.0667,
         -1.2868, -1.9315,  1.6074],
        [ 1.7539,  1.5428,  1.6749, -0.4727,  1.8536, -0.4421,  0.4072, -0.9824,
          1.1851,  1.7788, -1.4804],
        [-4.2317, -3.7222, -4.0409,  1.1406, -4.4723,  1.0667, -0.9824,  2.3703,
         -2.8592, -4.2918,  3.5717],
        [ 5.1046,  4.4900,  4.8745, -1.3758,  5.394

In [25]:
from functools import partial
import torch
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.laplace import dense_ggn

from scenarios import TestModel


def normal_log_likelihood(y_pred, batch):
    y = batch[1]
    return (
        Normal(y_pred, 1, validate_args=False).log_prob(y).sum()
    )  # validate args introduces control flows not yet supported in torch.func.vmap


def forward_m(params, b, model):
    y_pred = functional_call(model, params, b[0])
    return y_pred, torch.tensor([])


def test_ggn_vmap():
    torch.manual_seed(42)
    model = TestModel()

    xs = torch.randn(100, 10)
    ys = model(xs)

    dataloader = DataLoader(
        TensorDataset(xs, ys),
        batch_size=20,
    )

    forward = partial(forward_m, model=model)

    params = dict(model.named_parameters())

    # Test inplace = False
    transform = dense_ggn.build(forward, normal_log_likelihood)
    laplace_state = transform.init(params)
    laplace_state_prec_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, inplace=False)

    flat_params, unravel_fn = tree_ravel(params)

    expected = torch.zeros((flat_params.shape[0], flat_params.shape[0]))
    for x, y in zip(xs, ys):
        with torch.no_grad():
            z = forward(params, (x, y))[0]
            J = torch.func.jacrev(lambda fp: forward(unravel_fn(fp), (x, y)))(
                flat_params
            )[0]
            H = torch.func.hessian(lambda zt: normal_log_likelihood(zt, (x, y)))(z)
            G = J.T @ H @ J
        expected -= G

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert not torch.allclose(laplace_state.prec, laplace_state_prec_init)

    print("eee\n")
    print(expected)
    print("ppp\n")
    print(laplace_state.prec)

    # Also check full batch
    laplace_state_fb = transform.init(params)
    laplace_state_fb = transform.update(laplace_state_fb, (xs, ys))

    assert torch.allclose(expected, laplace_state_fb.prec, atol=1e-5)

    # Test inplace = True
    laplace_state = transform.init(params)
    laplace_state_prec_init = laplace_state.prec
    for batch in dataloader:
        laplace_state = transform.update(laplace_state, batch, inplace=True)

    assert torch.allclose(expected, laplace_state.prec, atol=1e-5)
    assert torch.allclose(laplace_state.prec, laplace_state_prec_init)

    # Test sampling
    num_samples = 10000
    laplace_state.prec.data += 0.1 * torch.eye(
        flat_params.shape[0]
    )  # regularize to ensure PSD and reduce variance

    mean_copy = tree_map(lambda x: x.clone(), laplace_state.params)
    sd_flat = torch.diag(torch.linalg.inv(laplace_state.prec)).sqrt()

    samples = dense_ggn.sample(laplace_state, (num_samples,))

    samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
    samples_sd = tree_map(lambda x: x.std(dim=0), samples)
    samples_sd_flat = tree_ravel(samples_sd)[0]

    for key in samples_mean:
        assert samples[key].shape[0] == num_samples
        assert samples[key].shape[1:] == samples_mean[key].shape
        assert torch.allclose(samples_mean[key], laplace_state.params[key], atol=1e-1)
        assert torch.allclose(mean_copy[key], laplace_state.params[key])

    assert torch.allclose(sd_flat, samples_sd_flat, atol=1e-1)

test_ggn_vmap()

eee

tensor([[100.0000,   5.3003,   3.6100, -15.0475,   5.2662,  18.4212,   2.4490,
         -12.4884,   4.4733,  -1.8450,   4.2531],
        [  5.3003, 107.0566,  -5.4149,  -5.1864,  -5.6259,  -2.9519,  -1.0276,
           2.4567,  22.9299,  -5.7365,   5.1043],
        [  3.6100,  -5.4149, 116.0197,  -7.5019,  35.3815,  12.1161,   4.0535,
           4.4423,   6.7399,  -4.9319, -19.8263],
        [-15.0475,  -5.1864,  -7.5019,  69.4874,  -0.6540,  -2.1473,  -3.5546,
         -24.9956,  -8.7571,  -2.9979,  -0.4580],
        [  5.2662,  -5.6259,  35.3815,  -0.6540,  96.8592,   3.1908,   4.8563,
           3.0504,  -7.1151,  -3.2673, -19.2185],
        [ 18.4212,  -2.9519,  12.1161,  -2.1473,   3.1908,  83.9926,  18.7322,
           0.2138,   5.1761,   7.1025,  -6.9943],
        [  2.4490,  -1.0276,   4.0535,  -3.5546,   4.8563,  18.7322,  86.4687,
           4.5106,   4.3176,  14.7430,  11.8299],
        [-12.4884,   2.4567,   4.4423, -24.9956,   3.0504,   0.2138,   4.5106,
         103.