# Use BlackJAX with PyMC3
Author: Kaustubh Chaudhari

BlackJAX can take any log-probability function as long as it is compatible with JAX's JIT. In this notebook we show how we can use PyMC as a modeling language and BlackJAX as an inference library.

This example relies on PyMC v4, see [installation instructions](https://github.com/pymc-devs/pymc#installation) on the PyMC repository.

In [1]:
import jax
import numpy as np
import pymc as pm
import pymc.sampling_jax

import blackjax

print(f"Running on PyMC3 v{pm.__version__}")

/bin/ld: /tmp/tmpewbjudzh/tmp/tmpewbjudzh/source.o: in function `main':
/tmp/tmpewbjudzh/source.c:6: undefined reference to `cblas_ddot'
collect2: error: ld returned 1 exit status


Running on PyMC3 v4.0.0b2




## Data

Please refer to the [original TFP example](https://www.tensorflow.org/probability/examples/Eight_Schools) for a description of the problem and the model that is used.

In [2]:
# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

## Model


In [3]:
with pm.Model() as model:

    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
    theta_1 = mu + tau * theta
    obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)

## Sampling using PyMC NUTS Sampler

In [4]:
%%time

with model:
    posterior = pm.sample(50_000, chains=1)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [mu, tau, theta]


Sampling 1 chain for 1_000 tune and 50_000 draw iterations (1_000 + 50_000 draws total) took 30 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


CPU times: user 30.8 s, sys: 285 ms, total: 31.1 s
Wall time: 32.9 s


## Sampling using PyMC JAX Numpyro NUTS sampler

In [5]:
%%time

with model:
    hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(
        50_000, target_accept=0.9, chains=1, progress_bar=False
    )

Compiling...
Compilation time =  0 days 00:00:00.102013
Sampling...
Sampling time =  0 days 00:00:03.760742
Transforming variables...
Transformation time =  0 days 00:00:00.020388
CPU times: user 3.98 s, sys: 43.6 ms, total: 4.03 s
Wall time: 3.99 s


## Sampling using BlackJax

### Configuring the model for BlackJax

We first need to transpile the PyMC model into a logprobability density function that is compatible with JAX:

In [6]:
rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.compute_initial_point()
init_position = [init_position_dict[rv] for rv in rvs]

logprob_fn = pm.sampling_jax.get_jaxified_logp(model)

### Sampling

In [7]:
%%time

seed = jax.random.PRNGKey(1234)

adapt = blackjax.window_adaptation(blackjax.nuts, logprob_fn, 1000)
last_state, kernel, _ = adapt.run(seed, init_position)


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Sample from the posterior distribution
states, infos = inference_loop(seed, kernel, last_state, 50_000)

CPU times: user 5.21 s, sys: 10.9 ms, total: 5.22 s
Wall time: 5.17 s
