# Exploration of the jaxns library

In [None]:
import sys, os
sys.path.append("..")
from models.JAXHRD import hybrid_rosenbrock
import numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

import pylab as plt
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
from jax import vmap
import jax

from jaxns import DefaultNestedSampler
from jaxns import Model
from jaxns import Prior
from jaxns import bruteforce_evidence
from jaxns import TerminationCondition
from corner import corner
from jaxns.utils import resample

In [None]:
# Define Hybrid Rosenbrock model
n2 = 3
n1 = 4
DoF = n2 * (n1 - 1) + 1
B = np.zeros(DoF)
B[0] = 30
B[1:] = 20
mu=1
model = hybrid_rosenbrock(n2, n1, mu, B, seed=35)

# Draw i.i.d samples
iid_samples = model.newDrawFromPosterior(50000000)
truth_table = ((iid_samples > model.lower_bound) & (iid_samples < model.upper_bound))
idx = np.where(np.all(truth_table, axis=1))[0]
print('%i samples obtained from rejection sampling' % idx.shape[0])
bounded_iid_samples = iid_samples[idx]

In [41]:
# Setup sampler
tfpd = tfp.distributions

def prior_model():
    x = yield Prior(tfpd.Uniform(low=model.lower_bound, high=model.upper_bound), name='x')
    return x

log_like = lambda x: -1 * model.getMinusLogPosterior(x)

jaxns_model = Model(prior_model=prior_model, log_likelihood=log_like)

# Create the nested sampler class. In this case without any tuning.
ns = DefaultNestedSampler(model=jaxns_model, max_samples=1e6)

ns_jit = jax.jit(ns)

In [42]:
%%time
termination_reason, state = ns_jit(random.PRNGKey(420))
results = ns.to_results(termination_reason=termination_reason, state=state)



CPU times: user 41.9 s, sys: 495 ms, total: 42.4 s
Wall time: 36.3 s


In [44]:
ns.summary(results)

--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 252209
samples: 16430
phantom samples: 13250.0
likelihood evals / sample: 15.4
phantom fraction (%): 80.6%
--------
logZ=-11.79 +- 0.25
H=870.0
ESS=1908.9117258599542
--------
x[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x[0]: 1.022 +- 0.05 | 0.957 / 1.025 / 1.084 | 1.0 | 1.0
x[1]: 1.02 +- 0.12 | 0.86 / 1.02 / 1.17 | 1.01 | 1.01
x[2]: 1.03 +- 0.23 | 0.73 / 1.02 / 1.35 | 1.0 | 1.0
x[3]: 1.11 +- 0.48 | 0.53 / 1.06 / 1.76 | 1.0 | 1.0
x[4]: 1.057 +- 0.091 | 0.934 / 1.058 / 1.176 | 0.993 | 0.993
x[5]: 1.13 +- 0.17 | 0.91 / 1.13 / 1.39 | 0.98 | 0.98
x[6]: 1.33 +- 0.39 | 0.85 / 1.27 / 1.94 | 0.97 | 0.97
x[7]: 1.08 +- 0.11 | 0.93 / 1.09 / 1.23 | 0.94 | 0.94
x[8]: 1.21 +- 0.22 | 0.9 / 1.22 / 1.51 | 0.91 | 0.91
x[9]: 1.53 +- 0.53 | 0.84 / 1.49 / 2.36 | 0.85 | 0.85
--------


In [None]:
# Compare samples to i.i.d algorithm
samples_jaxns = np.array(resample(random.PRNGKey(42), results.samples, results.log_dp_mean, S=int(results.ESS))['x'])
print('Number of samples obtained: %i' % samples_jaxns.shape[0])
fig1 = corner(bounded_iid_samples[0:30000], hist_kwargs={'density':True})
corner(np.array(samples_jaxns), color='r', fig=fig1, hist_kwargs={'density':True})