# Using Different Moves

We allow users to select different Hamiltonian & Vanilla moves.
You can access such moves in `hemcee.moves.hamiltonian` and `hemcee.moves.vanilla`.

We find the best performance with the `walk_move` in both the Hamiltonian and Vanilla settings, so we have the samplers default to them.

In [1]:
import hemcee
import jax
import jax.numpy as jnp

Here we make a test example, a highly skewed Gaussian

In [4]:
from hemcee.tests.distribution import make_gaussian_skewed

key = jax.random.PRNGKey(0)
dim = 10
cond_number = 1000

log_prob = make_gaussian_skewed(key, dim, cond_number)

Here's your options for Hamiltonian moves, and how to change them!

In [5]:
from hemcee.moves.hamiltonian.hmc_walk import hmc_walk_move
from hemcee.moves.hamiltonian.hmc_side import hmc_side_move

In [6]:
total_chains = dim * 4

sampler = hemcee.HamiltonianEnsembleSampler(
    total_chains= total_chains,
    dim=dim,
    log_prob=log_prob,
    move=hmc_walk_move # <- Plug and play different moves here!
)  

keys = jax.random.split(key, 2)
inital_states = jax.random.normal(keys[0], shape=(total_chains, dim))

samples = sampler.run_mcmc(
    key=keys[1],
    initial_state=inital_states,
    num_samples=10**5,
    warmup=10**5,
)

# You can compare the performance of different moves 
# by computing the integrated autocorrelation time
tau = hemcee.autocorr.integrated_time(samples)
print('Integrated autocorrelation time:')
print(tau)

Using 40 total chains: Group 1 (20), Group 2 (20)
Integrated autocorrelation time:
[1.3700383 1.3753891 1.3692183 1.3724241 1.3790314 1.364594  1.3714349
 1.3645306 1.3705964 1.3662345]


The same syntax goes for Vanilla moves

In [7]:
from hemcee.moves.vanilla.walk import walk_move
from hemcee.moves.vanilla.side import side_move
from hemcee.moves.vanilla.stretch import stretch_move

In [8]:
sampler = hemcee.EnsembleSampler(
    total_chains=total_chains,
    dim=dim,
    log_prob=log_prob,
    move=walk_move # <- Plug and play different moves here!
)

keys = jax.random.split(key, 2)
inital_states = jax.random.normal(keys[0], shape=(total_chains, dim))

samples = sampler.run_mcmc(
    key=keys[1],
    initial_state=inital_states,
    num_samples=10**5,
    warmup=10**5,
)

tau = hemcee.autocorr.integrated_time(samples)
print('Integrated autocorrelation time:')
print(tau)


Using 40 total chains: Group 1 (20), Group 2 (20)
Integrated autocorrelation time:
[36.823833 37.407433 37.187393 36.25884  36.523792 38.95756  37.184566
 38.007595 35.820908 39.53417 ]
