# Use BlackJAX with PyMC v4
Author: Martin Ingram, based on the notebook for PyMC3 by Kaustubh Chaudhari

In [1]:
import jax
import pymc as pm
from pymc.sampling_jax import get_jaxified_logp
import numpy as np
import blackjax.nuts as nuts
import blackjax.stan_warmup as stan_warmup

# Just for me to disable the GPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

print(f"Running on PyMC3 v{pm.__version__}")

You are running the v4 development version of PyMC which currently still lacks key features. You probably want to use the stable v3 instead which you can either install via conda or find on the v3 GitHub branch: https://github.com/pymc-devs/pymc/tree/v3


Running on PyMC3 v4.0.0




## Data

Please refer to the [original TFP example](https://www.tensorflow.org/probability/examples/Eight_Schools) for a description of the problem and the model that is used.

In [2]:
# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

# Model


In [3]:
with pm.Model() as model:

    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
    theta_1 = mu + tau * theta
    obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)

In [4]:
chains = 4
n_warmup = 1000
n_samples = 50_000

# Sampling using PyMC NUTS Sampler

In [5]:
%%time

with model:
    posterior = pm.sample(draws=n_samples, tune=n_warmup, chains=chains)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, tau, theta]


Sampling 4 chains for 1_000 tune and 50_000 draw iterations (4_000 + 200_000 draws total) took 50 seconds.


CPU times: user 37.5 s, sys: 2.28 s, total: 39.7 s
Wall time: 59.6 s


# Sampling using PyMC JAX Numpyro NUTS sampler

In [6]:
%%time

with model:
    hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(
        n_samples, target_accept=0.9, chains=chains, tune=n_warmup 
    )

Compiling...


  sub_prof = optimizer.optimize(fgraph)


Compilation time =  0 days 00:00:00.336274
Sampling...


2022-01-11 13:00:57.547356: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


  0%|          | 0/51000 [00:00<?, ?it/s]

  0%|          | 0/51000 [00:00<?, ?it/s]

  0%|          | 0/51000 [00:00<?, ?it/s]

  0%|          | 0/51000 [00:00<?, ?it/s]

Sampling time =  0 days 00:00:03.088836
Transforming variables...
Transformation time =  0 days 00:00:00.024455
CPU times: user 4.45 s, sys: 188 ms, total: 4.64 s
Wall time: 3.56 s


  sub_prof = optimizer.optimize(fgraph)


# Sampling using BlackJax

## Configuring the model for BlackJax


In [7]:
seed = jax.random.PRNGKey(1234)

logp_fn_jax = get_jaxified_logp(model)

rv_names = [rv.name for rv in model.value_vars]
initial_point = model.recompute_initial_point()
init_state = [initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

initial_state = jax.vmap(nuts.new_state, in_axes=(0, None))(init_state_batched, logp_fn_jax)

In [9]:
%%time

# Warmup
kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    logp_fn_jax, step_size, inverse_mass_matrix
)

init_seeds = jax.random.split(seed, 4)

warmup_fn = lambda seed, init_state: stan_warmup.run(seed, kernel_factory, init_state, n_warmup)

last_state, (step_size, inverse_mass_matrix), _ = jax.vmap(warmup_fn)(init_seeds, initial_state)

CPU times: user 17.3 s, sys: 18.7 ms, total: 17.3 s
Wall time: 17.3 s


In [10]:
%%time

# Sampling
def inner_inference_loop(
    rng_key, step_size, inverse_mass_matrix, initial_state, num_samples, kernel_factory
):
    
    kernel = kernel_factory(step_size, inverse_mass_matrix)
    
    def one_step(state, rng_key):
        
        key = jax.random.split(rng_key, 2)
        
        state, info = kernel(key, state)
        
        return state, (state, info)
    
    keys = jax.random.split(rng_key, num_samples)
    
    _, states = jax.lax.scan(one_step, initial_state, keys)
    
    return states

from functools import partial

to_vmap = partial(inner_inference_loop, num_samples=n_samples, kernel_factory=kernel_factory)

# TODO: need a new seed maybe?
start_keys = jax.random.split(seed, 4)

results = jax.vmap(to_vmap)(start_keys, step_size, inverse_mass_matrix, last_state)

CPU times: user 7.55 s, sys: 91.8 ms, total: 7.64 s
Wall time: 7.45 s
