### Bayesian Cubic Regression with Variational Inference

In [None]:
import scipy
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

from IPython.display import display, clear_output

import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax.config import config
from jax import value_and_grad
from jax import grad, vmap, pmap, jit
import jax.tree_util as jtu

import optax
from flax import linen as nn
from flax.training import train_state
import flax

from typing import Any, Callable, Sequence, Optional
import sympy

import distrax

from sympy import Matrix

from dataclasses import dataclass

import functools

from NN_arch import PiNet

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
import warnings
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")

### Cubic Regression

### Define True Model Function and Sample Data

$ y = 1 + t + 2t^2 + 4t^3$

In [None]:
ndata = 200 #number of known data points

t0 = -1.25
t1 = 1.25
t = jnp.linspace(t0, t1, ndata)

def true_fun(t):
    return jnp.array([1 + t + 2*t**2 + 4*t**3])

In [None]:
stdev = 3

seed = 989
np.random.seed(seed)

true_y = true_fun(t).squeeze() 
true_y = true_y + np.random.normal(scale=stdev, size=true_y.shape)

In [None]:
# Plot the dataset
plt.figure(figsize=(10,7))
plt.plot(t, true_fun(t).squeeze() , color='r', label="True values")
plt.scatter(t, true_y, label="Noise corrupted values")
plt.xlabel("Features")
plt.ylabel("Labels")
plt.title("Real function along with noisy targets")
plt.legend();

In [None]:
t = t[:,None]
true_y = true_y[:,None]

### Pre-train Neural ODE for guess values of VI initial parameters.  

This speeds up the method significantly and improves the accuracy, but is not necessary

In [None]:
# 1. Model instance
model = PiNet()

# 2. Initialize the parameters of the model
key = random.PRNGKey(0)
key, init_key = random.split(key)
params = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...

In [None]:
@jax.jit
def loss_fn(pred, known, params):    
    return jnp.sum(jnp.power(known - pred, 2)) #log likelihood

In [None]:
@jax.jit
def calculate_loss(params, t, y_known):
  
    y_pred = model.apply({'params': params}, t)
    
    loss = loss_fn(y_pred, y_known, params)
    
    return loss


In [None]:
@jax.jit
def calculate_value_loss_grad(params, t, y_known):
    y_pred = model.apply({'params': params}, t)
    loss = loss_fn(y_pred, y_known, params)
    
    grads = jax.grad(calculate_loss, 0)(params, t, y_known)
    
    return y_pred, loss, grads

In [None]:
# F. Initial train state including parameters initialization
def create_train_state(key, lr=5e-2):
    """Creates initial `TrainState for our classifier.
    
    Args:
        key: PRNG key to initialize the model parameters
        lr: Learning rate for the optimizer
    
    """
    # 1. Model instance
    model = PiNet()
    
    # 2. Initialize the parameters of the model
    params = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...  
    
    # 3. Define the optimizer with the desired learning rate
    #constant learning rate:
    optimizer = optax.adam(learning_rate=lr) #lr passed in from function
    
    # 4. Create and return initial state from the above information. The `Module.apply` applies a 
    # module method to variables and returns output and modified variables.
    return model, train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [None]:
#initialize model parameters
key = random.PRNGKey(0)
key, init_key = random.split(key)
model, state = create_train_state(init_key)

In [None]:
@jax.jit #can't jit the nonlinear solver
def train_step_gradient_descent(state, t, y_known):
    """Defines the single training step.
    """
    
    #calculate loss, grad
    y_pred, loss, grads = jax.vmap(calculate_value_loss_grad, in_axes=(None, 0,0))(state.params, t, true_y)
    
    #accumulate loss and grad
    loss = jnp.sum(loss, 0)
    grads = jtu.tree_map(lambda x: jnp.sum(x, 0), grads)
    
    #update gradients: 
    lr = 1e-6
    state = state.apply_gradients(grads=grads)
    
    
    # 5. Return loss, accuracy and the updated state
    return y_pred, loss, state

In [None]:
EPOCHS = 1000
test_freq = 500
key = random.PRNGKey(0)

key, init_key = random.split(key)
model, state = create_train_state(init_key)

# Lists to record loss for each epoch
training_loss = []

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))
plt.rcParams.update({'font.size': 12})

# Training 
for itr in range(EPOCHS):  
    y_pred, loss, state = train_step_gradient_descent(state, t, true_y)
    
    if loss < 1e-17:
        break
    
    training_loss.append(loss)
    
    if itr % test_freq == 0 or itr == EPOCHS-1:
        print('Iter {:04d} | Total Loss {:e}'.format(itr, training_loss[-1]))
        
        #loss graph -- don't change 
        #ax = plt.gca()
        ax1.cla()
        ax1.semilogy(training_loss)
        ax1.set_ylabel('training Loss')
        ax1.set_xlabel('Epochs')
        ax1.minorticks_on()

        #data and NN prediction
        ax2.cla()
        ax2.plot(t,y_pred, '-')
        ax2.scatter(t,true_y)
        ax2.set_xlabel('t')
        ax2.set_ylabel('y')
        ax2.minorticks_on()
        ax2.legend('NN', 'data')

        fig.tight_layout()
        display(fig)
        
        print(model.get_equation(state.params, ['x']))
        
        clear_output(wait=True)
    
    
    #print(f"loss: {training_loss[-1]:.3f}")

In [None]:
params_pretrained = state.params

### Setup BNN functions

In [None]:
# 1. Model instance
model = PiNet()

# 2. Initialize the parameters of the model
key = jax.random.PRNGKey(314)
key, init_key = jax.random.split(key)

#params_mu = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...
params_mu = params_pretrained
params_stdev = jtu.tree_map(lambda x: jnp.ones_like(x) * 5, params_mu)
params = flax.core.frozen_dict.freeze({'mu': params_mu, 'stdev': params_stdev})

prior_mu = jtu.tree_map(lambda x: jnp.zeros_like(x), params_mu)
prior_stdev = jtu.tree_map(lambda x: jnp.ones_like(x) * 100000, params_mu)
prior_params = flax.core.frozen_dict.freeze({'mu': prior_mu, 'stdev': prior_stdev})

In [None]:
#@jax.jit
def sample_params(params, key, Nsamples=100):
    """
    return a matrix of size (D, K) containing K samples from
     our variational distribution q
    """
    eps = jtu.tree_map(lambda x: jax.random.normal(key, shape=((Nsamples,) + x.shape)), params['mu'])
    #print(jtu.tree_map(lambda x: x.shape, eps))
    #print(jtu.tree_map(lambda x: x.shape, params['mu']))
    #print(jtu.tree_map(lambda x: x.shape, params['stdev']))
    
    #w = jtu.tree_map(lambda x, y, z: x + jnp.abs(y) * z, params['mu'], params['stdev'], eps)
    w = jtu.tree_map(lambda x, y, z: x + jnp.abs(y) * z, params['mu'], params['stdev'], eps)
    #print(jtu.tree_map(lambda x: x.shape, w))
    
    return w

In [None]:
jtu.tree_map(lambda x: x.shape, sample_params(prior_params, key))

In [None]:
@jax.jit
def KLD_cost(q_params, p_params):
    q_mu = q_params['mu']
    q_stdev = q_params['stdev']
    p_mu = p_params['mu']
    p_stdev = p_params['stdev']
    
    #p_logdet
    leaves, _ = jtu.tree_flatten(jtu.tree_map(lambda x: jnp.log(x**2), p_stdev))
    p_logdet = jnp.clip(jnp.sum(jnp.array([jnp.sum(leaf) for leaf in leaves])), a_min=-700)
    #print(p_logdet)
    
    #q_logdet
    leaves, _ = jtu.tree_flatten(jtu.tree_map(lambda x: jnp.log(x**2), q_stdev))
    q_logdet = jnp.clip(jnp.sum(jnp.array([jnp.sum(leaf) for leaf in leaves])), a_min=-700)
    #print(q_logdet)
    
    logdet_ratio = p_logdet - q_logdet
    #print(logdet_ratio)
    
    #k = number of parameters:
    leaves, _ = jtu.tree_flatten(jtu.tree_map(lambda x: x.size, q_mu))
    k = sum(leaves)
    #print(k)
    
    #dmu_sigma_inv_dmu
    dmu = jtu.tree_map(lambda x,y: x - y, q_mu, p_mu)
    dmu_sigma_inv_dmu = jtu.tree_map(lambda x, y: x**2 / y**2, dmu, p_stdev)
    #print(dmu_sigma_inv_dmu)
    leaves, _ = jtu.tree_flatten(dmu_sigma_inv_dmu)
    dmu_sigma_inv_dmu = sum([jnp.sum(leaf) for leaf in leaves])
    #print(dmu_sigma_inv_dmu)
    
    #trace term
    trace_term = jtu.tree_map(lambda x,y: (x**2)/(y**2), q_stdev, p_stdev)
    #print(trace_term)
    leaves, _ = jtu.tree_flatten(trace_term)
    trace_term = sum([jnp.sum(leaf) for leaf in leaves])
    #print(trace_term)
    
    KLD = 0.5 * (logdet_ratio - k + dmu_sigma_inv_dmu + trace_term)
    
    return KLD

In [None]:
KLD_cost(params, prior_params)

In [None]:
KLD_cost(prior_params, prior_params)

In [None]:
@jax.jit
def loss_fn(pred, known, params):    
    unregularized_loss = jnp.power(pred - known, 2)
    return unregularized_loss

In [None]:
@jax.jit
def ELBO(params, X, known, key, Nsamples=1000):
    sampled_params = sample_params(params, key, Nsamples=Nsamples)
    #print(jtu.tree_map(lambda x: x.shape, sampled_params))

    #calculate log likelihood:
    ypred = jax.vmap(jax.vmap(model.apply, (None, 0)), (0,None))({'params': sampled_params}, X)
    print(ypred.shape)
    ll_term = jnp.mean(loss_fn(ypred, known, sampled_params),0).sum()
    print(ll_term)

    KL_term = KLD_cost(params, prior_params)
    #print(KL_term)
    
    return -1*(-ll_term.mean() - KL_term)

In [None]:
ELBO(params, t, true_y, key)

In [None]:
# F. Initial train state including parameters initialization
def create_train_state(key, init_params, lr=1e-3):
    """Creates initial `TrainState for our classifier.
    
    Args:
        key: PRNG key to initialize the model parameters
        lr: Learning rate for the optimizer
    
    """
    # 1. Model instance
    model = PiNet()
    
    # 2. Initialize the parameters of the model    
    #params_mu = model.init(key, jnp.ones([1]))['params'] #change the 3 to match the dimension of input data...
    params_mu = init_params
    params_stdev = jtu.tree_map(lambda x: jnp.ones_like(x) * 0.5, params_mu)
    params = flax.core.frozen_dict.freeze({'mu': params_mu, 'stdev': params_stdev})

    prior_mu = jtu.tree_map(lambda x: jnp.zeros_like(x), params_mu)
    prior_stdev = jtu.tree_map(lambda x: jnp.ones_like(x) * 100000, params_mu)
    prior_params = flax.core.frozen_dict.freeze({'mu': prior_mu, 'stdev': prior_stdev})
    
    # 3. Define the optimizer with the desired learning rate
    #constant learning rate:
    optimizer = optax.adam(learning_rate=lr) #lr passed in from function
    
    # 4. Create and return initial state from the above information. The `Module.apply` applies a 
    # module method to variables and returns output and modified variables.
    return model, train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer), prior_params

In [None]:
#initialize model parameters
key = jax.random.PRNGKey(0)
key, init_key = jax.random.split(key)
#init_params = model.init(key, jnp.ones([1]))['params']
init_params = params_pretrained
model, state, prior_params = create_train_state(init_key, init_params)

In [None]:
@jax.jit 
def train_step_gradient_descent(state, t, y_known, key):
    """Defines the single training step.
    """
    sampled_params = sample_params(state.params, key, Nsamples=1000)
    ypred = jax.vmap(jax.vmap(model.apply, (None, 0)), (0,None))({'params': sampled_params}, t)
    
    loss, grads = jax.value_and_grad(ELBO)(state.params, t, y_known, key)

    #mu_zeros = jtu.tree_map(lambda x: jnp.zeros_like(x), grads['mu'])
    #grads = flax.core.frozen_dict.unfreeze(grads)
    #grads['mu'] = mu_zeros
    #grads = flax.core.frozen_dict.freeze(grads)
    #print(grads)
    
    #update gradients: 
    state = state.apply_gradients(grads=grads)
    
    # 5. Return loss, accuracy and the updated state
    return ypred, loss, state

In [None]:
ypred, loss, state = train_step_gradient_descent(state, t, true_y, key)

In [None]:
ypred.shape

In [None]:
state.params

### Train BNN

In [None]:
EPOCHS = 2000
test_freq = 500
key = random.PRNGKey(0)

key, init_key = random.split(key)
model, state, prior_params = create_train_state(init_key, init_params)

# Lists to record loss for each epoch
training_loss = []

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))
plt.rcParams.update({'font.size': 12})

# Training 
for itr in range(EPOCHS):  
    key, init_key = jax.random.split(key)
    ypred, loss, state = train_step_gradient_descent(state, t, true_y, key)
    
    #if loss < 1e-17:
    #    break
    
    training_loss.append(loss)
    
    if itr % test_freq == 0 or itr == EPOCHS-1:
        print('Iter {:04d} | Total Loss {:e}'.format(itr, training_loss[-1]))
        
        #loss graph -- don't change 
        #ax = plt.gca()
        ax1.cla()
        ax1.semilogy(training_loss)
        ax1.set_ylabel('training Loss')
        ax1.set_xlabel('Epochs')
        ax1.minorticks_on()

        #data and NN prediction
        ax2.cla()
        ax2.plot(t,jnp.mean(ypred,0), '-')
        ax2.scatter(t,true_y)
        ax2.fill_between(t.squeeze(), jnp.mean(ypred,0).squeeze()-3.0*jnp.var(ypred, 0).squeeze()**0.5, jnp.mean(ypred,0).squeeze()+3.0*jnp.var(ypred, 0).squeeze()**0.5, alpha=0.3,color='aqua')
        ax2.set_xlabel('t')
        ax2.set_ylabel('y')
        ax2.minorticks_on()
        ax2.legend('NN', 'data')

        fig.tight_layout()
        display(fig)
        
        #print(model.get_equation(state.params, ['x']))
        
        clear_output(wait=True)
    
    
    #print(f"loss: {training_loss[-1]:.3f}")

In [None]:
n_samples = 1000
sampled_params = sample_params(state.params, key, Nsamples=n_samples)

In [None]:
dim = 1
ys = np.zeros((n_samples, ndata, dim))
for i in range(0, n_samples):
    params_i = jtu.tree_map(lambda x: x[i], sampled_params)
    y_i = jax.vmap(functools.partial(model.apply, {'params': params_i}), (0))(t)
    ys[i] = np.array(y_i)

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-3.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='99.7% CI')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% CI')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua') #,label='99.7% confidence interval')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')
#plt.savefig('Uncertainty_Figure.svg')
#plt.savefig('Uncertainty_Figure.pdf')

In [None]:
import seaborn as sns
plt.rcParams.update({'font.size': 14})
fig, axs = plt.subplots(3, 3, figsize=(12, 12), sharey=False)

keys = list(sampled_params.keys())

sns.kdeplot(ax = axs[0,0], data=sampled_params[keys[0]], label=None, legend=False)
axs[0,0].set_title(keys[0])
axs[0,0].set_ylabel('Kernel Density Estimate')

sns.kdeplot(ax = axs[0,1], data=sampled_params[keys[1]].squeeze(), label=None, legend=False)
axs[0,1].set_title(keys[1])
axs[0,1].set_ylabel(None)

sns.kdeplot(ax = axs[0,2], data=sampled_params[keys[2]].squeeze(), label=None, legend=False)
axs[0,2].set_title(keys[2])
axs[0,2].set_ylabel(None)

sns.kdeplot(ax = axs[1,0], data=sampled_params[keys[3]].squeeze(), label=None, legend=False)
axs[1,0].set_title(keys[3])
axs[1,0].set_ylabel('Kernel Density Estimate')

sns.kdeplot(ax = axs[1,1], data=sampled_params[keys[4]].squeeze(), label=None, legend=False)
axs[1,1].set_title(keys[4])
axs[1,1].set_ylabel(None)

sns.kdeplot(ax = axs[1,2], data=sampled_params[keys[5]][:,0].squeeze(), label=None, legend=False)
axs[1,2].set_title(keys[5])
axs[1,2].set_ylabel(None)

sns.kdeplot(ax = axs[2,0], data=sampled_params[keys[6]].squeeze(), label=None, legend=False)
axs[2,0].set_title(keys[6])
axs[2,0].set_ylabel('Kernel Density Estimate')
axs[2,0].set_xlabel('Parameter Value')

sns.kdeplot(ax = axs[2,1], data=sampled_params[keys[7]].squeeze(), label=None, legend=False)
axs[2,1].set_title(keys[7])
axs[2,1].set_ylabel(None)
axs[2,1].set_xlabel('Parameter Value')

sns.kdeplot(ax = axs[2,2], data=sampled_params[keys[8]].squeeze(), label=None, legend=False)
axs[2,2].set_title(keys[8])
axs[2,2].set_ylabel(None)
axs[2,2].set_xlabel('Parameter Value')

for ax in axs.flatten():
    ax.minorticks_on()
    ax.set_ylim([0,1])

plt.tight_layout()
plt.savefig('Cubic_Posteriors_VI_w&b.svg')
plt.savefig('Cubic_Posteriors_VI_w&b.pdf')

### Expanding out the Parameters with Monte Carlo

In [None]:
n_samples = 500
sampled_params = sample_params(state.params, key, Nsamples=n_samples)
expanded = []
it = 0
for i in range(0, n_samples):
    it = it + 1
    if it % 100 == 0:
        print(it)
    sample_params = jtu.tree_map(lambda x: x[i], sampled_params)
    equation = model.get_equation(sample_params, ['x'])
    sample_expanded = sympy.Poly(equation[0], sympy.symbols('x')).as_dict(sympy.symbols('x')).values()
    sample_expanded = np.array(list(sample_expanded), np.float64)
    expanded.append(sample_expanded)
expanded = np.array(expanded, np.float64)

In [None]:
#np.save('VI_cubic_expanded.npy', expanded)

In [None]:
import seaborn as sns
sns.kdeplot(expanded[:,0], label='$1$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,1], label='$x$')
#plt.axvline(1, 0)
sns.kdeplot(expanded[:,2], label='$2x^2$')
#plt.axvline(2, 0)
sns.kdeplot(expanded[:,3], label='$4x^3$')
#plt.axvline(4, 0)
plt.xlabel('Parameter Value')
plt.legend()
plt.minorticks_on()
plt.ylabel('Kernel Density Estimate')

plt.savefig('CubicRegression_VI_kde.svg')
plt.savefig('CubicRegression_VI_kde.pdf')

In [None]:
dim = 1

def function2(param_flat, x):
    return jnp.dot(jnp.column_stack((jnp.ones_like(t), t, t**2, t**3)), param_flat)

params_samples = expanded
ys = jax.vmap(function2, (0,None))(params_samples, t)[:,:,None]

ys_mean = np.mean(ys, 0)
ys_stdev  = np.std(ys, 0)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.scatter(t, true_y, label='Noise Corrupted Training Data')
plt.plot(t, true_fun(t).squeeze() , color='g', label="True Model")
plt.plot(t, ys_mean , color='r', label="Mean Model")
for i in range(0, dim):
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]+2.0*ys_stdev[:,i], alpha=0.3,color='royalblue',label='95% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]+2.0*ys_stdev[:,i], ys_mean[:,i]+3.0*ys_stdev[:,i], alpha=0.3,color='aqua',label='99.7% confidence interval')
    plt.fill_between(t.squeeze(), ys_mean[:,i]-2.0*ys_stdev[:,i], ys_mean[:,i]-3.0*ys_stdev[:,i], alpha=0.3,color='aqua')
    
plt.legend(fontsize=12)
plt.xlabel('x')
plt.minorticks_on()
plt.ylabel('y')
plt.savefig('CubicRegression_VI_Uncertainty_Figure.svg')
plt.savefig('CubicRegression_VI_Uncertainty_Figure.pdf')