Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

JAX implementation of Gaussian process classification (GPC) using parallelized elliptical slice sampling (ESS). The algorithm is taken from Iain Murray, Ryan Prescott Adams, and David JC MacKay. "[Elliptical Slice Sampling.](http://proceedings.mlr.press/v9/murray10a.html)" (2010). 

We leverage recent theoretical advances that characterize the function-space prior of an ensemble of infinitely-wide NNs as a Gaussian process, termed the neural network Gaussian process (NNGP). We use the NNGP with a softmax link function to build a probabilistic model for multi-class classification and marginalize over the latent Gaussian outputs to sample from the posterior using ESS. This gives us a better understanding of the implicit prior NNs place on function space and allows a direct comparison of the calibration of the NNGP and its finite-width analogue. See Adlam *et al.* "[Exploring the Uncertainty Properties of Neural Networks' Implicit Priors in the Infinite-Width Limit.](https://arxiv.org/abs/2010.07355)" (2020).

##Imports

In [None]:
import time
import numpy as onp
import tensorflow_datasets as tfds
import jax
from jax import core
from jax import device_put
from jax import devices
from jax import jit
from jax import lax
from jax import numpy as np
from jax import pmap
from jax import random
from jax import vmap
from jax import config as jax_config
from jax.nn import softmax
from jax.nn import log_softmax

In [None]:
!pip install -q git+https://www.github.com/google/neural-tangents

/bin/sh: pip: command not found


In [None]:
import neural_tangents as nt
from neural_tangents import stax

In [None]:
# Use float64 when possible to safeguard against numerical issues.
jax_config.update('jax_enable_x64', True)

# Hardcode the maximum number of ESS steps allowed.
# To avoid getting stuck in the while loop due to numerical issues.
MAX_STEPS = 1e4

##Function Definitions

In [None]:
def make_sda(arrays, devices=None):
  """Manually make a ShardedDeviceArray from a list of arrays."""
  if devices is None:
    devices = jax.local_devices()
  buffers = []
  for arr, dev in zip(arrays, devices):
    buffers.append(jax.interpreters.xla.device_put(arr, device=dev))
    x_shape, x_dtype = arr.shape, arr.dtype
  aval = core.ShapedArray((len(devices),) + x_shape, x_dtype)
  return jax.pxla.ShardedDeviceArray(aval, buffers)

In [None]:
def es_step(key, f_old, prior_sampler, log_l):
  """Performs a single step of elliptical slice sampling.

  Args:
    key: A JAX PRNGKey.
    f_old: The current state of the Markov chain stored as a 1D DeviceArray of
      shape (dim,).
    prior_sampler: We assume the prior distribution is a mean zero multivariate
      Gaussian of dimension dim that can be parameterized by the Cholesky
      decomposition its covariance matrix, denoted l. Then prior_sampler is a
      function to sample from the prior that takes a PRNGKey and l as arguments
      and returns a 1D DeviceArray of shape (dim,).
    log_l: A function that returns the log-likelihood of the current state.
  Returns:
    A tuple (key, f_new, success, i), where key is a new PRNGKey, f_new is the
    new state of the Markov chain, success is a bool, and i is an int. The bool
    success indicates where the step of ESS was performed successfully. In some
    cases the step can fail due to numerical precision; for some samples of the
    randomness, the only acceptable transition for the chain is very close to
    its current state. The chain can take many loops to complete the step, and
    eventually the new state is equal to the current state (up to numerical
    precision). So success is an indicator for when this failure occurs. Note as
    long as it happens infrequently, the overall sampling is fine. Finally, i
    indicates the number of iterations taken in the while loop.
  """
  key, subkey = random.split(key)
  nu = prior_sampler(subkey)
  key, subkey = random.split(key)
  log_y = log_l(f_old) + np.log(random.uniform(subkey))
  key, subkey = random.split(key)
  theta = 2 * np.pi * random.uniform(subkey)
  theta_min, theta_max = theta - 2 * np.pi, theta

  def _cond(vals):
    _, f_new, _, _, _, i = vals
    return np.logical_and(log_l(f_new) < log_y, i <= MAX_STEPS)

  def _body(vals):
    """Body function for while loop to shrink the feasible region of theta."""
    key, f_new, theta, theta_min, theta_max, i = vals
    i_new = i + 1
    # if theta < 0, then theta_min = theta
    theta_min += (theta - theta_min) * (np.sign(-theta) + 1.) / 2.
    # else theta_max = theta
    theta_max += (theta - theta_max) * (np.sign(theta) + 1.) / 2.
    key, subkey = random.split(key)
    theta = theta_min + (theta_max - theta_min) * random.uniform(subkey)
    f_new = f_old * np.cos(theta) + nu * np.sin(theta)

    return key, f_new, theta, theta_min, theta_max, i_new

  f_new = f_old * np.cos(theta) + nu * np.sin(theta)
  key, f_new, theta, theta_min, theta_max, i = lax.while_loop(
      _cond, _body, (key, f_new, theta, theta_min, theta_max, 0))

  return (key, np.where(i <= MAX_STEPS, f_new, f_old))

In [None]:
def es_sample(key, dim, nc, prior_sampler, log_l, num_samples, burn_in,
              trace_tuple=None, eval_tuple=None, logging_fn=print,
              init_state=None):
  """Sample from the posterior using MCMC given by elliptical slice sampling.

  Args:
    key: A JAX PRNGKey.
    dim: An int specifying the dimension of the state, since it is 1D.
    nc: An int specifying the number of classes in the classification.
    prior_sampler: We assume the prior distribution is a mean zero multivariate
      Gaussian of dimension dim that can be parameterized by the Cholesky
      decomposition its covariance matrix, denoted l. Then prior_sampler is a
      function to sample from the prior that takes a PRNGKey and l as arguments
      and returns a 1D DeviceArray of shape (dim,).
    log_l: A function that returns the log-likelihood of the current state.
    num_samples: An int for the number of samples from or steps of MCMC.
    burn_in: An int specifying the number of steps to throw out from the MCMC.
    trace_tuple: A tuple (fn, i), where fn is a function to apply to the state
      of the Markov chain and save as a DeviceArray, and i is an int specifying
      the interval at which to save the trace. Note saving lots of data
      frequently can cause OOMs.
    eval_tuple: A tuple (fn, i), where fn is a function to apply to the current
      posterior given by the Markov chain and log the result, and i is an int
      specifying the interval at which to evaluate.
    logging_fn: Defaults to logging.info, but if using the code in Colab print()
      can be used.
    init_state: A 1D DeviceArray of shape (dim,) that is the initial state for
      the ESS.
  Returns:
    A tuple (p, eval_traces, step_norms), where p is a DeviceArray for the
    posterior, eval_traces is a DeviceArray containing statisitics given by
    eval_fn from the trace, and step_norms contains the steps sizes of Markov
    chain's transitions.
  """
  start_time = time.time()
  loop_time = start_time
  logging_fn('Starting MCMC sampling...\n')

  p_num = jax.local_device_count()
  key = np.reshape(
      random.split(key, jax.device_count()),
      [jax.host_count(), jax.local_device_count(), 2])[jax.host_id()]
  total_samples = int(num_samples // p_num + burn_in)
  if init_state is None:
    sample = pmap(lambda x: np.zeros([dim]))(np.arange(p_num))
  else:
    sample = pmap(lambda x: init_state)(np.arange(p_num))
  p = pmap(lambda x: np.zeros([dim//nc, nc]))(np.arange(p_num))
  def accumulate_softmax(x, y, t):
    out = ((t-1.)/t) * x + (1./t) * softmax(np.reshape(y, [dim//nc, nc]))
    # Guess current state while still in burn in phase.
    out = np.where(t > 0, out, softmax(np.reshape(y, [dim//nc, nc])))
    return out / np.sum(out, axis=-1, keepdims=True)

  # Set up function to evaluate the current p(y|x).
  eval_print_steps = total_samples
  if eval_tuple is not None:
    eval_print, eval_print_steps = eval_tuple

  # Set up function to apply to current sample and save.
  eval_traces = []
  trace_fn_steps = total_samples
  if trace_tuple is not None:
    trace_fn, trace_fn_steps = trace_tuple
    trace_fn = pmap(trace_fn)

  epoch = int(min(eval_print_steps, trace_fn_steps))

  def body_fun(i, vals):
    key, sample, p = vals
    key, sample = es_step(key, sample, prior_sampler, log_l)
    p = accumulate_softmax(p, sample, i - burn_in + 1.)
    return key, sample, p

  @pmap
  def for_loop_fn(i, key, sample, p):
    return lax.fori_loop(i, i + epoch, body_fun, (key, sample, p))

  for i in range(0, total_samples, epoch):
    key, sample, p = for_loop_fn(i * np.ones([p_num]), key, sample, p)

    i += epoch
    if eval_tuple is not None:
      eval_print(p, i)

    if trace_tuple is not None:
      eval_traces.append(trace_fn(sample))

    logging_fn('Completed step {}/{} in {:.3f} mins at {}\n'.format(
        i, total_samples, (time.time() - loop_time) / 60.,
        time.asctime(time.localtime())))
    loop_time = time.time()

  p = np.mean(p, axis=0)

  logging_fn('Sampling complete in {:.3f} mins.'.format(
      (time.time() - start_time) / 60.))

  return p, np.array(eval_traces)

In [None]:
def gpc_predict(
    k_fn, k_scale, x0, y0, x1, key, num_samples, burn_in, diag_reg=0.,
    ess_dtype=np.float32, trace_tuple=None, eval_tuple=None,
    logging_fn=print):
  """Approximates test set posterior given a kernel and training data using ESS.

  Args:
    k_fn: A neural_tangents kernel function. Can be None if the kernel's
      Cholesky decomposition will be loaded from CNS.
    k_scale: A scalar to rescale the kernel matrix k. Note this is only used if
      the kernel is computed using the kernel_fn, not if it is loaded from CNS.
    x0: An array of training points.
    y0: An array of labels for the training data.
    x1: An array of test points.
    key: A PRNGKey.
    num_samples: A positive integer specifying the number of steps of MCMC.
    burn_in: A positive integer specifying to burn-in for the MCMC sampling.
    diag_reg: A nonnegative float to add to the diagonal and regularize the
      kernel matrix.
    ess_dtype: The dtype for the computed Cholesky and the state of the ess.
    trace_tuple: A tuple (fn, num), where fn is a function applied to the
      current state and num is an integer specifying how often the output is
      saved.
    eval_tuple: A tuple (fn, num), where fn is a function applied to the current
      probabilities to log their performance every num steps.
    logging_fn: Function to log progress.

  Returns:
    Two arrays containing the posterior on the training set and the test set.
    Consequently they have shapes (n0, nc) and (n1, nc), where each row is a
    distribution.
  """
  start_time = time.time()
  logging_fn('========= jax.device_count(): %s' % jax.device_count())
  logging_fn('========= local_device_count: %s' % jax.local_device_count())
  logging_fn('========= devices: %s' % ',"'.join(map(str, jax.devices())))
  logging_fn('========= host_id: %s' % jax.host_id())

  if k_fn is None and cns_l is None:
    raise ValueError('Either k_fn or cns_l must be specified!')

  n0, nc = y0.shape
  n1 = x1.shape[0]

  init_state = None


  logging_fn('Computing kernel matrices...')
  k = k_fn(np.vstack([x0, x1]), None, 'nngp')
  logging_fn('Computed kernel matrices in {:.3f} mins.'.format(
      (time.time() - start_time) / 60.))

  # Computing initial state.
  logging_fn('Computing initial state for ESS...')
  y1_hat = k[:n0, :n0] @ np.linalg.solve(
      k[:n0, :n0] + diag_reg * np.eye(n0), y0)
  init_state = np.reshape(np.vstack([y0, y1_hat]), [-1])
  logging_fn('Computing initial state for ESS in {:.3f} mins.'.format(
      (time.time() - start_time) / 60.))

  # Compute the Cholesky once, adding a diagonal regularizer for stability.
  start_time = time.time()
  logging_fn('Computing Cholesky decomposition...')
  l = onp.sqrt(k_scale) * onp.linalg.cholesky(
      k + diag_reg * np.trace(k) / (n0 + n1) * np.eye(n0 + n1))
  # ANY CASTING SHOULD HAPPEN HERE, AFTER THE CHOLESKY COMPUTATION.
  l = jax.device_put(l, jax.devices('cpu')[0])
  l = np.array(l, dtype=ess_dtype)  # Cast Cholesky to desired precision.
  logging_fn('Computed Cholesky decomposition in {:.3f} mins.'.format(
      (time.time() - start_time) / 60.))

  def prior_sampler(key):
    normal_samples = random.normal(key, [n0+n1, nc])
    normal_samples = np.array(normal_samples, dtype=ess_dtype)
    # The Cholesky of the Kronecker is the Kronecker of the Choleskys, i.e.
    # np.kron(l, np.eye(nc)). Avoid instantiating large matrix.
    # NB. While this is correct, it seems buggy on TPU, i.e. there are different
    # answers for the code below and actually using the Kronecker product.
    return np.reshape(l @ normal_samples, [-1])

  def log_l(f):
    """Log-likelihood: p(data|f_train, f_test) == p(data|f_train)."""
    f_ = np.reshape(f[:n0*nc], [n0, nc])
    p_ = log_softmax(f_)
    # Assumes y0 contains one-hot labels.
    return np.sum(p_ * y0)

  p, eval_trace = es_sample(
      key, nc*(n0+n1), nc, prior_sampler, log_l, num_samples, burn_in,
      trace_tuple, eval_tuple, logging_fn=logging_fn, init_state=init_state)
  return p[:n0], p[n0:], eval_trace

In [None]:
DATASETS = {
    'mnist': 10,
    'cifar10': 10,
    'cifar100': 100,
}


def to_one_hot(label, nc):
  return (np.arange(nc) == label).astype(int)


def get_dataset(train_dataset, test_dataset, n0, n1, classes, valid_set=False,
                flatten=True):
  """Function to load data and preprocess it.
  
  Note this implementation is slow due to filtering and preprocessing."""
  if train_dataset not in DATASETS or test_dataset not in DATASETS:
    raise ValueError('Dataset "{}" or "{}" not recognized! Choose from: '
                     '{}'.format(train_dataset, test_dataset, DATASETS))
  ds_train = tfds.load(name=train_dataset)
  ds_test = tfds.load(name=test_dataset)

  if classes is None:
    nc = DATASETS[train_dataset]
    filter_fn = lambda x: True
  elif isinstance(classes, int):
    if classes > DATASETS[train_dataset]:
      raise ValueError('Requesting {} classes from {}, but it only has {}'
                       'classes!'.format(
                           classes, train_dataset, DATASETS[train_dataset]))
    nc = classes
    classes = set(range(nc))
    filter_fn = lambda x: x['label'] in classes
  elif isinstance(classes, set) or isinstance(classes, list):
    classes = set(classes)  # Remove any duplicate classes.
    nc = len(classes)
    filter_fn = lambda x: x['label'] in classes
  else:
    raise ValueError(
        '"classes" must be type None, int, or set! Given {}'.format(
            type(classes)))

  ds_train_np = tfds.as_numpy(ds_train)
  ds_test_np = tfds.as_numpy(ds_test)
  # Keep datapoints as regular NumPy arrays on CPU for Cholesky computation.
  # NB. Currently, requesting float64 does not work.
  x0_ = onp.array([np.array(x['image']).astype(onp.float64)
                   for x in ds_train_np['train'] if filter_fn(x)])
  x1_ = onp.array([np.array(x['image']).astype(onp.float64)
                   for x in ds_test_np['test'] if filter_fn(x)])

  ds_train_np = tfds.as_numpy(ds_train)
  ds_test_np = tfds.as_numpy(ds_test)
  y0_ = np.array([to_one_hot(x['label'], nc)
                  for x in ds_train_np['train'] if filter_fn(x)])
  y1_ = np.array([to_one_hot(x['label'], nc)
                  for x in ds_test_np['test'] if filter_fn(x)])
  y0 = y0_[:n0]
  y1 = y1_[:n1]

  if valid_set:
    if x0_.shape[0] < n0 + n1:
      raise ValueError('Validation set is taken from end of training split. '
                       'So n0+n1 cannot exceed total training points in '
                       'requested dataset, but received {} and {}'.format(
                           n0+n1, x0_.shape[0]))
    x1_ = x0_[-n1:]
    y1 = y0_[-n1:]

  mean = onp.mean(x0_[:n0])
  std = onp.std(x0_[:n0])
  if flatten:
    x0_ = onp.reshape(x0_[:n0], (n0, -1))
    x1_ = onp.reshape(x1_[:n1], (n1, -1))
  else:
    x0_ = x0_[:n0]
    x1_ = x1_[:n1]
  x0 = (x0_ - mean) / std
  x1 = (x1_ - mean) / std
  # NOTE: CURRENTLY THIS CASTS THE ARRAYS TO FLOAT32!
  x0 = device_put(x0, devices('cpu')[0])
  x1 = device_put(x1, devices('cpu')[0])

  return (x0, y0), (x1, y1)

##Define Parameters

In [None]:
n0 = 1000
n1 = 1000
nc = 10
kernel_batch_size = None
train_dataset = 'cifar10'
test_dataset = 'cifar10'
valid_set = False
kernel_type = 'fc'
activation = 'erf'
W_std = 1.4142135624
b_std = 0.
k_scale = 1.
k_depth = 5
diag_reg = 1e-6
ess_dtype = np.float32
key = random.PRNGKey(0)
mcmc_steps = 1e5
eval_steps = 1e4
burn_in = 1e4
iterations = 1
save_trace = False

##Load Data

In [None]:
flatten = False if kernel_type == 'cnn' else True
(x0, y0), (x1, y1) = get_dataset(train_dataset, test_dataset, n0, n1, nc,
                                 valid_set=valid_set, flatten=flatten)

##Define Kernel

In [None]:
if activation == 'relu':
  act = stax.Relu()
elif activation == 'erf':
  act = stax.Erf()

if kernel_type == 'fc':
  collect_layers = [stax.Dense(512, W_std=W_std, b_std=b_std), act]*k_depth
  collect_layers += [stax.Dense(1, W_std=W_std, b_std=b_std)]
  _, _, k_fn = stax.serial(*collect_layers)

elif kernel_type == 'cnn':
  conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std,
                            padding='SAME', parameterization='ntk')
  collect_layers = [conv(512, (3, 3)), act]*k_depth
  collect_layers += [stax.Flatten(),
                      stax.Dense(1, W_std, b_std, parameterization='ntk')]
  _, _, k_fn = stax.serial(*collect_layers)

else:
  raise ValueError('Kernel type {} not recognized! Choose either fc or '
                    'cnn.'.format(kernel_type))
  
if kernel_batch_size is not None:
  # Recommended batch size ~25 for pooling, ~800 for flattening
  if (n0+n1) % kernel_batch_size * local_device_count() != 0:
    raise ValueError('Device count times batch size must divide the training '
                      'set plus test set size! Received {} and {}.'.format(
                          kernel_batch_size*local_device_count(), n0+n1))
  k_fn = nt.batch(k_fn, batch_size=kernel_batch_size, store_on_device=False)

else:
  k_fn = jit(k_fn, static_argnums=(1, 2))

##Run ESS

In [None]:
@vmap
def acc_(y, y_hat):
  return np.argmax(y) == np.argmax(y_hat)
acc = lambda x, y: np.mean(acc_(x, y))

def eval_print(p, t):
  p = onp.array(p, dtype=onp.float32)
  # Probabilities are 0 during burn in.
  if np.max(p) < 1. / nc:
    return None
  p = p / onp.sum(p, axis=-1, keepdims=True)

  if p.ndim == 3:
    for i, p_ in enumerate(p):
      p0, p1 = p_[:n0], p_[n0:]
      print("Train acc for chain {}: {}".format(i, acc(y0, p0)))
      print("Test acc for chain {}: {}".format(i, acc(y1, p1)))

    p = np.mean(p, axis=0)
  p0, p1 = p[:n0], p[n0:]
  print("Train acc: {}".format(acc(y0, p0)))
  print("Test acc: {}".format(acc(y1, p1)))

eval_tuple = (eval_print, eval_steps)

trace_tuple = None
if save_trace:
  # Save trace of subset of data.
  def trace_fn(sample):
    x = np.reshape(sample, (-1, nc))
    return np.vstack([x[:5], x[n0: n0+5]])

  trace_tuple = (trace_fn, 1)

In [None]:
total_p0, total_p1 = 0., 0.
for i in range(iterations):
  key, subkey = random.split(key)
  p0, p1, eval_trace = gpc_predict(
      k_fn, k_scale, x0, y0, x1, subkey, num_samples=mcmc_steps,
      burn_in=burn_in, diag_reg=diag_reg,
      ess_dtype=ess_dtype, trace_tuple=trace_tuple, eval_tuple=eval_tuple)
  total_p0 = (i * total_p0 + p0) / (i+1)
  total_p1 = (i * total_p1 + p1) / (i+1)