In [1]:
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)
from jax import jacfwd, jacrev

from jax import jit
import numpy as np
import jax_cosmo as jc
from likelihood import Likelihood
import matplotlib.pyplot as plt

import emcee


In [2]:
lhood = Likelihood()
fidparam = np.array([0.8159,  0.2589,  0.0486,  0.6774,  0.9667, -1., 2., 1.])
symbols = ['$\sigma_{8}$', '$\Omega_{c}$', '$\Omega_{b}$', '$h$', '$n_{s}$', '$w_{0}$', '$b_lbg$', '$b_int$']



Initialising likelihood
Initialisation Complete


In [3]:
def logl_func(p):
    sig8, o_c, o_b, h, n_s, w_0, b_lbg, b_int = p
    if(sig8 < 0.0 or sig8 > 1.1):
        return -np.inf
    elif(o_c < 0.001 or o_c > 0.99):
        return -np.inf
    elif(o_b < 0.001 or o_b > 0.1):
        return -np.inf
    elif(h < 0.1 or h > 1.1):
        return -np.inf
    elif(n_s < 0.1 or n_s > 1.1):
        return -np.inf
    elif(w_0 < -3.0 or w_0 > -0.3):
        return -np.inf
    elif(b_lbg < 0 or b_lbg > 30):
        return -np.inf
    elif(b_int < 0 or b_int > 30):
        return -np.inf
    else:
        return lhood.logLgauss(p)

In [4]:
logl_func(fidparam)

DeviceArray(-2360.69758583, dtype=float64)

In [5]:
ndim, nwalkers = 8, 200

In [6]:
sig8 = np.random.uniform(0.91, 0.92, nwalkers)
o_c = np.random.uniform(0.31, 0.32, nwalkers)
o_b = np.random.uniform(0.06, 0.07, nwalkers)
h = np.random.uniform(0.5, 0.51, nwalkers)
n_s = np.random.uniform(0.91, 0.92, nwalkers)
w0 = np.random.uniform(-1.6, -1.5, nwalkers)
b_lbg = np.random.uniform(5, 6, nwalkers)
b_int = np.random.uniform(5, 6, nwalkers)

In [7]:
p0 = np.transpose(np.vstack([sig8, o_c, o_b, h, n_s, w0, b_lbg, b_int]))

In [8]:
p0.shape

(200, 8)

In [9]:
filename = "mcmc_data.h5"
backend = emcee.backends.HDFBackend(filename)
backend.reset(nwalkers, ndim)

In [10]:
move = emcee.moves.StretchMove(2)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, ndim, logl_func, moves=move, backend=backend)
sampler.run_mcmc(p0, 10000)

In [None]:
samples = sampler.get_chain(flat=True)

In [None]:
print(
    "Mean acceptance fraction: {0:.3f}".format(
        np.mean(sampler.acceptance_fraction)
    )
)