Skip to content

Commit

Permalink
Switching to stellar spectrum from phoenix, fixing thermal emission p…
Browse files Browse the repository at this point in the history
…arameters
  • Loading branch information
bmorris3 committed Oct 15, 2021
1 parent 14a2320 commit 472fa2b
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 78 deletions.
5 changes: 3 additions & 2 deletions jaxon/hatp7.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def get_planet_params():
a_rp = float(a / rp)
rp_rstar = float(rp / rstar)
eclipse_half_dur = duration / period / 2

mstar = 1.56
mass = 1.84
return (
planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar
rstar.value, rho_star, rp_rstar, mstar, mass
)
16 changes: 8 additions & 8 deletions jaxon/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
cadence_duration = 30 * u.min


def get_light_curve(cadence=cadence):
def get_light_curve(quarter=None, cadence=cadence):
"""
Parameters
----------
cadence : str {'long', 'short'}
Kepler cadence mode
"""
(planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar) = get_planet_params()
rstar, rho_star, rp_rstar, mstar, mass) = get_planet_params()

lcf = search_lightcurve(
planet_name, mission="Kepler", cadence=cadence
# quarter=10
planet_name, mission="Kepler", cadence=cadence,
quarter=quarter
).download_all()

slc = lcf.stitch()
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_filter():
return filt_wavelength, filt_trans


def eclipse_model(cadence_duration=cadence_duration):
def eclipse_model(quarter=None, cadence_duration=cadence_duration):
"""
Compute the (static) eclipse model
Expand All @@ -110,14 +110,14 @@ def eclipse_model(cadence_duration=cadence_duration):
in-eclipse.
"""
(planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar) = get_planet_params()
phase, time, flux_normed, flux_normed_err = get_light_curve()
rstar, rho_star, rp_rstar, mstar, mass) = get_planet_params()
phase, time, flux_normed, flux_normed_err = get_light_curve(quarter=quarter)

with pm.Model():
# Define a Keplerian orbit using `exoplanet`:
orbit = xo.orbits.KeplerianOrbit(
period=period, t0=0, b=b, rho_star=rho_star.to(u.g / u.cm ** 3),
r_star=float(rstar / u.R_sun)
r_star=rstar
)

# Compute the eclipse model (no limb-darkening):
Expand Down
172 changes: 115 additions & 57 deletions jaxon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from .thermal import thermal_phase_curve
from .tp import get_Tarr, polynomial_order, element_number, Parr, dParr
from .spectrum import (
exojax_spectrum, res_vis, nus, wav, bb_star_transformed
# cnu_TiO, indexnu_TiO
exojax_spectrum, res_vis, nus, wav, stellar_spectrum, stellar_spectrum_vis
)
from .hatp7 import (
get_observed_depths, get_planet_params
Expand All @@ -24,37 +23,34 @@

__all__ = [
'model',
'run_mcmc'
'run_mcmc',
'get_model_kwargs'
]

filt_wavelength, filt_trans = get_filter()
phase, time, flux_normed, flux_normed_err = get_light_curve()

(planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar) = get_planet_params()

(all_depths, all_depths_errs, all_wavelengths,
kepler_mean_wl) = get_observed_depths()

model_kwargs = dict(
phase=phase.astype(floatX),
time=(time - time.mean()).astype(floatX),
y=flux_normed.astype(floatX),
yerr=flux_normed_err.astype(floatX),
eclipse_numpy=jnp.array(eclipse_model()).astype(floatX),
filt_wavelength=jnp.array(filt_wavelength.astype(floatX)),
filt_trans=jnp.array(filt_trans.astype(floatX)),
a_rs=a_rs, a_rp=a_rp, T_s=T_s,
n_temps=polynomial_order * element_number + 1,
res=res_vis
)

def estimate_ellipsoidal_amplitude(mass, rstar, mstar, period):
ellipsoidal_amplitude_estimate = (
mass / 0.077 * rstar ** 3 * mstar ** -2 * period ** -2
)
return ellipsoidal_amplitude_estimate


def estimate_doppler_amplitude(mass, mstar, period):
doppler_amplitude_estimate = (
mass / 0.37 * mstar**(-2/3) * period**(-1/3)
)
return doppler_amplitude_estimate



def model(
n_temps, phase, time, y, yerr, eclipse_numpy,
filt_wavelength, filt_trans, a_rs, a_rp, T_s,
nus=nus, wav=wav, Parr=Parr, dParr=dParr,
bb_star_transformed=bb_star_transformed,
filt_wavelength, filt_trans, a_rs, a_rp, T_s, rprs,
mstar, mass, period, rstar, nus=nus, wav=wav, Parr=Parr, dParr=dParr,
stellar_spectrum=stellar_spectrum,
res=res_vis,
predict=False
):
Expand Down Expand Up @@ -85,6 +81,16 @@ def model(
Semimajor axis normalized by the planetary radius
T_s : float
Stellar effective temperature
rprs : float
Radius ratio
mstar : float
Stellar mass in solar masses
mass : float
Planet mass in Jupiter masses
period : float
Orbital period [d]
rstar : float
Stellar radius in solar radii
nus : numpy.ndarray
Frequencies sampled in the spectrum
wav : numpy.ndarray
Expand All @@ -93,8 +99,8 @@ def model(
Pressure array at each temperature in the T-P profile
dParr : numpy.ndarray
Delta pressure array at each temperature in the T-P profile
bb_star_transformed : numpy.ndarray
Planck function of the star
stellar_spectrum_vis : numpy.ndarray
Spectrum of the star
res : float
Spectral resolution
predict : bool
Expand Down Expand Up @@ -125,73 +131,97 @@ def model(
reflected_ppm = jnp.interp(phase, phases_grid, reflected_ppm_grid)

numpyro.deterministic('A_g', A_g)
# numpyro.deterministic('q', q)

ellipsoidal_amp_estimate = estimate_ellipsoidal_amplitude(
mass, rstar, mstar, period
)
doppler_amp_estimate = estimate_doppler_amplitude(
mass, mstar, period
)

# Define the ellipsoidal variation parameterization (simple sinusoid)
ellipsoidal_amp = numpyro.sample(
'ellip_amp', dist.Uniform(low=0, high=100)
'ellip_amp',
dist.TwoSidedTruncatedDistribution(
dist.Normal(
loc=ellipsoidal_amp_estimate,
scale=ellipsoidal_amp_estimate/4
), low=0, high=100
)
)
ellipsoidal_model_ppm = - ellipsoidal_amp * jnp.cos(
4 * np.pi * (phase - 0.5)) + ellipsoidal_amp

# Define the doppler variation parameterization (simple sinusoid)
doppler_amp = numpyro.sample('doppler_amp', dist.Uniform(low=0, high=20))
doppler_model_ppm = doppler_amp * jnp.sin(
2 * np.pi * phase)
doppler_amp = numpyro.sample(
'doppler_amp',
dist.TwoSidedTruncatedDistribution(
dist.Normal(
loc=doppler_amp_estimate,
scale=doppler_amp_estimate/4
), low=0, high=10
)
)
doppler_model_ppm = doppler_amp * jnp.sin(2 * np.pi * phase)

# Define the thermal emission model according to description in
# Morris et al. (in prep)
# floatX = 'float32'
n_phi = 75
n_theta = 7
n_phi = 150
n_theta = 10
phi = jnp.linspace(-2 * np.pi, 2 * np.pi, n_phi, dtype=floatX)
theta = jnp.linspace(0, np.pi, n_theta, dtype=floatX)
theta2d, phi2d = jnp.meshgrid(theta, phi)

# ln_C_11_kepler = -2.6
C_11_kepler = numpyro.sample('C_11', dist.Uniform(low=0, high=0.5))
hml_eps = numpyro.sample('epsilon', dist.Uniform(low=0, high=8 / 5))
hml_f = (2 / 3 - hml_eps * 5 / 12) ** 0.25
delta_phi = numpyro.sample('delta_phi', dist.Uniform(low=-np.pi, high=0))

C_11_kepler = 0.35 # numpyro.sample('C_11', dist.Uniform(low=0, high=0.55))
# hml_eps = numpyro.sample('epsilon', dist.Uniform(low=0, high=8 / 5))
hml_f = 0.73 #(2 / 3 - hml_eps * 5 / 12) ** 0.25
delta_phi = 0 #numpyro.sample(
# 'delta_phi',
# dist.TwoSidedTruncatedDistribution(
# dist.Normal(loc=0, scale=0.05),
# low=-np.pi/4, high=np.pi/4
# )
# )
A_B = 0.0

# Compute the thermal phase curve with zero phase offset
thermal_grid, temp_map = thermal_phase_curve(
xi_grid, delta_phi, 4.5, 0.575, C_11_kepler, T_s, a_rs, 1 / a_rp, A_B,
xi_grid, delta_phi, 4.5, 0.6, C_11_kepler, T_s, a_rs, 1 / a_rp, A_B,
theta2d, phi2d, filt_wavelength, filt_trans, hml_f
)

# thermal = interpolate(xi_grid, 1e6 * thermal_grid, xi)
thermal = jnp.interp(xi, xi_grid, 1e6 * thermal_grid)

# epsilon = 8 * nightside**4 / (3 * dayside**4 + 5 * nightside**4)
f = (2 / 3 - hml_eps * 5 / 12) ** 0.25
# f = (2 / 3 - hml_eps * 5 / 12) ** 0.25

numpyro.deterministic('f', f)
# numpyro.deterministic('f', f)
# numpyro.deterministic('epsilon', epsilon)

# Define the composite phase curve model
flux_norm = (eclipse_numpy *
(reflected_ppm + thermal) +
doppler_model_ppm + ellipsoidal_model_ppm
)
(reflected_ppm + thermal) + doppler_model_ppm + ellipsoidal_model_ppm
)

flux_norm -= jnp.mean(flux_norm)

sigma = numpyro.sample(
"sigma", dist.TwoSidedTruncatedDistribution(
dist.Normal(loc=y.ptp(), scale=y.std()), low=0, high=4 * y.ptp()
dist.Normal(loc=y.std(), scale=y.std()/10),
low=0, high=1000 * y.ptp()
)
)
kernel = terms.Matern32Term(sigma=sigma, rho=22)
jitter = 0 # numpyro.sample('jitter', dist.Uniform(low=0, high=y.ptp()))
jitter = numpyro.sample('jitter', dist.Uniform(low=0, high=100))
gp = GaussianProcess(kernel, mean=flux_norm)
gp.compute(time, yerr=jnp.sqrt(yerr ** 2 + jitter ** 2),
check_sorted=False)

if predict:
gp.condition(y)
pred = gp.predict(time)
pred = gp.predict(y)
numpyro.deterministic("therm", thermal)
numpyro.deterministic("ellip", ellipsoidal_model_ppm)
numpyro.deterministic("doppl", doppler_model_ppm)
Expand All @@ -200,18 +230,18 @@ def model(
numpyro.deterministic("resid", y - pred)
numpyro.deterministic("pred", pred)

log_vmr_prod = numpyro.sample('log_vmr_prod',
dist.Uniform(low=-10, high=-4))

mmr_TiO = numpyro.sample("mmr_TiO", dist.Uniform(low=-9, high=-2))
# log_vmr_prod = numpyro.sample('log_vmr_prod',
# dist.Uniform(low=-10, high=-4))
vmr_prod = 1e-6
mmr_TiO = 1e-6 #numpyro.sample("mmr_TiO", dist.Uniform(low=-9, high=-2))

Tarr = get_Tarr(temps, Parr)
Fcgs, _, _ = exojax_spectrum(
temps, jnp.power(10, log_vmr_prod), jnp.power(10, mmr_TiO),
Parr, dParr, nus, wav, res
temps, vmr_prod, mmr_TiO,
Parr, dParr, nus, wav
)

fpfs_spectrum = rprs ** 2 * Fcgs / bb_star_transformed.value
fpfs_spectrum = rprs ** 2 * Fcgs / stellar_spectrum

interp_depths = jnp.interp(
all_wavelengths, wav / 1000, fpfs_spectrum
Expand Down Expand Up @@ -247,7 +277,31 @@ def model(
)


def run_mcmc(run_title='tmp', num_warmup=5, num_samples=10):
def get_model_kwargs(quarter=None):
phase, time, flux_normed, flux_normed_err = get_light_curve(quarter=quarter)

filt_wavelength, filt_trans = get_filter()

(planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar, mstar, mass) = get_planet_params()

model_kwargs = dict(
phase=phase.astype(floatX),
time=(time - time.mean()).astype(floatX),
y=flux_normed.astype(floatX),
yerr=flux_normed_err.astype(floatX),
eclipse_numpy=jnp.array(eclipse_model(quarter=quarter)).astype(floatX),
filt_wavelength=jnp.array(filt_wavelength.astype(floatX)),
filt_trans=jnp.array(filt_trans.astype(floatX)),
a_rs=a_rs, a_rp=a_rp, T_s=T_s, period=period,
mass=mass, mstar=mstar, rstar=rstar,
n_temps=polynomial_order * element_number + 1,
res=res_vis, rprs=rprs
)
return model_kwargs


def run_mcmc(run_title='tmp', num_warmup=5, num_samples=10, quarter=10):
"""
Run MCMC with the NUTS via numpyro.
Expand All @@ -259,8 +313,12 @@ def run_mcmc(run_title='tmp', num_warmup=5, num_samples=10):
Number of iterations in the burn-in phase
num_samples : int
Number of iterations of the sampler
quarter : int, list of ints
Kepler quarters to fit
"""
print('Start MCMC')
model_kwargs = get_model_kwargs(quarter=quarter)

print(f'Start MCMC, n chains = {len(jax.devices())}')
mcmc = MCMC(
sampler=NUTS(
model, dense_mass=True,
Expand Down

0 comments on commit 472fa2b

Please sign in to comment.