In [None]:
%load_ext autoreload
%autoreload
import os
import sys
import jax
import jax.numpy as jnp
from flax import nnx
import math
    
sys.path.insert(0, os.path.abspath(os.path.join('../')))
from geometric_bayesian.models import MLP
from geometric_bayesian.operators import DenseOperator, SymOperator, PSDOperator, DiagOperator
from geometric_bayesian.densities import Normal, MultivariateNormal
from geometric_bayesian.utils import DataLoader, get_sinusoid_example, contour_plot, ggn, array_to_pytree, pytree_to_array

# sys.path.insert(0, os.path.abspath(os.path.join('/Users/balint/Projects/laplax/')))
sys.path.insert(0, os.path.abspath(os.path.join('/home/bernardo/repos/laplax/')))
from laplax.curv import create_ggn_mv
from laplax.util.loader import input_target_split

### Test PSD Operator

In [None]:
size = 10
rng_key = jax.random.key(0)
sca = jax.random.uniform(rng_key, (1, ))
vec = jax.random.uniform(rng_key, (size, ))
mat = jax.random.uniform(rng_key, (size, size))

In [None]:
mat += mat.transpose() + size*jnp.eye(mat.shape[0])
op = PSDOperator(
    op=mat,
    op_type='raw'
)
op_mv = PSDOperator(
    op=lambda v : jnp.matmul(mat, v),
    op_size=size
)

In [None]:
print("dense")
print(jnp.matmul(mat,vec).tolist())
print("op chol")
print(op(vec).tolist())
print("op mv")
print(op_mv(vec).tolist())

In [None]:
print("dense")
print(jnp.linalg.solve(mat, vec).tolist())
print("op chol")
print(op.solve(vec).tolist())
print("op mv")
print(op_mv.solve(vec).tolist())

In [None]:
print("dense")
print(jnp.linalg.eigh(mat)[0].tolist())
print("op chol")
print(op.diagonalize(num_iterations=10, rng_key=rng_key).tolist())
print("op mv")
print(op_mv.diagonalize(num_iterations=10, rng_key=rng_key).tolist())

### Test Normal

In [None]:
mean = jnp.array(1.0)
cov = jnp.array(1.0)
pdf = Normal(mean, cov)
x = jnp.linspace(-2, 4, 100)

In [None]:
import matplotlib.pyplot as plt
plt.plot(x, pdf(x))

### Test Multivariate Normal

In [None]:
mean = jnp.array([0.0, 0.0])
# cov = jnp.array([1.0, 0.5])
s, u = jnp.diag(jnp.array([1.0, 0.5])), jnp.array([[1.0, 1.0],[-1.0,1.0]])
cov = PSDOperator(
    op=jnp.matmul(u.transpose(), jnp.matmul(s,u)),
    op_type='raw'
)
pdf = MultivariateNormal(mean, cov)

In [None]:
_ = contour_plot(pdf)

### GGN

In [None]:
model = MLP(
    layers=[3, 64, 2],
    rngs=nnx.Rngs(params=0),
    prob_out=False,
)

In [None]:
train_samples = jax.random.uniform(rng_key, (10, model.shape[0]))
train_targets = jax.random.uniform(rng_key, (10, model.shape[1]))

In [None]:
ggn_mv = ggn(
    model=model,
    train_data=(train_samples, train_targets),
    likelihood_density=MultivariateNormal, 
    cov=DiagOperator(jnp.array(1.0), model.shape[1])
)

In [None]:
graph_def, params = nnx.split(model)
num_params = sum(x.size for x in jax.tree.leaves(params))

def model_fn(input, params):
    return nnx.call((graph_def, params))(input)[0]

In [None]:
ggn_mv_test = create_ggn_mv(
    model_fn,
    params,
    {
        'input': train_samples,
        'target': train_targets
    },
    loss_fn="mse",
    num_curv_samples=train_samples.shape[0],
    num_total_samples=train_samples.shape[0] // 2,
)

In [None]:
ggn_mv(params)

In [None]:
ggn_mv_test(params)

In [None]:
from laplax.curv.cov import create_posterior_fn, create_full_curvature
from laplax.util.flatten import (
    create_pytree_flattener,
    wrap_factory,
    wrap_function,
)
from laplax.util.mv import diagonal, to_dense
from laplax.util.tree import get_size, eye_like

flatten, unflatten = create_pytree_flattener(params)
mv_wrapped = wrap_function(ggn_mv, input_fn=unflatten, output_fn=flatten)
to_dense(mv_wrapped, layout=get_size(params))

In [None]:
num_params = sum(x.size for x in jax.tree.leaves(params))
eye_pytree = array_to_pytree(jnp.eye(num_params), params)
pytree_to_array(jax.lax.map(ggn_mv, eye_pytree, batch_size=None), axis=0)

In [None]:
from functools import partial

p = lambda f: MultivariateNormal(f, cov=DiagOperator(1.0, model.shape[1]))

@partial(jax.custom_jvp, nondiff_argnums=(0,1))
def neg_logll(p, y, f):
    return jax.vmap(lambda y, f: -p(f)._logpdf(y), in_axes=(0,0))(y, f).sum()

@partial(jax.custom_jvp, nondiff_argnums=(0,1,3))
@neg_logll.defjvp
def neg_logll_jvp(p, y, primals, tangents):
    f = primals[0]
    v = tangents[0]
    return neg_logll(p, y, f), jax.vmap(lambda y, f, v: -p(f)._logpdf_jvp_mean(y, v), in_axes=(0,0,0))(y, f, v).sum()

@neg_logll_jvp.defjvp
def neg_logll_hvp(p, y, tangents, primals_new, tangents_new):
    f = primals_new[0]
    v = tangents_new[0]
    return neg_logll(p, y, f), jax.vmap(lambda y, f, v: -p(f)._logpdf_hvp_mean(y, v), in_axes=(0,0,0))(y, f, v).sum()

In [None]:
v = jax.random.uniform(rng_key, (10, model.shape[1]))

In [None]:
jax.jvp(lambda f: neg_logll(p, train_targets, f), (model(train_samples),), (v,))

In [None]:
model(train_samples)

In [None]:
jax.grad(neg_logll_jvp, 2)(p, train_targets, model(train_samples), v)

In [None]:
jax.jacfwd(jax.jacrev(neg_logll,2),2)(p, train_targets, model(train_samples))

In [None]:
jax.grad(neg_logll, 2)(p, train_targets, model(train_samples))

In [None]:
v = jax.random.uniform(rng_key, (10,model.shape[1]))
neg_logll_jvp(p, train_targets, (model(train_samples),), (v,))

In [None]:
neg_logll_hvp(p, train_targets, (model(train_samples),), (v,))

In [None]:
graph_def, map_params = nnx.split(model)
def model_fn(params):
    return nnx.call((graph_def, params))(train_samples)[0]

def ggn(vec):
    _, jvp = jax.linearize(model_fn, map_params)
    HJv = neg_logll_hvp(p, train_targets, (model_fn(params),), (jvp(vec),))[1]
    return jax.linear_transpose(jvp, vec)(HJv)[0]
    # return HJv

In [None]:
ggn(map_params)