# jax-bayes CIFAR10 Example --- Bayesian MCMC Approach

## Set Up the Environment

In [3]:
#see https://github.com/google/jax#pip-installation
!pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade jax
!pip install git+https://github.com/deepmind/dm-haiku
!pip install git+https://github.com/jamesvuc/jax-bayes

Collecting jaxlib==0.1.51
  Using cached https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.51-cp36-none-manylinux2010_x86_64.whl
Installing collected packages: jaxlib
  Found existing installation: jaxlib 0.1.51
    Uninstalling jaxlib-0.1.51:
      Successfully uninstalled jaxlib-0.1.51
Successfully installed jaxlib-0.1.51
Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-yf2rs_hb
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-yf2rs_hb
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=289739 sha256=3b8458b694f0318292ff7f1ef1f8a08f8166e3141772ec2afc50f2464a55d1b0
  Stored in directory: /tmp/pip-ephem-wheel-cache-nu9w8nn8/wheels/97/0f/e9/17f34e

In [4]:
import haiku as hk

import jax.numpy as jnp
from jax.experimental import optimizers
import jax

import jax_bayes

import sys, os, math, time
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import tensorflow_datasets as tfds

## Build the dataset loader and CNN

In [5]:
def load_dataset(split, is_training, batch_size, repeat=True, seed=0):
  if repeat:
    ds = tfds.load('cifar10', split=split).cache().repeat()
  else:
    ds = tfds.load('cifar10', split=split).cache()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=seed)
  ds = ds.batch(batch_size)
  return tfds.as_numpy(ds)

# build a 32-32-64-32 CNN with max-pooling 
# followed by a 128-10-n_classes MLP
class Net(hk.Module):
  def __init__(self, dropout=0.1, n_classes=10):
    super(Net, self).__init__()
    self.conv_stage = hk.Sequential([
      #block 1
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 2
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 3
      hk.Conv2D(64, kernel_shape=3, stride=1, padding='SAME'), 
      jax.nn.relu, 
      hk.MaxPool(window_shape=(1,2,2,1), strides=(1,1,1,1), padding='VALID'),
      # block 4
      hk.Conv2D(32, kernel_shape=3, stride=1, padding='SAME')
    ])

    self.mlp_stage = hk.Sequential([
      hk.Flatten(),
      hk.Linear(128), 
      jax.nn.relu, 
      hk.Linear(n_classes)
    ])

    self.p_dropout = dropout

  def __call__(self, x, use_dropout=True):
    x = self.conv_stage(x)
    
    dropout_rate = self.p_dropout if use_dropout else 0.0
    x = hk.dropout(hk.next_rng_key(), dropout_rate, x)

    return self.mlp_stage(x)

# standard normalization constants
mean_norm = jnp.array([[0.4914, 0.4822, 0.4465]])
std_norm = jnp.array([[0.247, 0.243, 0.261]])

#define the net-function 
def net_fn(batch, use_dropout):
  net = Net(dropout=0.0)
  x = batch['image']/255.0
  x = (x - mean_norm) / std_norm
  return net(x, use_dropout)

## Build the Loss, Metrics, and MCMC step

In [13]:
# hyperparameters
# lr = 1e-2
lr_initial = 1e-2
lr_final = 1e-3
decay_start = 100
decay_steps = 100
decay_schedule = jax.experimental.optimizers.polynomial_decay(lr_initial, decay_steps, lr_final, power=1.0)
lr = lambda t: jax.lax.cond(t < decay_start,
                            lambda s: lr_initial,
                            lambda s: decay_schedule(s - decay_start),
                            t)


reg = 1e-4
num_samples = 5
#for this example, we're going to use the jax initializers to sample the initial 
# distribution, so we will use init_stddev = 0.0
init_stddev = 0.0 

# instantiate the network
net = hk.transform(net_fn)

# build the sampler
key = jax.random.PRNGKey(0)
sampler_init, sampler_propose, sampler_update, sampler_get_params = \
  jax_bayes.mcmc.rms_langevin_fns(key, num_samples=-1, step_size=lr, 
                                  init_stddev=init_stddev)

# standard regularized crossentropy loss function, which is the 
# negative unnormalized log-posterior 
def loss(params, rng, batch):
    logits = net.apply(params, rng, batch, use_dropout=True)
    labels = jax.nn.one_hot(batch['label'], 10)

    l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) 
                        for p in jax.tree_leaves(params))
    softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))

    return softmax_crossent + reg * l2_loss

logprob = lambda p,k,b : - loss(p, k, b)

@jax.jit
def accuracy(params, batch):
  pred_fn = lambda p:net.apply(p, jax.random.PRNGKey(101), batch, use_dropout=False)
  pred_fn = jax.vmap(pred_fn)
  preds = jnp.mean(pred_fn(params), axis=0)
  return jnp.mean(jnp.argmax(preds, axis=-1) == batch['label'])

# the data loss will help us monitor the Markov chain's progress without worrying
# about the effects of regularization.
def data_loss(params, batch):
    logits = net.apply(params, jax.random.PRNGKey(0), batch, use_dropout=False)
    labels = jax.nn.one_hot(batch['label'], 10)
    softmax_crossent = - jnp.mean(labels * jax.nn.log_softmax(logits))
    return softmax_crossent
data_loss = jax.vmap(data_loss, in_axes=(0, None))

@jax.jit
def mcmc_step(i, sampler_state, sampler_keys, rng, batch):
  params = sampler_get_params(sampler_state)
  logp = lambda p,k: logprob(p, k, batch)
  fx, dx = jax.vmap(jax.value_and_grad(logp))(params, rng)

  sampler_prop_state, new_keys = sampler_propose(i, dx, sampler_state, 
                                        sampler_keys)

  fx_prop, dx_prop = fx, dx

  sampler_state, new_keys = sampler_update(i, fx, fx_prop, 
                        dx, sampler_state, 
                        dx_prop, sampler_prop_state, 
                        new_keys)
  
  return jnp.mean(fx), sampler_state, new_keys

## Load Batch iterators & Do the MCMC inference

In [10]:
init_batches = load_dataset("train", is_training=True, batch_size=512)
val_batches = load_dataset("train", is_training=False, batch_size=2_000)
test_batches = load_dataset("test", is_training=False, batch_size=2_000)

In [14]:
#Use the vmap-over-keys trick to sample a highly anisotropic initial distribution
init_batch = next(init_batches)
keys = jax.random.split(jax.random.PRNGKey(1), num_samples)
init_param_samples = jax.vmap(lambda k:net.init(k, init_batch, use_dropout=True))(keys)
sampler_state, sampler_keys = sampler_init(init_param_samples)

# generate RNGs for the dropout
rngs = jax.random.split(jax.random.PRNGKey(2), num_samples)

for epoch in range(250):
  #generate a shuffled epoch of training data
  train_batches = load_dataset("train", is_training=True,
                              batch_size=128, repeat=False, seed=epoch)
  
  start = time.time()
  for batch in train_batches:
    # run an MCMC step
    train_logprob, sampler_state, sampler_keys = \
      mcmc_step(epoch, sampler_state, sampler_keys, rngs, batch)
    
    # make more rngs for the dropout
    rngs = jax.random.split(rngs[0], num_samples)
  epoch_time = time.time() - start

  if epoch % 5 == 0:
    # compute val and test accuracy, and the sampler-average data loss
    params = sampler_get_params(sampler_state)
    val_acc = accuracy(params, next(val_batches))
    test_acc = accuracy(params, next(test_batches))
    _data_loss = jnp.mean(data_loss(params, next(val_batches)))
    print(f"epoch = {epoch}"
        f" | time per epoch {epoch_time:.3f}"
        f" | data loss = {_data_loss:.3e}"
        f" | val acc = {val_acc:.3f}"
        f" | test acc = {test_acc:.3f}")

epoch = 0 | time per epoch 43.130 | data loss = 1.126e+17 | val acc = 0.189 | test acc = 0.193
epoch = 5 | time per epoch 35.306 | data loss = 1.354e+15 | val acc = 0.333 | test acc = 0.346
epoch = 10 | time per epoch 35.097 | data loss = 3.986e+14 | val acc = 0.351 | test acc = 0.350
epoch = 15 | time per epoch 34.981 | data loss = 1.929e+14 | val acc = 0.388 | test acc = 0.368
epoch = 20 | time per epoch 34.972 | data loss = 1.142e+14 | val acc = 0.392 | test acc = 0.399
epoch = 25 | time per epoch 34.980 | data loss = 7.258e+13 | val acc = 0.412 | test acc = 0.413
epoch = 30 | time per epoch 34.928 | data loss = 5.112e+13 | val acc = 0.438 | test acc = 0.399
epoch = 35 | time per epoch 34.901 | data loss = 3.738e+13 | val acc = 0.442 | test acc = 0.416
epoch = 40 | time per epoch 34.915 | data loss = 3.082e+13 | val acc = 0.455 | test acc = 0.404
epoch = 45 | time per epoch 34.886 | data loss = 2.601e+13 | val acc = 0.458 | test acc = 0.432
epoch = 50 | time per epoch 34.899 | data 

**Note**: This example highlights how Bayesian ML and regular ML are very different. 

- We know a lot less about efficient inference than we do optimization.
- Accuracy of around 45% (vs 70% for the optimization approach) is only a bit worse than current SoTA algorithms for this architecture (see e.g. [This paper](https://arxiv.org/pdf/1709.01180.pdf)). More hyperparameter tuning could probably close this gap.
- In fact many MCMC papers do not evaluate on CIFAR10 (preferring to use MNIST, where we can easily achieve >96%)
- There are several factors that contribute to MCMC's increased difficulty:
  - stochastic gradients
  - dependence on hyperparameters
  - regularization techniques
  - probabilistic algorithms are generally more subtle

In [20]:
def posterior_predictive(params, batch):
  """computes the posterior_predictive P(class = c | inputs, params) using a histogram
  """
  pred_fn = lambda p:net.apply(p, jax.random.PRNGKey(0), batch, use_dropout=False) 
  pred_fn = jax.vmap(pred_fn)

  logit_samples = pred_fn(params) # n_samples x batch_size x n_classes
  pred_samples = jnp.argmax(logit_samples, axis=-1) #n_samples x batch_size

  n_classes = logit_samples.shape[-1]
  batch_size = logit_samples.shape[1]
  probs = np.zeros((batch_size, n_classes))
  for c in range(n_classes):
    idxs = pred_samples == c
    probs[:,c] = idxs.sum(axis=0)

  return probs / probs.sum(axis=1, keepdims=True)

params = sampler_get_params(sampler_state)
print('Final predictive entropy', jnp.mean(jax_bayes.utils.entropy(posterior_predictive(params, next(test_batches)))))

Final predictive entropy 1.3844115
