In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import glob
import gc

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
jax.config.update('jax_enable_x64', True)


%matplotlib widget
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'Times'

sys.path.insert(0, os.path.abspath(os.path.join('../')))

## Find MAP

In [None]:
from geometric_bayesian.utils import DataLoader

num_samples_1, num_samples_2 = 100, 100
samples_c1 = jax.random.multivariate_normal(jax.random.key(0), jnp.array([-1,-1]), 0.2*jnp.eye(2), shape=(num_samples_1,))
targets_c1 = jnp.zeros(num_samples_1)
samples_c2 = jax.random.multivariate_normal(jax.random.key(0), jnp.array([1,1]), 0.2*jnp.eye(2), shape=(num_samples_2,))
targets_c2 = jnp.ones(num_samples_2)

samples = jnp.concatenate((samples_c1,samples_c2), axis=0)
targets = jnp.concatenate((targets_c1,targets_c2))
train_loader = DataLoader(samples, targets, batch_size=2, shuffle=True)

In [None]:
from flax import nnx
from geometric_bayesian.models import MLP

model = MLP(
    layers=[2,3,2,3,1],
    nl = nnx.tanh,
    use_bias=False,
    param_dtype=jax.numpy.float64
)
num_params = model.size

In [None]:
from geometric_bayesian.densities import Bernoulli, MultivariateNormal
from geometric_bayesian.functions.likelihood import neg_logll
from geometric_bayesian.operators import DiagOperator

p_ll = lambda f : Bernoulli(f, logits=True)
prior_var = DiagOperator(
    diag = jnp.array(10.),
    dim = num_params
)
p_prior = MultivariateNormal(cov=prior_var)

In [None]:
import optax
from flax import nnx

n_epochs = 100
step_size = 1e-3
optimizer = nnx.Optimizer(model, optax.adam(step_size))

def loss_fn(model, x, y):
    y_pred = model(x)
    return neg_logll(p_ll, y, y_pred) #- p_prior(model.params)/y.shape[0]

@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)
    return loss

In [None]:
losses = []
for epoch in range(n_epochs):
    for x_tr, y_tr in train_loader:
        losses.append(train_step(model, optimizer, x_tr, y_tr))

    if epoch % 10 == 0:
        print(f"[epoch {epoch}]: loss: {losses[-1]:.4f}")

print(f'{optimizer.step.value = }')
print(f"Final loss: {losses[-1]:.4f}")

In [None]:
from geometric_bayesian.utils.plot import contour_plot, surf_plot
fig = contour_plot(lambda x: jax.nn.sigmoid(model(x)), min=[-3, -3], max=[3, 3], res=100, iso=False, alpha=0.5, zorder=-1, label='\sigma(f(X))')
ax = fig.axes[0]
ax.scatter(samples_c1[:,0], samples_c1[:,1], label='class 1', color='green', alpha=0.5)
ax.scatter(samples_c2[:,0], samples_c2[:,1], label='class 2', color='orange', alpha=0.5)
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_aspect('equal', 'box')
ax.legend()

## Laplace Approximation

In [None]:
from geometric_bayesian.curv.ggn import ggn
from geometric_bayesian.utils.helper import wrap_pytree_function
from geometric_bayesian.operators.psd_operator import PSDOperator

ggn_fn = wrap_pytree_function(
    ggn(
        p = p_ll,
        f = model,
        X = samples,
        y = targets,
        scaling = float(samples.shape[0])
    ),
    nnx.state(model)
)
ggn_lr = PSDOperator(lambda v : ggn_fn(model.params, v), op_size=num_params).lowrank(num_modes=num_params)
cov_op = (ggn_lr + p_prior._cov.inverse()).inverse()
posterior = MultivariateNormal(cov=cov_op, mean=model.params)

In [None]:
from geometric_bayesian.approx.mc import pred_posterior_mean, pred_posterior_std, pred_posterior
params_samples = posterior.sample(size=10)
mean_fn = pred_posterior_mean(model, params_samples)
std_fn = pred_posterior_std(model, params_samples)
pred_posterior_fn = pred_posterior(model, params_samples, p_ll)

In [None]:
from geometric_bayesian.utils.plot import contour_plot, surf_plot
fig = contour_plot(mean_fn, min=[-3, -3], max=[3, 3], res=100, iso=False, alpha=0.5, zorder=-1)
ax = fig.axes[0]
ax.scatter(samples_c1[:,0], samples_c1[:,1], label='class 1', color='green', alpha=0.5)
ax.scatter(samples_c2[:,0], samples_c2[:,1], label='class 2', color='orange', alpha=0.5)
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_aspect('equal', 'box')
ax.legend()

In [None]:
from geometric_bayesian.utils.plot import contour_plot, surf_plot
fig = contour_plot(std_fn, min=[-3, -3], max=[3, 3], res=100, iso=False, alpha=0.5, zorder=-1)
ax = fig.axes[0]
ax.scatter(samples_c1[:,0], samples_c1[:,1], label='class 1', color='green', alpha=0.5)
ax.scatter(samples_c2[:,0], samples_c2[:,1], label='class 2', color='orange', alpha=0.5)
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_aspect('equal', 'box')
ax.legend()

## Features Mixture

In [None]:
from geometric_bayesian.utils.helper import wrap_pytree_function
from geometric_bayesian.operators.psd_operator import PSDOperator

f1 = model.features(samples,2)[targets==0]
f2 = model.features(samples,2)[targets==1]

p_z1 = MultivariateNormal(cov=PSDOperator(op=jnp.cov(f1.T)+1e-3*jnp.eye(f1.shape[1]), op_type='raw'),mean=jnp.mean(f1, axis=0))
p_z2 = MultivariateNormal(cov=PSDOperator(op=jnp.cov(f2.T)+1e-3*jnp.eye(f2.shape[1]), op_type='raw'),mean=jnp.mean(f2, axis=0))
p_z = lambda z: 0.5*jnp.exp(p_z1(z)) + 0.5*jnp.exp(p_z2(z))

In [None]:
from geometric_bayesian.utils.plot import contour_plot

mean, std = jnp.mean(model.features(samples,2), axis=0), jnp.std(model.features(samples,2), axis=0)
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,4))
# Need vmap but applied correctly - map over individual points, not coordinates
fig = contour_plot(lambda z: jnp.exp(jax.vmap(p_z1)(z)), min=mean-3*std, max=mean+3*std, res=100, iso=False, alpha=0.5, zorder=-1, label=r'p(z|c=1)', fig_ax=(fig, axs[0]))
fig = contour_plot(lambda z: jnp.exp(jax.vmap(p_z2)(z)), min=mean-3*std, max=mean+3*std, res=100, iso=False, alpha=0.5, zorder=-1, label=r'p(z|c=2)', fig_ax=(fig, axs[1]))
for ax in axs:
    ax.scatter(f1[:,0], f1[:,1], label='class 1', color='green', alpha=0.5)
    ax.scatter(f2[:,0], f2[:,1], label='class 2', color='orange', alpha=0.5)
    ax.set_aspect('equal', 'box')
    ax.legend()
fig.tight_layout()

In [None]:
p_y = lambda x: p_z1(model.features(x,2)) + jnp.log(1/2) - p_z(model.features(x,2))
fig = contour_plot(jax.vmap(p_y), min=[-3, -3], max=[3, 3], res=100, iso=False, alpha=0.5, zorder=-1)
ax = fig.axes[0]
ax.scatter(samples_c1[:,0], samples_c1[:,1], label='class 1', color='green', alpha=0.5)
ax.scatter(samples_c2[:,0], samples_c2[:,1], label='class 2', color='orange', alpha=0.5)
ax.set_xlim([-3,3])
ax.set_ylim([-3,3])
ax.set_aspect('equal', 'box')
ax.legend()