# Benchmark NumPyro in large dataset

This notebook uses `numpyro` and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.1 on a NVIDIA RTX 2070.

In [1]:
#!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

In [2]:
import time

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS
assert numpyro.__version__.startswith('0.5.0')

# NB: replace gpu by cpu to run this notebook in cpu
#numpyro.set_platform("cpu") #CA
numpyro.set_platform("gpu")

We do preprocessing steps as in [source code](https://github.com/google-research/google-research/blob/master/simple_probabilistic_programming/no_u_turn_sampler/logistic_regression.py) of reference [1]:

In [3]:
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()

# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

# make binary feature
_, counts = np.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = (labels == specific_category)

N, dim = features.shape
print("Data shape:", features.shape)
print("Label distribution: {} has label 1, {} has label 0"
      .format(labels.sum(), N - labels.sum()))

Data shape: (581012, 55)
Label distribution: 211840 has label 1, 369172 has label 0


Now, we construct the model:

In [4]:
def model(data, labels):
    coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
    logits = jnp.dot(data, coefs)
    return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

## Benchmark HMC

In [5]:
step_size = jnp.sqrt(0.5 / N)
kernel = HMC(model, step_size=step_size, trajectory_length=(10 * step_size), adapt_step_size=False)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=('num_steps',))
mcmc.get_extra_fields()['num_steps'].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=['num_steps'])
num_leapfrogs = mcmc.get_extra_fields()['num_steps'].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()

number of leapfrog steps: 5000
avg. time for each step : 0.003021224784851074

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.98      0.00      1.98      1.98      1.98      3.98      1.58
  coefs[1]     -0.03      0.00     -0.03     -0.03     -0.03      3.62      1.70
  coefs[2]     -0.12      0.00     -0.12     -0.12     -0.12      5.73      1.07
  coefs[3]     -0.30      0.00     -0.30     -0.30     -0.30      3.45      1.68
  coefs[4]     -0.10      0.00     -0.10     -0.10     -0.10      5.43      1.02
  coefs[5]     -0.15      0.00     -0.15     -0.16     -0.15      2.59      3.18
  coefs[6]     -0.04      0.00     -0.04     -0.04     -0.04      2.64      2.74
  coefs[7]     -0.49      0.00     -0.49     -0.49     -0.49      4.99      1.43
  coefs[8]      0.25      0.00      0.25      0.24      0.25      4.07      1.70
  coefs[9]     -0.02      0.00     -0.02     -0.02     -0.02      6.43      1.48
 coefs[10]     -0.23      0.00

In CPU, we get `avg. time for each step : 0.02782863507270813`.

## Benchmark NUTS

In [6]:
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=('num_steps',))
mcmc.get_extra_fields()['num_steps'].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=['num_steps'])
num_leapfrogs = mcmc.get_extra_fields()['num_steps'].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()

number of leapfrog steps: 49262
avg. time for each step : 0.004623858861213508

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.97      0.01      1.97      1.96      1.99     33.46      1.01
  coefs[1]     -0.04      0.00     -0.04     -0.05     -0.03     38.63      1.01
  coefs[2]     -0.06      0.01     -0.06     -0.08     -0.04     41.08      1.03
  coefs[3]     -0.30      0.00     -0.30     -0.31     -0.30     93.01      1.00
  coefs[4]     -0.09      0.00     -0.09     -0.10     -0.08    209.27      0.98
  coefs[5]     -0.14      0.00     -0.15     -0.15     -0.14     34.54      1.01
  coefs[6]      0.26      0.04      0.26      0.22      0.33     61.87      0.98
  coefs[7]     -0.67      0.02     -0.67     -0.70     -0.64     59.25      0.98
  coefs[8]      0.60      0.04      0.59      0.55      0.67     73.34      0.98
  coefs[9]     -0.01      0.00     -0.01     -0.02     -0.01     38.63      0.99
 coefs[10]      0.22      0.4

In CPU, we get `avg. time for each step : 0.028006251705287415`.

## Compare to other frameworks

|               |    HMC    |    NUTS   |
| ------------- |----------:|----------:|
| Edward2 (CPU) |           |  56.1 ms  |
| Edward2 (GPU) |           |   9.4 ms  |
| Pyro (CPU)    |  35.4 ms  |  35.3 ms  |
| Pyro (GPU)    |   3.5 ms  |   4.2 ms  |
| NumPyro (CPU) |  27.8 ms  |  28.0 ms  |
| NumPyro (GPU) |   1.6 ms  |   2.2 ms  |

Note that in some situtation, HMC is slower than NUTS. The reason is the number of leapfrog steps in each HMC trajectory is fixed to $10$, while it is not fixed in NUTS.

**Some takeaways:**
+ The overhead of iterative NUTS is pretty small. So most of computation time is indeed spent for evaluating potential function and its gradient.
+ GPU outperforms CPU by a large margin. The data is large, so evaluating potential function in GPU is clearly faster than doing so in CPU.

## References

1. `Simple, Distributed, and Accelerated Probabilistic Programming,` [arxiv](https://arxiv.org/abs/1811.02091)<br>
Dustin Tran, Matthew D. Hoffman, Dave Moore, Christopher Suter, Srinivas Vasudevan, Alexey Radul, Matthew Johnson, Rif A. Saurous