# Introduction to `laplax`s FSP-Laplace regression tutorial

This tutorial follows one of the toy data experiments of [FSP-Laplace: Function-Space Priors for the Laplace Approximation in Bayesian Deep Learning](https://arxiv.org/abs/2407.13711) for regression and provides a quick overview of the FSP Laplace approximation in jax.
We regress on data that is modelled by:
$y = \sin(2\pi x) + \mathcal N(0, \sigma_n^2=0.1),$ on the intervals $[-1, -0.5] \cup [0.5, 1]$.

In [None]:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from helper import DataLoader, get_sinusoid_example, to_float64
from plotting import plot_regression_with_uncertainty, plot_sinusoid_task
from prior import *

import laplax
from laplax import util
from laplax.extra.fsp import *
from laplax.extra.fsp import lanczos_isqrt
from laplax.extra.fsp.fsp import compute_matrix_jacobian_product
from laplax.util.tree import to_dtype

jax.config.update("jax_enable_x64", True)


In [None]:
batch_size = 10
key = jax.random.key(0)
X_train1 = jnp.linspace(-1, -0.5, 75).reshape(-1, 1)
X_train2 = jnp.linspace(0.5, 1, 75).reshape(-1, 1)
X_train = jnp.concatenate([X_train1, X_train2], axis=0)
y_train = jnp.reshape(jnp.sin(X_train * 2 * jnp.pi) + jax.random.normal(key, (150, 1)) * 0.1, (-1, 1))
X_test = X_train
y_test = jnp.reshape(jnp.sin(X_train * 2 * jnp.pi) + jax.random.normal(key, (150, 1)) * 0.2, (-1, 1))
train_loader = DataLoader(X_train, y_train, batch_size)
data = {"input": X_train, "target": y_train}

fig = plot_sinusoid_task(X_train, y_train, X_test, y_test)

## Training for the MAP and defining the GP prior

In [None]:
class Model(nnx.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, rngs):
        self.linear1 = nnx.Linear(in_channels, hidden_channels, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_channels, hidden_channels, rngs=rngs)
        self.linear3 = nnx.Linear(hidden_channels, out_channels, rngs=rngs)

    def __call__(self, x):
        x = self.linear3(
            nnx.tanh(self.linear2(
                nnx.tanh(self.linear1(x)))
                )
            )
        return x
    

class MLP(nnx.Module):
    def __init__(self, model, param=None):
        self.model = model
        if param is not None:
            self.scale = nnx.Param(jnp.asarray(param))
        else:
            self.scale = nnx.Param(jnp.array(jnp.log(1 - jnp.exp(-0.1))))

    def __call__(self, x):
        return self.model(x)

    
model = Model(in_channels=1, hidden_channels=50, out_channels=1, rngs=nnx.Rngs(2))
model = to_float64(model)
model = MLP(model)

graph_def, params = nnx.split(model)

def model_fn(input, params):
    return nnx.call((graph_def, params))(input)[0]

prior_params = {
    # "per_ls": 2.947,
    # "per_p": 1.0,
    # "per_var": 6.608,
    # "matern52_ls": 143.478,
    "per_ls": 1.0,
    "per_p": 1.0,
    "per_var": 0.5,
    "matern52_ls": 4.0,
    "matern12_ls": 0.1,
    "matern12_var": 0.0,  # 0.25,
} 


def kernel_fn(xc):
    return gram(xc, prior_params, composite_kernel)

In [None]:
X_train.shape

In [None]:
X_plot = jnp.linspace(-3, 3, 200)[:, None]

K = kernel_fn(X_plot)

K_sqrt = jax.scipy.linalg.sqrtm((K + K.T) / 2)

sample = K_sqrt @ jax.random.normal(jax.random.key(10), (K.shape[0], 5))

from matplotlib import pyplot as plt

plt.plot(X_plot, sample, color="C0", alpha=0.1)
plt.plot(X_plot, jnp.sin(X_plot * 2 * jnp.pi), color="red", linestyle="--")

In [None]:
@nnx.jit(static_argnames=['loss_fn'])
def train_step(model, data, x_context, loss_fn):
    def loss_function(model):
        graph_def, current_params = nnx.split(model)
        
        def wrapped_loss_fn(data, x_context, params):
            
            temp_model = nnx.merge(graph_def, params)
            return loss_fn(data, x_context, params, jax.nn.softplus(temp_model.scale.value))
        
        return wrapped_loss_fn(data, x_context, current_params)
    
    loss, grads = nnx.value_and_grad(loss_function)(model)
    return loss, grads

def train_model(model, n_epochs, lr=1e-2):
    optimizer = nnx.Optimizer(model, optax.adam(lr))
    graph_def, _ = nnx.split(model)
    
    def model_fn(input, params):
        return nnx.call((graph_def, params))(input)[0]
    
    loss_fn = create_fsp_objective(model_fn, X_train.shape[0], jnp.zeros((200, 1)), kernel_fn)  # noqa: F405
    
    for epoch in range(n_epochs):
        for x_tr, y_tr in train_loader:
            data = {"input": x_tr, "target": y_tr}
            x_context = jnp.linspace(-2, 2, 200).reshape(-1, 1)
            
            loss, grads = train_step(model, data, x_context, loss_fn)
            optimizer.update(grads)
        
        if epoch % 100 == 0:
            print(f"[epoch {epoch}]: loss: {loss:.4f} Scale: {jax.nn.softplus(model.scale.value):.4f}")
            
    print(f"Final loss: {loss:.4f}")
    return model

model = train_model(model, n_epochs=1000)

In [None]:
X_pred = jnp.linspace(-2., 2., 200).reshape(200, 1)
y_pred = jax.vmap(model)(X_pred)

_ = plot_sinusoid_task(X_train, y_train, X_test, y_test, X_pred, y_pred)

In [None]:

graph_def, params = nnx.split(model.model)

def model_fn(input, params):
    return nnx.call((graph_def, params))(input)[0]

context_points = select_context_points(1000, "grid", [3.0], [-3.0], X_train.shape, key=jax.random.key(0))
prob_predictive = fsp_laplace(model_fn, params, data, kernel_fn, context_points)

In [None]:
X_pred = jnp.linspace(-2, 2, 200, dtype=jnp.float64).reshape(-1, 1)

pred = jax.vmap(prob_predictive)(X_pred)
plot_regression_with_uncertainty(
        X_train=data["input"],
        y_train=data["target"],
        X_pred=X_pred,
        y_pred=pred["pred_mean"][:, 0],
        y_std=jnp.sqrt(pred["pred_var"][:, 0]),
        y_samples=pred["samples"],
    )
