
# Case Study 5: Bayesian Neural Network

Adapted from https://num.pyro.ai/en/stable/examples/bnn.html , we first see the NumPyro implementation and then SOGA.


In [1]:
from sogaPreprocessor import *
from producecfg import *
from libSOGA import *
from time import time

torch.set_default_dtype(torch.float64)

In [2]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

matplotlib.use("Agg")  # noqa: E402


# the non-linearity we use in our neural network
def nonlin(x):
    return jax.nn.relu(x)


# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    assert w1.shape == (D_X, D_H)
    z1 = nonlin(jnp.matmul(X, w1))  # <= first layer of activations
    assert z1.shape == (N, D_H)

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    assert w2.shape == (D_H, D_H)
    z2 = nonlin(jnp.matmul(z1, w2))  # <= second layer of activations
    assert z2.shape == (N, D_H)

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    assert w3.shape == (D_H, D_Y)
    z3 = jnp.matmul(z2, w3)  # <= output of the neural network
    assert z3.shape == (N, D_Y)

    if Y is not None:
        assert z3.shape == Y.shape

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Normal(0., 1.0))  #Originally Gamma(3.0, 1.0)
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    with numpyro.plate("data", N):
        # note we use to_event(1) because each observation has shape (1,)
        numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)


# helper function for HMC inference
def run_inference(model, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=1000,
        num_samples=2000,
        num_chains=1,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, X, Y, D_H)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()


# helper function for prediction
def predict(model, rng_key, samples, X, D_H):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
    return model_trace["Y"]["value"]


# create artificial regression dataset
def get_data(N=20, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test


args = [10, 2, 2]
N, D_X, D_H = args
X, Y, X_test = get_data(N=N, D_X=D_X)

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, rng_key, X, Y, D_H)

# predict Y_test at inputs X_test
vmap_args = (
    samples,
    random.split(rng_key_predict, 2000 * 1),
)
predictions = vmap(
    lambda samples, rng_key: predict(model, rng_key, samples, X_test, D_H)
)(*vmap_args)
predictions = predictions[..., 0]

# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(predictions, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(X[:, 1], Y[:, 0], "kx")
# plot 90% confidence level of predictions
ax.fill_between(
    X_test[:, 1], percentiles[0, :], percentiles[1, :], color="lightblue"
)
# plot mean prediction
ax.plot(X_test[:, 1], mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

plt.savefig("bnn_plot.pdf")

sample: 100%|██████████| 3000/3000 [00:04<00:00, 678.96it/s, 15 steps of size 2.31e-01. acc. prob=0.85] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prec_obs      0.98      0.38      0.94      0.37      1.55    970.20      1.00
   w1[0,0]     -0.11      0.97     -0.13     -1.70      1.49   1633.86      1.00
   w1[0,1]     -0.21      0.96     -0.21     -1.74      1.42   1637.72      1.00
   w1[1,0]     -0.00      0.96      0.00     -1.54      1.56   2012.00      1.00
   w1[1,1]      0.00      0.92      0.01     -1.51      1.50   1611.33      1.00
   w2[0,0]     -0.12      0.99     -0.11     -1.76      1.46   1629.87      1.00
   w2[0,1]     -0.18      0.94     -0.17     -1.90      1.18   1679.31      1.00
   w2[1,0]     -0.10      0.93     -0.10     -1.66      1.39   2034.71      1.00
   w2[1,1]     -0.16      1.01     -0.16     -1.81      1.59   1460.40      1.00
   w3[0,0]     -0.08      0.89     -0.07     -1.59      1.30   1669.51      1.00
   w3[1,0]     -0.06      0.95     -0.06     -1.58      1.57   1544.18      1.00

Number of divergences: 6



In [3]:
print((X[:,1]).tolist())
print(X[:,1].shape)

[-1.0, -0.7777777910232544, -0.5555555820465088, -0.333333283662796, -0.11111113429069519, 0.11111116409301758, 0.3333333730697632, 0.5555555820465088, 0.7777777910232544, 1.0]
(10,)


In [19]:
def optimize(params_dict, loss_function, y, cfg, steps=500):
    optimizer = torch.optim.Adam([params_dict[key] for key in params_dict.keys()], lr=0.01)

    total_start = time.time()

    for i in range(steps):

        optimizer.zero_grad()  # Reset gradients
        
        # loss
        current_dist = start_SOGA(cfg, params_dict, pruning='ranking')

        loss = loss_function(y, current_dist)

        # Backpropagate
        loss.backward(retain_graph=True)
        
        optimizer.step()

        # Print progress
        if i % 1 == 0:
            out = ''
            for key in params_dict.keys():
                out = out + key + ': ' + str(params_dict[key].item()) + ' '
            out = out + f" loss: {loss.item()}"
            print(out)

    total_end = time.time()

    print('Optimization performed in ', round(total_end-total_start, 3))

In [20]:
def mean_squared_error(y_true, dist):
    return torch.mean((y_true - dist.gm.mean()) ** 2)

def mean_squared_error_bayes(y_true, dist):
    #This works for the means but of course not for the variances
    return torch.mean((y_true - dist.gm.mean()[:-2]) ** 2)

def neg_log_likelihood(y_true, dist):
    #Calculate the log-likelihood of the data given the distribution
    neg_log_likelihood = 0
    for i in range(10):
        neg_log_likelihood -= torch.log(dist.gm.marg_pdf(y_true[i].unsqueeze(0), i))
    return neg_log_likelihood

In [21]:
print((X[:,1]).tolist())
print(X[:,1].shape)

[-1.0, -0.7777777910232544, -0.5555555820465088, -0.333333283662796, -0.11111113429069519, 0.11111116409301758, 0.3333333730697632, 0.5555555820465088, 0.7777777910232544, 1.0]
(10,)


In [22]:
# Convert JAX array to NumPy array
numpy_array = Y.__array__().copy()

# Convert NumPy array to PyTorch tensor
Ytorch = torch.from_numpy(numpy_array)
Ytorch.shape

torch.Size([10, 1])

In [29]:
Ytorch

tensor([[-0.3339],
        [-0.2940],
        [-0.2069],
        [-0.5793],
        [-0.1670],
        [ 0.1788],
        [ 1.2079],
        [ 1.8025],
        [ 0.4927],
        [-2.1008]], dtype=torch.float32)

In [31]:
compiledFile=compile2SOGA('../programs/SOGA/Optimization/Case Studies/bnn3.soga')
cfg = produce_cfg(compiledFile)

pars = {'mu100':0., 'sigma100':1., 'mu101':0., 'sigma101':1.,'mu110':0., 'sigma110':1.,'mu111':0., 'sigma111':1.,'mu200':0., 'sigma200':1.,
        'mu201':0., 'sigma201':1.,'mu210':0., 'sigma210':1.,'mu211':0., 'sigma211':1.,'mu300':0., 'sigma300':1.,'mu310':0., 'sigma310':1.,}


#pars = {'mu100':0., 'sigma100':1., 'mu101':0., 'sigma101':1.,'mu110':0., 'sigma110':1.,'mu111':0., 'sigma111':1.,'mu300':0., 'sigma300':1.,'mu310':0., 'sigma310':1.}
#pars = {}

for key, value in pars.items():
    pars[key] = torch.tensor(value, requires_grad=True)    

output_dist = start_SOGA(cfg, pars, pruning='ranking') #params_dict 

optimize(pars, neg_log_likelihood, Ytorch, cfg, steps=20)

#predictive mean
#y_pred = params_dict['muw'].detach().numpy()*X.detach().numpy()+params_dict['mub'].detach().numpy()

#predictive variance
#sigma_y_pred = np.sqrt(params_dict['sigmay'].detach().numpy()**2 + (X.detach().numpy()*params_dict['sigmaw'].detach().numpy())**2 + params_dict['sigmab'].detach().numpy()**2)



mu100: 0.009999999866369806 sigma100: 0.990000000230016 mu101: -0.009998953196849708 sigma101: 0.990000005198233 mu110: -0.009999999894822729 sigma110: 0.9900000001408742 mu111: 0.009999906923708831 sigma111: 1.009999994801767 mu200: 0.0 sigma200: 1.0 mu201: 0.0 sigma201: 1.0 mu210: 0.0 sigma210: 1.0 mu211: 0.0 sigma211: 1.0 mu300: -0.009999881307914446 sigma300: 0.9900000000873542 mu310: 0.009999997675319882 sigma310: 0.9900000000422919  loss: 15.058224666171014
mu100: 0.019993934463126223 sigma100: 0.9800003256727742 mu101: -0.018336183347813017 sigma101: 0.9800161574420118 mu110: -0.019984767893382622 sigma110: 0.9800139698213848 mu111: 0.019622946892534743 sigma111: 1.0199749605541948 mu200: 0.0 sigma200: 1.0 mu201: 0.0 sigma201: 1.0 mu210: 0.0 sigma210: 1.0 mu211: 0.0 sigma211: 1.0 mu300: -0.01826063201058528 sigma300: 0.9800133700020317 mu310: 0.017459741682736948 sigma310: 0.9800062783560246  loss: 14.994924058856814
mu100: 0.029977460021501437 sigma100: 0.9700018601017881 mu101

In [33]:
output_dist = start_SOGA(cfg, pars, pruning='ranking')

In [34]:
output_dist.gm.mean()[:10]

tensor([ 0.0003, -0.0004,  0.0003,  0.0013,  0.0022,  0.0030,  0.0037,  0.0041,
         0.0043,  0.0043], grad_fn=<SliceBackward0>)

In [35]:
Ytorch

tensor([[-0.3339],
        [-0.2940],
        [-0.2069],
        [-0.5793],
        [-0.1670],
        [ 0.1788],
        [ 1.2079],
        [ 1.8025],
        [ 0.4927],
        [-2.1008]], dtype=torch.float32)