Skip to content

Commit

Permalink
Use blackjax.util.run_inference.
Browse files Browse the repository at this point in the history
We prefer library-provided functions where possible.

PiperOrigin-RevId: 607999912
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Feb 17, 2024
1 parent f4d97ea commit ae6ec6e
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def _blackjax_adapt(
return last_state, parameters


# TODO(colcarroll): Use blackjax.util.run_inference_algorithm here.
def _blackjax_inference(
seed,
adapt_state,
Expand All @@ -173,24 +172,14 @@ def _blackjax_inference(
"""

algorithm_kwargs = kwargs[algorithm] | adapt_parameters
kernel = algorithm(**algorithm_kwargs).step

@jax.jit
def inference_loop(rng_key):

def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)

keys = jax.random.split(rng_key, num_draws)
_, (states, infos) = jax.lax.scan(one_step, adapt_state, keys)

return states, infos

# Functions returned by chees adaptation.
adapt_parameters.pop("next_random_arg_fn", None)
adapt_parameters.pop("integration_steps_fn", None)
return inference_loop(seed), adapt_parameters
inference_algorithm = algorithm(**algorithm_kwargs)
_, states, infos = blackjax.util.run_inference_algorithm(
rng_key=seed,
initial_state_or_position=adapt_state,
inference_algorithm=inference_algorithm,
num_steps=num_draws,
progress_bar=False)
return states, infos


def _blackjax_inference_loop(
Expand All @@ -210,7 +199,7 @@ def _blackjax_inference_loop(
adapt_parameters,
algorithm,
num_draws,
kwargs)
kwargs), adapt_parameters


def _blackjax_stats_to_dict(sample_stats, potential_energy, adapt_parameters):
Expand Down Expand Up @@ -362,7 +351,7 @@ def _sample_blackjax_dynamic(
map_seed = jax.random.split(seed, num_chains)
mapped_sampler = shared.map_fn(chain_method, sampler)

(states, stats), adapt_parameters = mapped_sampler(map_seed, adapt_state)
states, stats = mapped_sampler(map_seed, adapt_state)
draws = transform_fn(states.position)
if extra_parameters["return_pytree"]:
return draws
Expand Down

0 comments on commit ae6ec6e

Please sign in to comment.