In [1]:
"""
Bayesian Neural Network
=======================

We demonstrate how to use NUTS to do inference on a simple (small)
Bayesian neural network with two hidden layers.
"""

import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as onp

from jax import vmap
import jax.numpy as np
import jax.random as random
from jax.nn import elu, relu

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

matplotlib.use('Agg')  # noqa: E402


In [2]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(1)

In [3]:


# the non-linearity we use in our neural network
def nonlin(x):
    #return np.tanh(x)
    return elu(x)


# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H):

    D_X, D_Y = X.shape[1], 1

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H))))  # D_X D_H
    z1 = nonlin(np.matmul(X, w1))   # N D_H  <= first layer of activations

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H))))  # D_H D_H
    z2 = nonlin(np.matmul(z1, w2))  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y))))  # D_H D_Y
    z3 = np.matmul(z2, w3)  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / np.sqrt(prec_obs)

    # observe data
    return numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)




In [4]:
# create artificial regression dataset
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    onp.random.seed(0)
    X = np.linspace(-1, 1, N)
    X = np.power(X[:, onp.newaxis], np.arange(D_X))
    W = 0.5 * onp.random.randn(D_X)
    Y = np.dot(X, W) + 0.5 * np.power(0.5 + X[:, 1], 2.0) * np.sin(4.0 * X[:, 1])
    Y += sigma_obs * onp.random.randn(N)
    Y = Y[:, onp.newaxis]
    Y -= np.mean(Y)
    Y /= np.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = np.linspace(-1.3, 1.3, N_test)
    X_test = np.power(X_test[:, onp.newaxis], np.arange(D_X))

    return X, Y, X_test

In [5]:


# helper function for HMC inference
def run_inference(model, args, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=args['num_chains'],
                progress_bar=True)
    mcmc.run(rng_key, X, Y, D_H)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()


# helper function for prediction
def predict(model, rng_key, samples, X, D_H):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
    return model_trace['Y']['value']




In [6]:
args = {'num_samples': 2000, 'num_chains': 1, 'num_warmup': 2000}
N, D_X, D_H = 20, 3, 25

def run_inference_and_compute_eigs(N, D_X=3, D_H=15):
    X, Y, X_test = get_data(N=N, D_X=D_X)

    # do inference
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    samples = run_inference(model, args, rng_key, X, Y, D_H)
    
    # predict Y_test at inputs X_test
    vmap_args = (samples, random.split(rng_key_predict, args['num_samples'] * args['num_chains']))
#     test_predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H))(*vmap_args)
#     test_predictions = test_predictions[..., 0]

    train_predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X, D_H))(*vmap_args)
    train_predictions = train_predictions[..., 0]

    # compute mean prediction and confidence interval around median
    #mean_prediction = np.mean(test_predictions, axis=0)
    #percentiles = onp.percentile(test_predictions, [5.0, 95.0], axis=0)
    
    pars_list = []
    for keys in samples.keys():
        items = samples[keys]
        pars_list.append(items.reshape(args['num_samples'],-1))
    pars = np.hstack(pars_list)
    
    pars_eigs = np.linalg.svd(pars)[1]**2 / (args['num_samples']-1)
    
    function_eigs = np.linalg.svd(np.cov(train_predictions.T))[1]**2 / (args['num_samples'] - 1)
    
    return pars_eigs, function_eigs

In [9]:
n_list = [20, 100, 200, 500, 800, 1000, 1300, 1600, 1800, 2000, 2500, 3000]

In [10]:
len(n_list)

12

In [8]:
all_eig_list = [run_inference_and_compute_eigs(n, D_H=20) for n in n_list]

sample: 100%|██████████| 4000/4000 [00:44<00:00, 89.80it/s, 511 steps of size 6.61e-03. acc. prob=0.91]  



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prec_obs      7.60      2.40      7.32      4.20     11.68   1774.97      1.00
   w1[0,0]     -0.05      0.97     -0.05     -1.65      1.46   1229.34      1.00
   w1[0,1]     -0.07      0.99     -0.06     -1.68      1.47   1214.09      1.00
   w1[0,2]     -0.10      0.98     -0.12     -1.66      1.54   1308.52      1.00
   w1[0,3]     -0.07      0.94     -0.07     -1.75      1.36   1301.97      1.00
   w1[0,4]     -0.03      0.97     -0.00     -1.67      1.54   1462.12      1.00
   w1[0,5]     -0.03      0.99     -0.03     -1.52      1.62   1358.92      1.00
   w1[0,6]     -0.10      0.97     -0.07     -1.71      1.44   1277.40      1.01
   w1[0,7]     -0.06      0.99     -0.09     -1.72      1.52   1431.58      1.00
   w1[0,8]     -0.07      0.98     -0.08     -1.70      1.47   1412.89      1.00
   w1[0,9]     -0.08      0.95     -0.06     -1.68      1.46   1290.31      1.00
  w1[0,10]      0.00      0

(2000, 286)


  0%|          | 0/4000 [00:07<?, ?it/s]


KeyboardInterrupt: 

In [None]:


# # make plots
# fig, ax = plt.subplots(1, 1)

# # plot training data
# ax.plot(X[:, 1], Y[:, 0], 'kx')
# # plot 90% confidence level of predictions
# ax.fill_between(X_test[:, 1], percentiles[0, :], percentiles[1, :], color='lightblue')
# # plot mean prediction
# ax.plot(X_test[:, 1], mean_prediction, 'blue', ls='solid', lw=2.0)
# ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
# plt.tight_layout()



In [None]:
[plt.semilogy(all_eig_list[xx][0], label = str(n_list[xx])) for xx in range(len(n_list))]
plt.legend()
plt.title('Parameter Space Eigenvalues')

In [None]:
[plt.semilogy(all_eig_list[xx][1], label = str(n_list[xx])) for xx in range(len(n_list))]
plt.legend()
plt.title('Function Space Eigenvalues')

In [None]:
def eff_dim(x, s = 1.):
    return np.sum(x / (x + s))

In [None]:
eff_dims_parameters = [eff_dim(all_eig_list[x][0]) for x in range(len(n_list))]
eff_dims_fns = [eff_dim(all_eig_list[x][1], s=1e-4) for x in range(len(n_list))]


In [None]:
plt.scatter(n_list, eff_dims_parameters, label = 'Covariance(w)')
plt.scatter(n_list, eff_dims_fns, label = 'Covariance(f)')
plt.legend()
plt.ylabel('Effective dimension')
plt.xlabel('N')
plt.grid()
plt.title('Effective Dimensionality: Bayesian Neural Net')
plt.vlines(726, -10, 160)
plt.ylim((0, 140))
#plt.scatter([20, 200, 2000], eff_dims_fns)

In [None]:
import pickle

In [None]:
with open('../../saved-experiments/bnn_regression_evals.pkl', 'wb') as handle:
    pickle.dump(all_eig_list, handle)