## Using Laplace Approximation on Cubic Regression Problem

In [None]:
import scipy
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

from IPython.display import display, clear_output

import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax.config import config
from jax import value_and_grad
from jax import grad, vmap, pmap, jit
import jax.tree_util as jtu

import optax
from flax import linen as nn
from flax.training import train_state
import flax

from typing import Any, Callable, Sequence, Optional
import sympy

from sympy import Matrix

from dataclasses import dataclass

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

### Define True Model Function and Sample Data

$ y = 1 + x + 2x^2 + 4x^3$


In [None]:
ndata = 200 #number of known data points

t0 = -1.25
t1 = 1.25
t = jnp.linspace(t0, t1, ndata)

def true_fun(t):
    return jnp.array([1 + t + 2*t**2 + 4*t**3])

In [None]:
stdev = 3

seed = 989
np.random.seed(seed)

true_y = true_fun(t).squeeze() 
true_y = true_y + np.random.normal(scale=stdev, size=true_y.shape)

In [None]:
# Plot the dataset
plt.figure(figsize=(10,7))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y.squeeze(), label='Noise Corrupted Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.xlabel('x')
plt.ylabel('y')
plt.legend();

In [None]:
yscale = jnp.abs(true_y.max()-true_y.min())
yscale

In [None]:
t = t[:,None]

In [None]:
true_y = true_y[:,None]

### Laplace Approximation Results 

In [None]:
expanded = np.load('Laplace_cubic_expanded.npy')

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')


In [None]:
n_samples = 100
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')

### MCMC NUTS Results 

In [None]:
expanded = np.load('nuts_cubic_expanded.npy')

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

In [None]:
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')

### Variational Inference Results 

In [None]:
expanded = np.load('VI_cubic_expanded.npy')

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

In [None]:
n_samples = 100
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')

### Bayesian Linear Regression

$\beta = (X^T X)^{-1} X^T y$

$ y = x^T \beta $

$I(\beta) = \frac{X^T X}{\sigma^2}$

In [None]:
X = jnp.column_stack((jnp.ones_like(t), t, t**2, t**3))
X.shape

In [None]:
beta = jnp.tensordot(jnp.tensordot(jnp.linalg.inv(jnp.tensordot(X.transpose(), X, (-1,0))), X.transpose(), (-1,0)), true_y, (-1,0))
beta

In [None]:
cov = jnp.linalg.inv(jnp.tensordot(X.transpose(), X, (-1,0)) / (jnp.var(X @ beta - true_y, ddof=1)))
cov

In [None]:
import seaborn as sns

expanded = np.random.multivariate_normal(beta.squeeze(), cov, 500000)

sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

In [None]:
expanded.shape

In [None]:
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')

### Let's combine all of the figures together:

In [None]:
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplots(4, 2, figsize=(12, 15))

#Laplace Approximation:----------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('Laplace_cubic_expanded.npy')

#figure [0,0] kde plot
sns.kdeplot(ax=axs[0,0], data=expanded[:,0], label=r'$\beta_0=1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,1], label=r'$\beta_1=1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,2], label=r'$\beta_2=2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,3], label=r'$\beta_3=4$')
axs[0,0].set_ylabel('Kernel Density Estimate')
axs[0,0].legend()



#figure [0,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[0,1].scatter(t, true_y, label='Training Data')
axs[0,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[0,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
axs[0,1].legend()



#MCMC NUTS: -----------------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('nuts_cubic_expanded.npy')

#figure [1,0] kde plot
sns.kdeplot(ax=axs[1,0], data=expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,3], label='$4x^3$')
axs[1,0].set_ylabel('Kernel Density Estimate')
#axs[1,0].legend()


#figure [1,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[1,1].scatter(t, true_y, label='Training Data')
axs[1,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[1,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
#axs[1,1].legend()





#MCMC NUTS: -----------------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('VI_cubic_expanded.npy')

#figure [2,0] kde plot
sns.kdeplot(ax=axs[2,0], data=expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,3], label='$4x^3$')
axs[2,0].set_ylabel('Kernel Density Estimate')
#axs[1,0].legend()

#figure [2,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[2,1].scatter(t, true_y, label='Training Data')
axs[2,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[2,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
#axs[2,1].legend()


# Bayesian Linear Regression: --------------------------------------------------------------------------------------------------------------------------------
expanded = np.random.multivariate_normal(beta.squeeze(), cov, 500000)

#figure [3,0] kde plot
sns.kdeplot(ax=axs[3,0], data=expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[3,0], data=expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[3,0], data=expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[3,0], data=expanded[:,3], label='$4x^3$')
axs[3,0].set_ylabel('Kernel Density Estimate')
#axs[2,0].legend()


#figure [2,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[3,1].scatter(t, true_y, label='Training Data')
axs[3,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[3,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[3,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[3,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[3,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
#axs[3,1].legend()




axs[0,0].set_xlim([-2,7])
axs[1,0].set_xlim([-2,7])
axs[2,0].set_xlim([-2,7])
axs[3,0].set_xlim([-2,7])

axs[0,1].set_ylabel('y')
axs[1,1].set_ylabel('y')
axs[2,1].set_ylabel('y')
axs[3,1].set_ylabel('y')

axs[3,0].set_xlabel('Parameter Value')
axs[3,1].set_xlabel('x')

axs[0,0].set_ylim([0,1.85])
axs[1,0].set_ylim([0,1.85])
axs[2,0].set_ylim([0,1.85])
axs[3,0].set_ylim([0,1.85])


for ax in axs.flatten():
    ax.minorticks_on()

axs[0,0].set_title('a.) Laplace Approximation', loc='left', pad=10, fontsize=15)
axs[1,0].set_title('b.) Markov Chain Monte Carlo', loc='left', pad=10, fontsize=15)
axs[2,0].set_title('c.) Variational Inference', loc='left', pad=10, fontsize=15)
axs[3,0].set_title('d.) Bayesian Linear Regression', loc='left', pad=10, fontsize=15)

plt.tight_layout()
#plt.subplots_adjust(top=0.25, hspace=0.25) 

plt.savefig('Cubic_Regression_Comparision.svg')
plt.savefig('Cubic_Regression_Comparision.pdf')

### Plot showing results for different MCMC sample sizes (500, 1000, and 10000 samples)

In [None]:
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplots(3, 2, figsize=(12, 11.25))

#MCMC NUTS with 500 samples:----------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('nuts_cubic_expanded_500.npy')

#figure [0,0] kde plot
sns.kdeplot(ax=axs[0,0], data=expanded[:,0], label=r'$\beta_0=1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,1], label=r'$\beta_1=1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,2], label=r'$\beta_2=2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[0,0], data=expanded[:,3], label=r'$\beta_3=4$')
axs[0,0].set_ylabel('Kernel Density Estimate')
axs[0,0].legend()



#figure [0,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[0,1].scatter(t, true_y, label='Training Data')
axs[0,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[0,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[0,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
axs[0,1].legend()



#MCMC NUTS with 1000 samples: -----------------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('nuts_cubic_expanded_1000.npy')

#figure [1,0] kde plot
sns.kdeplot(ax=axs[1,0], data=expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[1,0], data=expanded[:,3], label='$4x^3$')
axs[1,0].set_ylabel('Kernel Density Estimate')
#axs[1,0].legend()


#figure [1,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[1,1].scatter(t, true_y, label='Training Data')
axs[1,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[1,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[1,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
#axs[1,1].legend()





#MCMC NUTS with 10000 samples: -----------------------------------------------------------------------------------------------------------------------------------------------------
expanded = np.load('nuts_cubic_expanded_10000.npy')

#figure [2,0] kde plot
sns.kdeplot(ax=axs[2,0], data=expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(ax=axs[2,0], data=expanded[:,3], label='$4x^3$')
axs[2,0].set_ylabel('Kernel Density Estimate')
#axs[1,0].legend()

#figure [2,1] regression fit:
dim = 1
def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

axs[2,1].scatter(t, true_y, label='Training Data')
axs[2,1].plot(t, true_fun(t).squeeze() , color='g', label="True Model")
axs[2,1].plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% CI')
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    axs[2,1].fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
#axs[2,1].legend()




axs[0,0].set_xlim([-2,7])
axs[1,0].set_xlim([-2,7])
axs[2,0].set_xlim([-2,7])

axs[0,1].set_ylabel('y')
axs[1,1].set_ylabel('y')
axs[2,1].set_ylabel('y')

axs[2,0].set_xlabel('Parameter Value')
axs[2,1].set_xlabel('x')

axs[0,0].set_ylim([0,1.85])
axs[1,0].set_ylim([0,1.85])
axs[2,0].set_ylim([0,1.85])


for ax in axs.flatten():
    ax.minorticks_on()

axs[0,0].set_title('a.) Markov Chain Monte Carlo (500 samples)', loc='left', pad=10, fontsize=15)
axs[1,0].set_title('b.) Markov Chain Monte Carlo (1000 samples)', loc='left', pad=10, fontsize=15)
axs[2,0].set_title('c.) Markov Chain Monte Carlo (10000 samples)', loc='left', pad=10, fontsize=15)
#axs[3,0].set_title('d.) Bayesian Linear Regression', loc='left', pad=10, fontsize=15)

plt.tight_layout()
#plt.subplots_adjust(top=0.25, hspace=0.25) 

plt.savefig('Cubic_Regression_mcmc_Comparision.svg')
plt.savefig('Cubic_Regression_mcmc_Comparision.pdf')