## Using Laplace Approximation on Cubic Regression Problem

In [None]:
import scipy
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

from IPython.display import display, clear_output

import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax.config import config
from jax import value_and_grad
from jax import grad, vmap, pmap, jit
import jax.tree_util as jtu

import optax
from flax import linen as nn
from flax.training import train_state
import flax

from typing import Any, Callable, Sequence, Optional
import sympy

from sympy import Matrix

from dataclasses import dataclass

from NN_arch import PiNet

In [None]:
#enable float64 (required)
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
import warnings
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")

### Define True Model Function and Sample Data

$ y = 1 + x + 2x^2 + 4x^3$


In [None]:
ndata = 200 #number of known data points

t0 = -1.25
t1 = 1.25
t = jnp.linspace(t0, t1, ndata)

def true_fun(t):
    return jnp.array([1 + t + 2*t**2 + 4*t**3])

In [None]:
stdev = 3

seed = 989
np.random.seed(seed)

true_y = true_fun(t).squeeze() 
true_y = true_y + np.random.normal(scale=stdev, size=true_y.shape)

In [None]:
# Plot the dataset
plt.figure(figsize=(10,7))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y.squeeze(), label='Noise Corrupted Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.xlabel('x')
plt.ylabel('y')
plt.legend();

In [None]:
t = t[:,None]

In [None]:
true_y = true_y[:,None]

### Model

In [None]:
# 1. Model instance
model = PiNet()

# 2. Initialize the parameters of the model
key = random.PRNGKey(0)
key, init_key = random.split(key)
params = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...

In [None]:
@jax.jit
def loss_fn(pred, known, params):    
    return jnp.sum(jnp.power(known - pred, 2)) #log likelihood

In [None]:
@jax.jit
def calculate_loss(params, t, y_known):
  
    y_pred = model.apply({'params': params}, t)
    
    loss = loss_fn(y_pred, y_known, params)
    
    return loss


In [None]:
@jax.jit
def calculate_value_loss_grad(params, t, y_known):
    y_pred = model.apply({'params': params}, t)
    loss = loss_fn(y_pred, y_known, params)
    
    grads = jax.grad(calculate_loss, 0)(params, t, y_known)
    
    return y_pred, loss, grads

In [None]:
# F. Initial train state including parameters initialization
def create_train_state(key, lr=5e-2):
    """Creates initial `TrainState for our classifier.
    
    Args:
        key: PRNG key to initialize the model parameters
        lr: Learning rate for the optimizer
    
    """
    # 1. Model instance
    model = PiNet()
    
    # 2. Initialize the parameters of the model
    params = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...  
    
    # 3. Define the optimizer with the desired learning rate
    #constant learning rate:
    optimizer = optax.adam(learning_rate=lr) #lr passed in from function
    
    # 4. Create and return initial state from the above information. The `Module.apply` applies a 
    # module method to variables and returns output and modified variables.
    return model, train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [None]:
#initialize model parameters
key = random.PRNGKey(0)
key, init_key = random.split(key)
model, state = create_train_state(init_key)

In [None]:
@jax.jit #can't jit the nonlinear solver
def train_step_gradient_descent(state, t, y_known):
    """Defines the single training step.
    """
    
    #calculate loss, grad
    y_pred, loss, grads = jax.vmap(calculate_value_loss_grad, in_axes=(None, 0,0))(state.params, t, true_y)
    
    #accumulate loss and grad
    loss = jnp.sum(loss, 0)
    grads = jtu.tree_map(lambda x: jnp.sum(x, 0), grads)
    
    #update gradients: 
    lr = 1e-6
    state = state.apply_gradients(grads=grads)
    
    
    # 5. Return loss, accuracy and the updated state
    return y_pred, loss, state

In [None]:
y_pred, loss, state = train_step_gradient_descent(state, t, true_y)

In [None]:
loss

### Model Training

In [None]:
EPOCHS = 1000
test_freq = 500
key = random.PRNGKey(0)

key, init_key = random.split(key)
model, state = create_train_state(init_key)

# Lists to record loss for each epoch
training_loss = []

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))
plt.rcParams.update({'font.size': 12})

# Training 
for itr in range(EPOCHS):  
    y_pred, loss, state = train_step_gradient_descent(state, t, true_y)
    
    if loss < 1e-17:
        break
    
    training_loss.append(loss)
    
    if itr % test_freq == 0 or itr == EPOCHS-1:
        print('Iter {:04d} | Total Loss {:e}'.format(itr, training_loss[-1]))
        
        #loss graph -- don't change 
        #ax = plt.gca()
        ax1.cla()
        ax1.semilogy(training_loss)
        ax1.set_ylabel('training Loss')
        ax1.set_xlabel('Epochs')
        ax1.minorticks_on()

        #data and NN prediction
        ax2.cla()
        ax2.plot(t,y_pred, '-')
        ax2.scatter(t,true_y)
        ax2.set_xlabel('t')
        ax2.set_ylabel('y')
        ax2.minorticks_on()
        ax2.legend('NN', 'data')

        fig.tight_layout()
        display(fig)
        
        clear_output(wait=True)
    
    
    #print(f"loss: {training_loss[-1]:.3f}")

### Calculation with gradient

$\mathcal{I}(\theta) = \operatorname{E} \left[\left. \left(\frac{\partial}{\partial\theta} \log f(X;\theta)\right)^2\right|\theta \right]$

In [None]:
#get the gradients
ypred, loss, grads = jax.vmap(calculate_value_loss_grad, in_axes=(None, 0,0))(state.params, t, true_y)
#jtu.tree_map(lambda x: x.shape, grads)

In [None]:
mean_matrix, unravel_fn  = jax.flatten_util.ravel_pytree(state.params)
grads_matrix = jax.vmap(lambda x: jax.flatten_util.ravel_pytree(x)[0], (0))(grads)

In [None]:
prior_std = 1e30
cov = jnp.linalg.pinv(grads_matrix.transpose() @ grads_matrix * 1/(2*jnp.var(true_y - ypred, 0, ddof=1)+1e-30)**2 + jnp.eye(grads_matrix.shape[-1]) * 1/prior_std**2)

### Expanding out the Parameters with Monte Carlo

In [None]:
n_samples = 1000
samples = np.random.multivariate_normal(mean_matrix, cov, n_samples)
sample_params = jax.vmap(unravel_fn)(samples)

expanded = []
for it in range(0,n_samples):
    if it % 100 == 0:
        print(it)
    sample_param_i = jtu.tree_map(lambda x: x[it], sample_params)
    equation = model.get_equation(sample_param_i, ['x'])
    sample_expanded = sympy.Poly(equation[0], sympy.symbols('x')).as_dict(sympy.symbols('x')).values()
    sample_expanded = np.array(list(sample_expanded), np.float64)
    expanded.append(sample_expanded)
expanded = np.array(expanded, np.float64)

In [None]:
#np.save('Laplace_cubic_expanded.npy', expanded)

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

#plt.savefig('CubicRegression_kde.svg')
#plt.savefig('CubicRegression_kde.pdf')

In [None]:
n_samples = 100
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')
#plt.savefig('CubicRegression_Laplace_Uncertainty_Figure.svg')
#plt.savefig('CubicRegression_Laplace_Uncertainty_Figure.pdf')