# A quick introduction to Blackjax

In [32]:
import numpy as np
import functools as ft
import blackjax.hmc as hmc
import jax.numpy as jnp
import jax.scipy.stats as stats
import jax

## Generate some data to work with

In [33]:
loc, scale = 10, 20
observed = np.random.normal(10, 20, size=100)

## Create a potential function

In [34]:
def potential_fn(loc, scale, observed=observed):
    """Univariate Normal"""
    logpdf = stats.norm.logpdf(observed, loc, scale)
    return -jnp.sum(logpdf)

potential = lambda x: potential_fn(**x)


## Set an initial state

In [35]:
initial_position = {"loc": 1.0, "scale": 2.0}
initial_state = hmc.new_state(initial_position, potential)
initial_state

HMCState(position={'loc': 1.0, 'scale': 2.0}, potential_energy=DeviceArray(6243.8994, dtype=float32), potential_energy_grad={'loc': DeviceArray(-245.63074, dtype=float32), 'scale': DeviceArray(-6032.691, dtype=float32)})

## Set some sampler parameters

In [36]:
inv_mass_matrix = np.array([1.0, 1.0])
inv_mass_matrix

array([1., 1.])

In [37]:
params = hmc.HMCParameters(
    num_integration_steps=90, step_size=1e-3, inv_mass_matrix=inv_mass_matrix
)

## Combine both into a kernel

In [38]:
kernel = hmc.kernel(potential, params)

## Create an inference loop

In [39]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

## Run Sampling

In [40]:
rng_key = jax.random.PRNGKey(19)
states = inference_loop(rng_key, kernel, initial_state, 20_000)

loc_samples = states.position["loc"][5000:]
scale_samples = states.position["scale"][5000:]

In [41]:
np.mean(loc_samples)

DeviceArray(10.872661, dtype=float32)

In [42]:
np.mean(scale_samples)

DeviceArray(20.161758, dtype=float32)