# Use BlackJAX with PyMC3
Author: Kaustubh Chaudhari

BlackJAX can take any log-probability function as long as it is compatible with JAX's JIT. In this notebook we show how we can use PyMC as a modeling language and BlackJAX as an inference library.

For this notebook to run you will need to install PyMC3:

```bash
pip install pymc3
```

In [1]:
# Higher versions will have omnistaging disabled which will throw errors when using theano
!pip install jax==0.2.10

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting jax==0.2.10
  Downloading jax-0.2.10.tar.gz (589 kB)
[K     |████████████████████████████████| 589 kB 170 kB/s 
Collecting absl-py
  Downloading absl_py-0.13.0-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 112 kB/s 
[?25hCollecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 120 kB/s 
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.2.10-py3-none-any.whl size=679776 sha256=7dc83d3259a6fd548f0d859d4f88410777fe89b81c4c924670d9985883fcba40
  Stored in directory: /tmp/pip-ephem-wheel-cache-43993739/wheels/42/76/93/4efdcc626990c35448e8163b3e988184f77a061ad0437f922d
Successfully built jax
Installing collected packages: opt-einsum, absl-py, jax
Successfully installed absl-py-0.13.0 jax-0.2.10 opt-einsum-3.3.0


In [2]:
import jax
import numpy as np
import pymc3 as pm
import pymc3.sampling_jax

import blackjax.nuts as nuts
import blackjax.stan_warmup as stan_warmup

print(f"Running on PyMC3 v{pm.__version__}")

Running on PyMC3 v3.11.2




## Data

Please refer to the [original TFP example](https://www.tensorflow.org/probability/examples/Eight_Schools) for a description of the problem and the model that is used.

In [3]:
# Data of the Eight Schools Model
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

## Model


In [5]:
with pm.Model() as model:

    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
    theta_1 = mu + tau * theta
    obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)

## Sampling using PyMC NUTS Sampler

In [6]:
%%time

with model:
    posterior = pm.sample(50_000, chains=1)

  
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [theta, tau, mu]


Sampling 1 chain for 1_000 tune and 50_000 draw iterations (1_000 + 50_000 draws total) took 54 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


CPU times: user 58.6 s, sys: 560 ms, total: 59.2 s
Wall time: 1min


## Sampling using PyMC JAX Numpyro NUTS sampler

In [None]:
%%time

with model:
    hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(
        50_000, target_accept=0.9, chains=1
    )

Compiling...


sample: 100%|██████████| 51000/51000 [02:57<00:00, 286.59it/s]


## Sampling using BlackJax

### Configuring the model for BlackJax


In [7]:
from theano.graph.fg import FunctionGraph
from theano.link.jax.jax_dispatch import jax_funcify

seed = jax.random.PRNGKey(1234)
chains = 1

# Get the FunctionGraph of the model.
fgraph = FunctionGraph(model.free_RVs, [model.logpt])

# Jax funcify builds Jax variant of the FunctionGraph.
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

# Now we build a Jax variant of the initial state/inputs to the model.
rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(
    lambda x: np.repeat(x[None, ...], chains, axis=0), init_state
)



In [8]:
# Then we transform the Jaxified input and FunctionGraph to a BlackJax NUTS sampler
logprob = lambda x: logp_fn_jax(*x)
initial_position = init_state
initial_state = nuts.new_state(initial_position, logprob)

### Sampling

In [10]:
%%time

kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    logprob, step_size, inverse_mass_matrix
)

last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
    seed, kernel_factory, initial_state, 1000
)


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)

# Sample from the posterior distribution
states, infos = inference_loop(seed, kernel, last_state, 50_000)

CPU times: user 18.3 s, sys: 126 ms, total: 18.5 s
Wall time: 18.5 s
