# Analyzing sampling result

In this tutorial, we will show how to analyze the sampling result from `flowMC` using `arviz`. In particular, we will look at $\hat{R}$ and effective sample size (ESS) to check convergence and efficiency of the sampling result.

In [2]:
from flowMC.nfmodel.rqSpline import RQSpline
from flowMC.sampler.MALA import make_mala_sampler

import jax
import jax.numpy as jnp  # JAX NumPy
from jax.scipy.special import logsumexp
import numpy as np

from flowMC.sampler.Sampler import Sampler
from flowMC.utils.PRNG_keys import initialize_rng_keys

from flowMC.nfmodel.utils import *


def dual_moon_pe(x):
    """
    Term 2 and 3 separate the distribution and smear it along the first and second dimension
    """
    term1 = 0.5 * ((jnp.linalg.norm(x) - 2) / 0.1) ** 2
    term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
    term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
    return -(term1 - logsumexp(term2) - logsumexp(term3))


### Demo config

n_dim = 5
n_chains = 20
n_loop_training = 5
n_loop_production = 5
n_local_steps = 100
n_global_steps = 100
learning_rate = 0.01
momentum = 0.9
num_epochs = 30
batch_size = 10000

print("Preparing RNG keys")
rng_key_set = initialize_rng_keys(n_chains, seed=42)

print("Initializing MCMC model and normalizing flow model.")

initial_position = jax.random.normal(rng_key_set[0], shape=(n_chains, n_dim)) * 1
model = RQSpline(n_dim, 10, [128, 128], 8)
local_sampler_caller = lambda x: make_mala_sampler(x, jit=True)

print("Initializing sampler class")

nf_sampler = Sampler(
    n_dim,
    rng_key_set,
    local_sampler_caller,
    {'dt':1e-1},
    dual_moon_pe,
    model,
    n_loop_training=n_loop_training,
    n_loop_production=n_loop_production,
    n_local_steps=n_local_steps,
    n_global_steps=n_global_steps,
    n_chains=n_chains,
    n_epochs=num_epochs,
    learning_rate=learning_rate,
    momentum=momentum,
    batch_size=batch_size,
    use_global=True,
)

print("Sampling")

nf_sampler.sample(initial_position)

summary = nf_sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values() 

Preparing RNG keys
Initializing MCMC model and normalizing flow model.
Initializing sampler class
Sampling
No autotune found, use input sampler_params
Training normalizing flow


Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 104.00it/s]
Training NF, current loss: 7.316: 100%|██████████| 30/30 [00:11<00:00,  2.54it/s] 
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 394.74it/s]
Training NF, current loss: 5.818: 100%|██████████| 30/30 [00:01<00:00, 24.31it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 302.29it/s]
Training NF, current loss: 5.393: 100%|██████████| 30/30 [00:00<00:00, 32.58it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 312.81it/s]
Training NF, current loss: 5.615: 100%|██████████| 30/30 [00:00<00:00, 33.93it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 412.45it/s]
Training NF, current loss: 5.049: 100%|██████████| 30/30 [00:00<00:00, 40.89it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 270.95it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 293.42it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00, 421.38it/s]
Sampling Locally: 100%|██████████| 99/99 [00:00<00:00