Skip to content

Commit

Permalink
first code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Oct 6, 2021
1 parent 1058a92 commit 0bac7c1
Show file tree
Hide file tree
Showing 10 changed files with 1,762 additions and 39 deletions.
50 changes: 11 additions & 39 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,44 +1,16 @@
Phase curve models in JAX
-------------------------

License
-------

This project is Copyright (c) Brett M. Morris and licensed under
the terms of the GNU GPL v3+ license. This package is based upon
the `Openastronomy packaging guide <https://github.com/OpenAstronomy/packaging-guide>`_
which is licensed under the BSD 3-clause licence. See the licenses folder for
more information.

jaxon
=====

Contributing
------------
.. image:: https://readthedocs.org/projects/jaxon/badge/?version=latest
:target: https://jaxon.readthedocs.io/en/latest/?badge=latest
:alt: Documentation Status

We love contributions! jaxon is open source,
built on open source, and we'd love to have you hang out in our community.
.. image:: http://img.shields.io/badge/powered%20by-AstroPy-orange.svg?style=flat
:target: http://www.astropy.org
:alt: Powered by Astropy Badge

**Imposter syndrome disclaimer**: We want your help. No, really.
.. image:: https://github.com/bmorris3/jaxon/workflows/CI%20Tests/badge.svg
:target: https://github.com/bmorris3/jaxon/actions

There may be a little voice inside your head that is telling you that you're not
ready to be an open source contributor; that your skills aren't nearly good
enough to contribute. What could you possibly offer a project like this one?

We assure you - the little voice in your head is wrong. If you can write code at
all, you can contribute code to open source. Contributing to open source
projects is a fantastic way to advance one's coding skills. Writing perfect code
isn't the measure of a good developer (that would disqualify all of us!); it's
trying to create something, making mistakes, and learning from those
mistakes. That's how we all improve, and we are happy to help others learn.

Being an open source contributor doesn't just mean writing code, either. You can
help out by writing documentation, tests, or even giving feedback about the
project (and yes - that includes giving feedback about the contribution
process). Some of these contributions may be the most valuable to the project as
a whole, because you're coming to the project with fresh eyes, so you can see
the errors and assumptions that seasoned contributors have glossed over.
Phase curve models in JAX

Note: This disclaimer was originally written by
`Adrienne Lowe <https://github.com/adriennefriend>`_ for a
`PyCon talk <https://www.youtube.com/watch?v=6Uj746j9Heo>`_, and was adapted by
jaxon based on its use in the README file for the
`MetPy project <https://github.com/Unidata/MetPy>`_.
157 changes: 157 additions & 0 deletions jaxon/continuum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from jax import jit, numpy as jnp, lax

CONST_K, CONST_C, CONST_H = 1.380649e-16, 29979245800.0, 6.62607015e-27 # cgs

@jit
def log_hminus_continuum(wavelength_um, temperature, pressure,
volume_mixing_ratio_product, truncation_value=-100):

# first, compute the cross sections (in cm4/dyne)
kappa_bf = bound_free_absorption(wavelength_um, temperature)
kappa_ff = free_free_absorption(wavelength_um, temperature)

absorption_coeff = (
(kappa_bf + kappa_ff) * volume_mixing_ratio_product *
jnp.atleast_2d(pressure).T
)
truncate_small = jnp.where(
absorption_coeff != 0,
jnp.log10(absorption_coeff),
truncation_value
)

return truncate_small


@jit
def bound_free_absorption(wavelength_um, temperature):
# Note: alpha has a value of 1.439e4 micron-1 K-1, the value stated in John (1988) is wrong
# here, we express alpha using physical constants
alpha = CONST_C * CONST_H / CONST_K * 10000.0
lambda_0 = 1.6419 # photo-detachment threshold

# //tabulated constant from John (1988)
def f(wavelength_um):
C_n = jnp.vstack(
[jnp.arange(1, 7),
[152.519, 49.534, -118.858, 92.536, -34.194, 4.982]]
).T

def body_fun(val, x):
i, C_n_i = x
base = jnp.where(1.0 / wavelength_um - 1.0 / lambda_0 > 0,
1.0 / wavelength_um - 1.0 / lambda_0, 0)
return val, val + C_n_i * jnp.power(base, (i - 1) / 2.0)

return lax.scan(body_fun, jnp.zeros_like(wavelength_um), C_n)[-1].sum(
0)

# first, we calculate the photo-detachment cross-section (in cm2)
kappa_bf = (1e-18 * wavelength_um ** 3 *
jnp.power(
jnp.clip(1.0 / wavelength_um - 1.0 / lambda_0, a_min=0,
a_max=None), 1.5) * f(wavelength_um)
)

kappa_bf = jnp.where(
(wavelength_um <= lambda_0) & (wavelength_um > 0.125),
(0.750 * jnp.power(temperature, -2.5) * jnp.exp(
alpha / lambda_0 / temperature) *
(1.0 - jnp.exp(-alpha / wavelength_um / temperature)) * kappa_bf),
0
)
return kappa_bf


@jit
def free_free_absorption(wavelength_um, temperature):
# coefficients from John (1988)
# to follow his notation (which starts at an index of 1), the 0-index components are 0
# for wavelengths larger than 0.3645 micron
A_n1 = jnp.array([2483.3460, -3449.8890, 2200.0400, -696.2710, 88.2830])
B_n1 = jnp.array([285.8270, -1158.3820, 2427.7190, -1841.4000, 444.5170])
C_n1 = jnp.array(
[-2054.2910, 8746.5230, -13651.1050, 8624.9700, -1863.8650])
D_n1 = jnp.array(
[2827.7760, -11485.6320, 16755.5240, -10051.5300, 2095.2880])
E_n1 = jnp.array([-1341.5370, 5303.6090, -7510.4940, 4400.0670, -901.7880])
F_n1 = jnp.array([208.9520, -812.9390, 1132.7380, -655.0200, 132.9850])

# for wavelengths between 0.1823 micron and 0.3645 micron
A_n2 = jnp.array([518.1021, 473.2636, -482.2089, 115.5291, 0.0, 0.0])
B_n2 = jnp.array([-734.8666, 1443.4137, -737.1616, 169.6374, 0.0, 0.0])
C_n2 = jnp.array([1021.1775, -1977.3395, 1096.8827, -245.6490, 0.0, 0.0])
D_n2 = jnp.array([-479.0721, 922.3575, -521.1341, 114.2430, 0.0, 0.0])
E_n2 = jnp.array([93.1373, -178.9275, 101.7963, -21.9972, 0.0, 0.0])
F_n2 = jnp.array([-6.4285, 12.3600, -7.0571, 1.5097, 0.0, 0.0])

coeffs1 = jnp.vstack([
jnp.arange(2, 7), A_n1, B_n1, C_n1, D_n1, E_n1, F_n1
]).T

coeffs2 = jnp.vstack([
jnp.arange(1, 7), A_n2, B_n2, C_n2, D_n2, E_n2, F_n2
]).T

def body_fun(val, x):
i, A_n_i, B_n_i, C_n_i, D_n_i, E_n_i, F_n_i = x
return val, val + (jnp.power(5040.0 / temperature, (i + 1) / 2.0) *
(wavelength_um ** 2 * A_n_i + B_n_i + C_n_i /
wavelength_um + D_n_i / wavelength_um ** 2 +
E_n_i / wavelength_um ** 3 + F_n_i /
wavelength_um ** 4))

kappa_ff = jnp.where(
wavelength_um > 0.3645,
lax.scan(body_fun, jnp.zeros_like(wavelength_um), coeffs1)[-1].sum(
0) * 1e-29,
0
) + jnp.where(
(wavelength_um >= 0.1823) & (wavelength_um <= 0.3645),
lax.scan(body_fun, jnp.zeros_like(wavelength_um), coeffs2)[-1].sum(
0) * 1e-29,
0
)

return kappa_ff


@jit
def dtauHminusCtm(nus, Tarr, Parr, dParr, volume_mixing_ratio_product, mmw, g):
"""dtau of the H- continuum
Args:
nus: wavenumber matrix (cm-1)
Tarr: temperature array (K)
Parr: temperature array (bar)
dParr: delta temperature array (bar)
volume_mixing_ratio_product: number density for e- times number density for H [N_layer]
mmw: mean molecular weight of atmosphere
g: gravity (cm2/s)
nucia: wavenumber array for CIA
tcia: temperature array for CIA
logac: log10(absorption coefficient of CIA)
Returns:
optical depth matrix [N_layer, N_nus]
Note:
logm_ucgs=np.log10(m_u*1.e3) where m_u = scipy.constants.m_u.
"""
kB = 1.380649e-16
logm_ucgs = -23.779750909492115

narr = (Parr * 1.e6) / (kB * Tarr)
lognarr1 = jnp.log10(narr) # log number density

logkb = jnp.log10(kB)
logg = jnp.log10(g)
ddParr = dParr / Parr
wavelength_um = 1e4 / nus[::-1]
dtauctm = (10 ** (log_hminus_continuum(wavelength_um, Tarr[:, None], Parr,
volume_mixing_ratio_product)
+ lognarr1[:, None] + logkb - logg - logm_ucgs)
* Tarr[:, None] / mmw * ddParr[:, None])

return dtauctm
78 changes: 78 additions & 0 deletions jaxon/hatp7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import astropy.units as u
from astropy.constants import G
from astropy.modeling.models import BlackBody

from .spectrum import wav, wav_vis

planet_name = "HAT-P-7"

# Mansfield 2018
lines = """1.120–1.158 0.0334±0.0037
1.158–1.196 0.0413±0.0038
1.196–1.234 0.0404±0.0037
1.234–1.271 0.0501±0.0037
1.271–1.309 0.0503±0.0038
1.309–1.347 0.0498±0.0037
1.347–1.385 0.0530±0.0037
1.385–1.423 0.0510±0.0037
1.423–1.461 0.0547±0.0039
1.461–1.499 0.0621±0.0041
1.499–1.536 0.0607±0.0042
1.536–1.574 0.0593±0.0044
1.574–1.612 0.0594±0.0046
1.612–1.650 0.0593±0.0045""".splitlines()
central_wl = ([np.array(list(map(float, line.split(' ')[0].split('–')))).mean()
for line in lines]) * u.um
depths = np.array([float(line.split(' ')[1].split('±')[0][:-1])
for line in lines]) * 1e-2
depths_err = np.array([float(line.split(' ')[1].split('±')[1][1:])
for line in lines]) * 1e-2

spitzer_wl = [3.6, 4.5]
spitzer_depth = 1e-2 * np.array([0.161, 0.186])
spitzer_depth_err = 1e-2 * np.array([0.014, 0.008])

# plt.plot(
# central_wl, depths, 'o'
# )
# plt.plot(
# spitzer_wl, spitzer_depth, 'o'
# )

kepler_mean_wl = [0.641] # um
kepler_depth = [19e-6] # eyeballed
kepler_depth_err = [10e-6]

all_depths = np.concatenate([depths, spitzer_depth])
all_depths_errs = np.concatenate([depths_err, spitzer_depth_err])
all_wavelengths = np.concatenate([central_wl.value, spitzer_wl])

bb_star = BlackBody(temperature=6300*u.K)

bb_star_transformed = (bb_star(wav*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav*u.nm)
) * np.pi

bb_star_transformed_vis = (bb_star(wav_vis*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav_vis*u.nm)
) * np.pi

rprs = float(1.431*u.R_jup / (2.00 * u.R_sun))

g = (G * u.M_jup / u.R_jup**2).to(u.cm/u.s**2).value

t0 = 2454954.357462 # Bonomo 2017
period = 2.204740 # Stassun 2017
rp = 16.9 * u.R_earth # Stassun 2017
rstar = 1.991 * u.R_sun # Berger 2017
a = 4.13 * rstar # Stassun 2017
duration = 4.0398 / 24 # Holczer 2016
b = 0.4960 # Esteves 2015
rho_star = 0.27 * u.g / u.cm ** 3 # Stassun 2017
T_s = 6449 # Berger 2018

a_rs = float(a / rstar)
a_rp = float(a / rp)
rp_rstar = float(rp / rstar)
eclipse_half_dur = duration / period / 2
83 changes: 83 additions & 0 deletions jaxon/lightcurve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
from lightkurve import search_lightcurve
from astropy.stats import sigma_clip, mad_std
import astropy.units as u
from kelp import Filter
import exoplanet as xo
import pymc3 as pm
import pymc3_ext as pmx

from .utils import floatX
from .hatp7 import (
planet_name, t0, period, eclipse_half_dur, b, rstar, rho_star, rp_rstar
)

lcf = search_lightcurve(
planet_name, mission="Kepler", cadence="long"#, quarter=10 #[10, 11, 12]
).download_all()

slc = lcf.stitch()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < 1.5 * eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)# | in_eclipse)

slc = slc.flatten(
polyorder=3, break_tolerance=10, window_length=1001, mask=~out_of_transit
).remove_nans()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < 1.5 * eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)# | in_eclipse)

sc = sigma_clip(
np.ascontiguousarray(slc.flux[out_of_transit], dtype=floatX),
maxiters=100, sigma=8, stdfunc=mad_std
)

phase = np.ascontiguousarray(
phases[out_of_transit][~sc.mask], dtype=floatX
)
time = np.ascontiguousarray(
slc.time.jd[out_of_transit][~sc.mask], dtype=floatX
)

bin_in_eclipse = np.abs(phase - 0.5) < eclipse_half_dur
unbinned_flux_mean = np.mean(sc[~sc.mask].data) # .mean()

unbinned_flux_mean_ppm = 1e6 * (unbinned_flux_mean - 1)
flux_normed = np.ascontiguousarray(
1e6 * (sc[~sc.mask].data / unbinned_flux_mean - 1.0), dtype=floatX
)
flux_normed_err = np.ascontiguousarray(
1e6 * slc.flux_err[out_of_transit][~sc.mask].value, dtype=floatX
)

filt = Filter.from_name("Kepler")
filt.bin_down(4) # This speeds up integration by orders of magnitude
filt_wavelength, filt_trans = filt.wavelength.to(u.m).value, filt.transmittance

with pm.Model() as model:
# Define a Keplerian orbit using `exoplanet`:
orbit = xo.orbits.KeplerianOrbit(
period=period, t0=0, b=b, rho_star=rho_star.to(u.g / u.cm ** 3),
r_star=float(rstar / u.R_sun)
)

# Compute the eclipse model (no limb-darkening):
eclipse_light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(
orbit=orbit._flip(rp_rstar), r=orbit.r_star,
t=phase * period,
texp=(30 * u.min).to(u.d).value
)

# Normalize the eclipse model to unity out of eclipse and
# zero in-eclipse
eclipse = 1 + pm.math.sum(eclipse_light_curves, axis=-1)

eclipse_numpy = pmx.eval_in_model(eclipse)

0 comments on commit 0bac7c1

Please sign in to comment.