Remember to always run this cell first:

In [None]:
# We need to import numpyro first, though we use it last
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist

# Set the number of cores on your machine for parallelism:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

This sets up T-p-$\lambda$ grid:

In [None]:
import sys
sys.path.insert(0, '../')

import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
from frei import Planet, Grid, load_example_opacity
from frei.core import F_TOA
from frei.opacity import kappa
from jax import numpy as jnp
from frei.twostream import emit
import numpy as np
from jax.scipy.optimize import minimize
from jax import jit
from functools import partial
from corner import corner

# Define planetary system parameters
planet = Planet.from_hot_jupiter()

n_wavelengths = 10_000
lam = np.sort(
    np.random.uniform(1e-6, 10e-6, n_wavelengths)
) * u.m

# Define a grid in wavelength, pressure, and temperature; set temperature
grid = Grid(
    planet,
    lam=lam,
    n_layers=10,       # number of pressure layers
    T_ref=2400 * u.K, # reference temperature at 0.1 bar (~T_eff)
)

# Load synthetic opacities, for demonstration purposes only
grid.load_opacities(
    opacities=load_example_opacity(grid)
);

`k` is the opacity (in some units) with dimensions: (p, T, $\lambda$)

In [None]:
n_layers = len(grid.pressures)
n_wavelengths = len(grid.lam)
fluxes_down = jnp.zeros((n_layers, n_wavelengths))
fluxes_up = jnp.zeros((n_layers, n_wavelengths))


# flux at the top of the atmosphere (which matters
# when iterating for rad eq.)
F_toa = np.zeros_like(grid.lam.value)

# the offline opacity grid is defined here
opacity_grid_temperatures = grid.init_temperatures.si.value
# opacity_grid_temperatures = jnp.linspace(
#     grid.init_temperatures.si.value.min(), 
#     grid.init_temperatures.si.value.max(), 
#     10
# )
offline_opacities = kappa(
    grid.opacities, 
    opacity_grid_temperatures, 
    grid.pressures.si.value[::-1], 
    grid.lam.si.value, 
    grid.planet.m_bar.si.value
)

In [None]:
# grid of pressures in units of bar for 
# producing the T-p curve with a "sensible" alpha:
pressure_bar = jnp.array(grid.pressures.to(u.bar).value)

@partial(jit, static_argnums=np.arange(1, 10))
def emit_opt(
    p, 
    pressures=grid.pressures.si.value,
    lam=grid.lam.to(u.um).value, 
    F_TOA=F_toa, 
    g=grid.planet.g.si.value,
    m_bar=grid.planet.m_bar.si.value,
    alpha=grid.planet.alpha,
    presure_bar=pressure_bar, 
    offline_opacities=offline_opacities, 
    opacity_grid_temperatures=opacity_grid_temperatures
):
    T_ref, alpha = p

    temps = T_ref * jnp.power(pressure_bar / 0.1, alpha)

    return emit(
        offline_opacities=offline_opacities, 
        temperatures=temps, 
        pressures=pressures, 
        lam=lam, 
        F_TOA=F_TOA, 
        g=g, 
        m_bar=m_bar,
        alpha=alpha,
        opacity_grid_temperatures=opacity_grid_temperatures,
    )[0][-1]

In [None]:
plt.loglog(grid.lam, emit_opt([2400., 0.2]), '.')
plt.plot(grid.lam, emit_opt([2400., 0.10]), '.')
plt.plot(grid.lam, emit_opt([3000., 0.30]), '.')

In [None]:
%%timeit
emit_opt([2400., 0.105]).block_until_ready()

This is where we define the example synthetic spectrum to fit (i.e. "DATA")

In [None]:
y = jnp.array(emit_opt([2345.0, 0.1]))
yerr = 5e4 * jnp.sqrt(y)

In [None]:
def numpyro_model(y=y, yerr=yerr):
    alpha = numpyro.sample(
        'alpha', 
        dist.TwoSidedTruncatedDistribution(
            dist.Normal(loc=0.1, scale=0.05),
            low=0.08, high=0.2
        )
    )

    T_ref = numpyro.sample(
        'T_ref', 
        dist.TwoSidedTruncatedDistribution(
            dist.Normal(loc=2400, scale=100), 
            low=0.0, high=3000
        )
    )

    numpyro.sample(
       'obs', 
       dist.Normal(
           loc=emit_opt([T_ref, alpha]),
           scale=yerr),
       obs=y
    )

In [None]:
# Random numbers in jax are generated like this:
from jax.random import PRNGKey, split

rng_seed = 42
rng_keys = split(
    PRNGKey(rng_seed), 
    cpu_cores
)

# Define a sampler, using here the No U-Turn Sampler (NUTS)
# with a dense mass matrix:
sampler = NUTS(
    numpyro_model, 
    dense_mass=True
)

# Monte Carlo sampling for a number of steps and parallel chains: 
mcmc = MCMC(
    sampler, 
    num_warmup=100, 
    num_samples=500, 
    num_chains=4
)

# Run the MCMC
mcmc.run(rng_keys)

In [None]:
truths = [2345, 0.1]

# make a corner plot
corner(
    np.vstack([v for k, v in mcmc.get_samples().items()]).T, 
    quiet=True, 
    truths=truths
);