### Bayesian Cubic Regression with blackJax

In [None]:
import jax
import distrax
import blackjax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt

from functools import partial
import functools
from jax.flatten_util import ravel_pytree
import jax.tree_util as jtu

import numpy as np
from typing import Any, Callable, Sequence, Optional
import sympy
from sympy import Matrix

from NN_arch import PiNet

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

### Cubic Regression

### Define True Model Function and Sample Data

$ y = 1 + t + 2t^2 + 4t^3$

In [None]:
ndata = 200 #number of known data points

t0 = -1.25
t1 = 1.25
t = jnp.linspace(t0, t1, ndata)

def true_fun(t):
    return jnp.array([1 + t + 2*t**2 + 4*t**3])

In [None]:
stdev = 3

seed = 989
np.random.seed(seed)

true_y = true_fun(t).squeeze() 
true_y = true_y + np.random.normal(scale=stdev, size=true_y.shape)

In [None]:
# Plot the dataset
plt.figure(figsize=(10,7))
plt.plot(t, true_fun(t).squeeze() , color='r', label="True values")
plt.scatter(t, true_y, label="Noise corrupted values")
plt.xlabel("Features")
plt.ylabel("Labels")
plt.title("Real function along with noisy targets")
plt.legend();

In [None]:
t = t[:,None]
true_y = true_y[:,None]

### Setup BNN functions

In [None]:
#@jax.jit
def bnn_log_joint(params, X, known, model):
    ypred = jax.vmap(model.apply, (None, 0))({'params': params}, X) #.ravel()
    #ypred = model.apply(params, X).ravel()
    
    flatten_params, _ = ravel_pytree(params)
    log_prior = distrax.Normal(0.0, 100000.0).log_prob(flatten_params).sum()
    #log_likelihood = -1*jnp.sum(jnp.power(ypred - known, 2)) #log likelihood
    stdev = jnp.mean(jnp.power(ypred - known, 2))
    log_likelihood = distrax.Normal(known, stdev).log_prob(ypred).sum()
    #print(log_likelihood)
    
    log_joint = log_prior + log_likelihood
    return log_joint


def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

### Train BNN

In [None]:
key = jax.random.PRNGKey(314)
key_samples, key_init, key_warmup, key = jax.random.split(key, 4)

In [None]:
num_warmup = 1000
num_steps = 1000

# 1. Model instance
model = PiNet()

# 2. Initialize the parameters of the model
key = jax.random.PRNGKey(0)
key, init_key = jax.random.split(key)
params = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...

potential = partial(bnn_log_joint, X=t, known=true_y, model=model)

#HMC -- Don't know what num_integration_steps does or what value it should be
#adapt = blackjax.window_adaptation(blackjax.hmc, potential, num_warmup, num_integration_steps=1)

#NUTS
adapt = blackjax.window_adaptation(blackjax.nuts, potential, num_warmup, progress_bar=True)

(final_state, parameters), _ = adapt.run(key_warmup, params)
print('warmup done')
kernel = blackjax.nuts(potential, **parameters).step

states = inference_loop(key_samples, kernel, final_state, num_steps)

sampled_params = states.position
print('inference loop done')

In [None]:
#use this is desired to take more samples without repeating warmup
#num_steps = 1000
#states = inference_loop(key_samples, kernel, final_state, num_steps)
#sampled_params = states.position

In [None]:
jax.tree_util.tree_map(lambda x: x.shape, sampled_params)

In [None]:
n_samples = num_steps
dim = 1
ys = np.zeros((n_samples, ndata, dim))
for i in range(0, n_samples):
    params_i = jtu.tree_map(lambda x: x[i], sampled_params)
    y_i = jax.vmap(functools.partial(model.apply, {'params': params_i}), (0))(t)
    ys[i] = np.array(y_i)

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-3.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='99.7% CI')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua') #,label='99.7% confidence interval')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')
#plt.savefig('Uncertainty_Figure.svg')
#plt.savefig('Uncertainty_Figure.pdf')

In [None]:
import seaborn as sns
plt.rcParams.update({'font.size': 14})
fig, axs = plt.subplots(3, 3, figsize=(12, 12), sharey=False)

keys = list(sampled_params.keys())

sns.kdeplot(ax = axs[0,0], data=sampled_params[keys[0]], label=None, legend=False)
#axs[0,0].set_title(keys[0])
axs[0,0].set_ylabel('Kernel Density Estimate')

sns.kdeplot(ax = axs[0,1], data=sampled_params[keys[1]].squeeze(), label=None, legend=False)
#axs[0,1].set_title(keys[1])
axs[0,1].set_ylabel(None)

sns.kdeplot(ax = axs[0,2], data=sampled_params[keys[2]].squeeze(), label=None, legend=False)
#axs[0,2].set_title(keys[2])
axs[0,2].set_ylabel(None)

sns.kdeplot(ax = axs[1,0], data=sampled_params[keys[3]].squeeze(), label=None, legend=False)
#axs[1,0].set_title(keys[3])
axs[1,0].set_ylabel('Kernel Density Estimate')

sns.kdeplot(ax = axs[1,1], data=sampled_params[keys[4]].squeeze(), label=None, legend=False)
#axs[1,1].set_title(keys[4])
axs[1,1].set_ylabel(None)

sns.kdeplot(ax = axs[1,2], data=sampled_params[keys[5]][:,0].squeeze(), label=None, legend=False)
#axs[1,2].set_title(keys[5])
axs[1,2].set_ylabel(None)

sns.kdeplot(ax = axs[2,0], data=sampled_params[keys[6]].squeeze(), label=None, legend=False)
#axs[2,0].set_title(keys[6])
axs[2,0].set_ylabel('Kernel Density Estimate')
axs[2,0].set_xlabel('Parameter Value')

sns.kdeplot(ax = axs[2,1], data=sampled_params[keys[7]].squeeze(), label=None, legend=False)
#axs[2,1].set_title(keys[7])
axs[2,1].set_ylabel(None)
axs[2,1].set_xlabel('Parameter Value')

sns.kdeplot(ax = axs[2,2], data=sampled_params[keys[8]].squeeze(), label=None, legend=False)
#axs[2,2].set_title(keys[8])
axs[2,2].set_ylabel(None)
axs[2,2].set_xlabel('Parameter Value')

for ax in axs.flatten():
    ax.minorticks_on()
    ax.set_ylim([0,1])

plt.tight_layout()
plt.savefig('Cubic_Posteriors_nuts_w&b.svg')
plt.savefig('Cubic_Posteriors_nuts_w&b.pdf')

### Expanding out the Parameters with Monte Carlo

In [None]:
n_samples = num_steps

expanded = []
it = 0
for i in range(0, n_samples):
    it = it + 1
    if it % 100 == 0:
        print(it)
    sample_params = jtu.tree_map(lambda x: x[i], sampled_params)
    equation = model.get_equation(sample_params, ['x'])
    sample_expanded = sympy.Poly(equation[0], sympy.symbols('x')).as_dict(sympy.symbols('x')).values()
    sample_expanded = np.array(list(sample_expanded), np.float64)
    expanded.append(sample_expanded)
expanded = np.array(expanded, np.float64)

In [None]:
np.save('nuts_cubic_expanded.npy', expanded)

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

plt.savefig('CubicRegression_nuts_kde.svg')
plt.savefig('CubicRegression_nuts_kde.pdf')

In [None]:
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')
plt.savefig('CubicRegression_nuts_Uncertainty_Figure.svg')
plt.savefig('CubicRegression_nuts_Uncertainty_Figure.pdf')

### Analyze MCMC Diagnostics

In [None]:
fig, axs = plt.subplots(2,2, figsize=(12, 8), sharey=False, sharex=True)
axs[0,0].plot(expanded[:,0])
axs[0,1].plot(expanded[:,1])
axs[1,0].plot(expanded[:,2])
axs[1,1].plot(expanded[:,3])

axs[1,0].set_xlabel('Chain Iteration Number')
axs[1,1].set_xlabel('Chain Iteration Number')

axs[1,0].set_ylabel('Parameter Value')
axs[0,0].set_ylabel('Parameter Value')

plt.suptitle('Trace Plots for the Expanded Polynomial Coefficients')
axs[0,0].set_title(r'$\beta_0$')
axs[0,1].set_title(r'$\beta_1$')
axs[1,0].set_title(r'$\beta_2$')
axs[1,1].set_title(r'$\beta_3$')

plt.savefig('CubicRegression_mcmc_trace_plot.svg')
plt.savefig('CubicRegression_mcmc_trace_plot.pdf')

In [None]:
n = jnp.shape(expanded[0:100])[0]
m1 = jnp.mean(expanded[0:100],axis=0)
s1 = jnp.var(expanded[0:100],axis=0,ddof=1)

m = jnp.shape(expanded[500:])[0]
m2 = jnp.mean(expanded[500:],axis=0)
s2 = jnp.var(expanded[500:],axis=0,ddof=1)

T = (m1-m2)/jnp.sqrt(s1/n+s2/m)
T