In [None]:
%load_ext autoreload
%autoreload 2
    
import os
import sys
import time
import glob
import gc

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
jax.config.update('jax_enable_x64', True)


%matplotlib widget
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'Times'

sys.path.insert(0, os.path.abspath(os.path.join('../')))

## Find MAP

In [None]:
from geometric_bayesian.utils import DataLoader

num_samples_1, num_samples_2 = 1, 0
samples_c1 = jax.random.multivariate_normal(jax.random.key(0), jnp.array([-1,-1]), 0.2*jnp.eye(2), shape=(num_samples_1,))
targets_c1 = jnp.zeros(num_samples_1)
samples_c2 = jax.random.multivariate_normal(jax.random.key(0), jnp.array([1,1]), 0.2*jnp.eye(2), shape=(num_samples_2,))
targets_c2 = jnp.ones(num_samples_2)

samples = jnp.concatenate((samples_c1,samples_c2), axis=0)
targets = jnp.concatenate((targets_c1,targets_c2))
train_loader = DataLoader(samples, targets, batch_size=2, shuffle=True)

In [None]:
from flax import nnx
from geometric_bayesian.models import MLP

model = MLP(
    layers=[2,1,1],
    nl = None, # nnx.tanh
    use_bias=False,
    param_dtype=jax.numpy.float64
)
num_params = model.size
print(model.params)

In [None]:
from geometric_bayesian.densities import Bernoulli, MultivariateNormal
from geometric_bayesian.functions.likelihood import neg_logll
from geometric_bayesian.operators import DiagOperator

p_ll = lambda f : Bernoulli(f, logits=True)
prior_var = DiagOperator(
    diag = jnp.array(10.), 
    dim = num_params
)
p_prior = MultivariateNormal(cov=prior_var)

In [None]:
import optax

n_epochs = 1000
step_size = 1e-2
optimizer = nnx.Optimizer(model, optax.adam(step_size))

def loss_fn(model, x, y):
    y_pred = model(x)
    return neg_logll(p_ll, y, y_pred) - p_prior(model.params)/y.shape[0]

@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads) 
    return loss

In [None]:
losses = []
for epoch in range(n_epochs):
    for x_tr, y_tr in train_loader:
        losses.append(train_step(model, optimizer, x_tr, y_tr))

    if epoch % 100 == 0:
        print(f"[epoch {epoch}]: loss: {losses[-1]:.4f}")

print(f'{optimizer.step.value = }')
print(f"Final loss: {losses[-1]:.4f}")
print(model.params)

In [None]:
from geometric_bayesian.utils.plot import contour_plot, surf_plot
fig = contour_plot(
    lambda x: jax.nn.sigmoid(model(x)), 
    min=[-3, -3], max=[3, 3], res=100, iso=False, alpha=0.5, zorder=-1, label='\sigma(f(X))')
ax = fig.axes[0]
ax.scatter(samples_c1[:,0], samples_c1[:,1], label='class 1', color='green', alpha=0.5)
ax.scatter(samples_c2[:,0], samples_c2[:,1], label='class 2', color='orange', alpha=0.5)
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_aspect('equal', 'box')
ax.legend()

In [None]:
from itertools import combinations
from geometric_bayesian.utils.plot import contour_nd_plot
def loss(params):
    return neg_logll(p_ll, targets, model.fwd_params()(samples, params)) - p_prior(params)/targets.shape[0]
fig = contour_nd_plot(
    fn=jax.vmap(loss),
    ranges=jnp.vstack((model.params-5, model.params+5)).T,
    resolution=100,
    reduce_fn=jnp.mean,
    cbar=True,
    pos="left", 
    size="8%", 
    pad=0.2, 
    label=r'\mathcal{L}(f,\theta)',
    labelpad=-30
)
pairs = list(combinations(range(num_params), 2))
for i, pair in enumerate(pairs):
    fig.axes[i].scatter(model.params[pair[0]], model.params[pair[1]], s=10)

In [None]:
from geometric_bayesian.curv.ggn import ggn
from geometric_bayesian.utils.helper import wrap_pytree_function
from geometric_bayesian.operators import PSDOperator, LowRankOperator

ggn_fn = wrap_pytree_function(
    ggn(
        p = p_ll,
        f = model,
        X = samples,
        y = targets,
        scaling = float(samples.shape[0])
    ), 
    nnx.state(model)
)

ggn = PSDOperator(lambda v : ggn_fn(model.params, v), op_size=num_params)
D, U = jnp.linalg.eigh(ggn @ jnp.eye(num_params))
print(D)

In [None]:
def ggn_subspace(ggn, subspace):
    def g_fn(x):
        op = PSDOperator(lambda v : ggn_fn(x, v), op_size=num_params)
        op_lr = op.lowrank()
        kernel_idx = op_lr.diag <= 1e-10
        if subspace=='kernel':
            op_lr.diag = p_prior._cov.diag
            op_lr.right = jnp.where(kernel_idx, op_lr.right, 0.0)
        else:
            op_lr.diag = jnp.where(~kernel_idx, op_lr.diag, 0.0)
            op_lr.right = jnp.where(~kernel_idx, op_lr.right, 0.0)
        op_lr.left = op_lr.right
        return op_lr
    return g_fn

In [None]:
from geometric_bayesian.integrate import integrate, ef, ef_s, ode23, ode45
from geometric_bayesian.geom.metric import christoffel_sk, christoffel_fk
dt, T = 0.001, 1.0
x0 = model.params
tau = 10.0

def brownian(g, tau, dt, space):
    def fn(t,x,u,key):
        g_lr = ggn_subspace(ggn=ggn_fn, subspace=space)(x)
        return (None, jnp.sqrt(2*tau*dt)*g_lr.squareroot().solve(jax.random.normal(key,(3,))))
    return fn

def brownian_ito(g, tau, dt, space):
    def fn(t,x,u,key):
        g_fn = ggn_subspace(ggn=ggn_fn, subspace=space)
        g_lr = g_fn(x)
        g_inv_sqrt = g_lr.inverse().squareroot().dense()
        csk = christoffel_sk(lambda x, v: g_fn(x)(v), lambda x,v: g_fn(x).solve(v))
        return (jax.vmap(lambda u,v: csk(x,u,v), in_axes=(1,1))(g_inv_sqrt,g_inv_sqrt).sum(axis=0), jnp.sqrt(2*tau*dt)*g_lr.squareroot().solve(jax.random.normal(key,(3,))))
    return fn

def geodeisic(g, space):
    def fn(t,x,u):
        g_fn = ggn_subspace(ggn=ggn_fn, subspace=space)
        csk = christoffel_sk(lambda x, v: g_fn(x)(v), lambda x,v: g_fn(x).solve(v))
        x1, x2 = jnp.split(x, 2)
        return jnp.append(x2, -csk(x1,x2,x2))
    return fn

In [None]:
step = integrate(
    f = brownian(g=ggn_fn, tau=tau, dt=dt, space='image'),
    dt = dt,
    T = T,
    integrator = ef_s,
    key = jax.random.key(0)
)
start = time.time()
brownian_image = step(x0)[0]
print(time.time() - start)

step = integrate(
    f = brownian(g=ggn_fn, tau=tau, dt=dt, space='kernel'),
    dt = dt,
    T = T,
    integrator = ef_s,
    key = jax.random.key(0)
)
start = time.time()
brownian_kernel = step(x0)[0]
print(time.time() - start)

In [None]:
mask = jnp.any(brownian_image > 5., axis=1)
brownian_image = brownian_image[~mask]
fig= plt.figure()
ax = fig.add_subplot(111, projection='3d', computed_zorder=False)
ax.scatter(brownian_image[:,0],brownian_image[:,1],brownian_image[:,2], color='blue')
ax.scatter(brownian_kernel[:,0],brownian_kernel[:,1],brownian_kernel[:,2], color='orange')
ax.scatter(x0[0],x0[1],x0[2], color='red')

In [None]:
step = integrate(
    f = brownian_ito(g=ggn_fn, tau=tau, dt=dt, space='image'),
    dt = dt,
    T = T,
    integrator = ef_s,
    key = jax.random.key(0)
)
start = time.time()
brownian_ito_image = step(x0)[0]
print(time.time() - start)

step = integrate(
    f = brownian_ito(g=ggn_fn, tau=tau, dt=dt, space='kernel'),
    dt = dt,
    T = T,
    integrator = ef_s,
    key = jax.random.key(0)
)
start = time.time()
brownian_ito_kernel = step(x0)[0]
print(time.time() - start)

In [None]:
mask = jnp.any(brownian_ito_image > 50., axis=1)
brownian_ito_image = brownian_ito_image[~mask]
fig= plt.figure()
ax = fig.add_subplot(111, projection='3d', computed_zorder=False)
ax.scatter(brownian_ito_image[:,0],brownian_ito_image[:,1],brownian_ito_image[:,2], color='blue')
ax.scatter(brownian_ito_kernel[:,0],brownian_ito_kernel[:,1],brownian_ito_kernel[:,2], color='orange')
ax.scatter(x0[0],x0[1],x0[2], color='red')

In [None]:
xv0 = jnp.hstack(
    (
        jnp.expand_dims(x0, axis=0).repeat(3, axis=0), 
        jax.random.uniform(jax.random.key(0), (3, x0.shape[0])) - 0.5
    )
)

step = integrate(
    f = geodeisic(g=ggn_fn, space='image'),
    dt = dt,
    T = T,
    integrator = ode45,
)
start = time.time()
# geodesic_image = step(xv0)[0]
geodesic_image = jax.vmap(step)(xv0)[0]
print(time.time() - start)

step = integrate(
    f = geodeisic(g=ggn_fn, space='kernel'),
    dt = dt,
    T = T,
    integrator = ode45,
)
start = time.time()
# geodesic_kernel = step(xv0)[0]
geodesic_kernel = jax.vmap(step)(xv0)[0]
print(time.time() - start)

In [None]:
fig= plt.figure()
ax = fig.add_subplot(111, projection='3d', computed_zorder=False)

for i in range(geodesic_image.shape[0]):
    traj = geodesic_image[i]
    ax.plot(
        traj[:, 0], 
        traj[:, 1],
        traj[:, 2],
        color="blue")

for i in range(geodesic_kernel.shape[0]):
    traj = geodesic_kernel[i]
    ax.plot(
        traj[:, 0], 
        traj[:, 1],
        traj[:, 2],
        color="orange")

ax.scatter(x0[0],x0[1],x0[2], color='red')