Skip to content

Commit

Permalink
Adding docs placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Oct 6, 2021
1 parent b6b0df3 commit f240c0c
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 83 deletions.
12 changes: 11 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
Expand Down Expand Up @@ -54,7 +55,12 @@
# -- Options for intersphinx extension ---------------------------------------

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None}
intersphinx_mapping = {
'https://docs.python.org/': None,
'numpy': ('https://numpy.org/doc/stable/', None),
'celerite2': ('https://celerite2.readthedocs.io/en/latest/', None),
'astropy': ('https://docs.astropy.org/en/stable/', None)
}

# -- Options for HTML output -------------------------------------------------

Expand All @@ -66,3 +72,7 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']

numpydoc_show_class_members = True
autosummary_generate = True
autosummary_imported_members = True
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
jaxon Documentation
-------------------
Documentation
-------------

This is the documentation for jaxon.

.. toctree::
:maxdepth: 2
:caption: Contents:


jaxon/api.rst

Indices and tables
==================
Expand Down
11 changes: 11 additions & 0 deletions docs/jaxon/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
===
API
===

.. automodapi:: jaxon.continuum
.. automodapi:: jaxon.lightcurve
.. automodapi:: jaxon.model
.. automodapi:: jaxon.reflected
.. automodapi:: jaxon.spectrum
.. automodapi:: jaxon.thermal
.. automodapi:: jaxon.tp
2 changes: 2 additions & 0 deletions jaxon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from .version import __version__ # noqa
from jax.config import config
config.update('jax_enable_x64', True)
3 changes: 3 additions & 0 deletions jaxon/continuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

CONST_K, CONST_C, CONST_H = 1.380649e-16, 29979245800.0, 6.62607015e-27 # cgs

__all__ = [
'dtauHminusCtm'
]

@jit
def log_hminus_continuum(wavelength_um, temperature, pressure,
Expand Down
13 changes: 1 addition & 12 deletions jaxon/hatp7.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
import astropy.units as u
from astropy.constants import G
from astropy.modeling.models import BlackBody

from .spectrum import wav, wav_vis
__all__ = []

planet_name = "HAT-P-7"

Expand Down Expand Up @@ -48,16 +47,6 @@
all_depths_errs = np.concatenate([depths_err, spitzer_depth_err])
all_wavelengths = np.concatenate([central_wl.value, spitzer_wl])

bb_star = BlackBody(temperature=6300*u.K)

bb_star_transformed = (bb_star(wav*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav*u.nm)
) * np.pi

bb_star_transformed_vis = (bb_star(wav_vis*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav_vis*u.nm)
) * np.pi

rprs = float(1.431*u.R_jup / (2.00 * u.R_sun))

g = (G * u.M_jup / u.R_jup**2).to(u.cm/u.s**2).value
Expand Down
45 changes: 26 additions & 19 deletions jaxon/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
planet_name, t0, period, eclipse_half_dur, b, rstar, rho_star, rp_rstar
)

__all__ = [
'eclipse_model'
]

lcf = search_lightcurve(
planet_name, mission="Kepler", cadence="long"
# quarter=10
Expand Down Expand Up @@ -66,22 +70,25 @@
filt.bin_down(4) # This speeds up integration by orders of magnitude
filt_wavelength, filt_trans = filt.wavelength.to(u.m).value, filt.transmittance

with pm.Model() as model:
# Define a Keplerian orbit using `exoplanet`:
orbit = xo.orbits.KeplerianOrbit(
period=period, t0=0, b=b, rho_star=rho_star.to(u.g / u.cm ** 3),
r_star=float(rstar / u.R_sun)
)

# Compute the eclipse model (no limb-darkening):
eclipse_light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(
orbit=orbit._flip(rp_rstar), r=orbit.r_star,
t=phase * period,
texp=(30 * u.min).to(u.d).value
)

# Normalize the eclipse model to unity out of eclipse and
# zero in-eclipse
eclipse = 1 + pm.math.sum(eclipse_light_curves, axis=-1)

eclipse_numpy = pmx.eval_in_model(eclipse)

def eclipse_model():
with pm.Model() as model:
# Define a Keplerian orbit using `exoplanet`:
orbit = xo.orbits.KeplerianOrbit(
period=period, t0=0, b=b, rho_star=rho_star.to(u.g / u.cm ** 3),
r_star=float(rstar / u.R_sun)
)

# Compute the eclipse model (no limb-darkening):
eclipse_light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(
orbit=orbit._flip(rp_rstar), r=orbit.r_star,
t=phase * period,
texp=(30 * u.min).to(u.d).value
)

# Normalize the eclipse model to unity out of eclipse and
# zero in-eclipse
eclipse = 1 + pm.math.sum(eclipse_light_curves, axis=-1)

eclipse_numpy = pmx.eval_in_model(eclipse)
return eclipse_numpy
22 changes: 15 additions & 7 deletions jaxon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,30 @@
from .reflected import reflected_phase_curve
from .thermal import thermal_phase_curve
from .tp import get_Tarr, polynomial_order, element_number, Parr, dParr
from .spectrum import exojax_spectrum, res_vis, nus, wav, cnu_TiO, indexnu_TiO
from .spectrum import (
exojax_spectrum, res_vis, nus, wav, bb_star_transformed
# cnu_TiO, indexnu_TiO
)
from .hatp7 import (
rprs, all_depths, all_depths_errs, all_wavelengths, kepler_mean_wl,
a_rs, a_rp, T_s, bb_star_transformed
a_rs, a_rp, T_s
)
from .lightcurve import (
phase, time, flux_normed, flux_normed_err, eclipse_numpy, filt_wavelength,
phase, time, flux_normed, flux_normed_err, eclipse_model, filt_wavelength,
filt_trans
)

__all__ = [
'model',
'run_mcmc'
]

model_kwargs = dict(
phase=phase.astype(floatX),
time=(time - time.mean()).astype(floatX),
y=flux_normed.astype(floatX),
yerr=flux_normed_err.astype(floatX),
eclipse_numpy=jnp.array(eclipse_numpy).astype(floatX),
eclipse_numpy=jnp.array(eclipse_model()).astype(floatX),
filt_wavelength=jnp.array(filt_wavelength.astype(floatX)),
filt_trans=jnp.array(filt_trans.astype(floatX)),
a_rs=a_rs, a_rp=a_rp, T_s=T_s, n_temps=polynomial_order * element_number + 1,
Expand All @@ -39,8 +47,8 @@ def model(
filt_wavelength, filt_trans, a_rs, a_rp, T_s,
nus=nus, wav=wav, Parr=Parr, dParr=dParr,
bb_star_transformed=bb_star_transformed,
res=res_vis, cnu_TiO=cnu_TiO,
indexnu_TiO=indexnu_TiO,
res=res_vis, #cnu_TiO=cnu_TiO,
# indexnu_TiO=indexnu_TiO,
predict=False
):
temps = numpyro.sample(
Expand Down Expand Up @@ -151,7 +159,7 @@ def model(
Tarr = get_Tarr(temps, Parr)
Fcgs, _, _ = exojax_spectrum(
temps, jnp.power(10, log_vmr_prod), jnp.power(10, mmr_TiO),
Parr, dParr, nus, wav, res, cnu_TiO, indexnu_TiO
Parr, dParr, nus, wav, res#, cnu_TiO, indexnu_TiO
)

fpfs_spectrum = rprs ** 2 * Fcgs / bb_star_transformed.value
Expand Down
110 changes: 69 additions & 41 deletions jaxon/spectrum.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,73 @@
import numpy as np
from jax import numpy as jnp, jit
from astropy.modeling.models import BlackBody
import astropy.units as u

from exojax.spec import (
moldb, contdb, initspec, molinfo, planck
)
from exojax.spec.modit import setdgm_exomol, exomol, xsmatrix
# from exojax.spec.modit import setdgm_exomol, exomol, xsmatrix
from exojax.spec.rtransfer import nugrid, dtauM, dtauCIA, rtrun

from .hatp7 import g
from .tp import get_Tarr, Parr
from .continuum import dtauHminusCtm

__all__ = [
'exojax_spectrum'
]


# The next block of code must occur before relative imports to avoid circular
# import errors
mmw = 2.33 # mean molecular weight
mmrH2 = 0.74

nus_kepler, wav_kepler, res_kepler = nugrid(348, 970, 10, "nm", xsmode="modit")
nus_wfc3, wav_wfc3, res_wfc3 = nugrid(1120, 1650, 10, "nm", xsmode="modit")
nus_spitzer, wav_spitzer, res_spitzer = nugrid(3000, 5500, 10, "nm", xsmode="modit")
nus_kepler, wav_kepler, res_kepler = nugrid(
348, 970, 10, "nm", xsmode="modit"
)
nus_wfc3, wav_wfc3, res_wfc3 = nugrid(
1120, 1650, 10, "nm", xsmode="modit"
)
nus_spitzer, wav_spitzer, res_spitzer = nugrid(
3000, 5500, 10, "nm", xsmode="modit"
)

nus = jnp.concatenate([nus_spitzer, nus_wfc3, nus_kepler])
wav = jnp.concatenate([wav_kepler, wav_wfc3, wav_spitzer])

nus_vis, wav_vis, res_vis = nugrid(300, 6000, 100, "nm", xsmode="modit")

bb_star = BlackBody(temperature=6300*u.K)

bb_star_transformed = (bb_star(wav*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav*u.nm)
) * np.pi

bb_star_transformed_vis = (bb_star(wav_vis*u.nm)).to(
u.erg/u.s/u.cm**2/u.Hz/u.sr, u.spectral_density(wav_vis*u.nm)
) * np.pi

from .hatp7 import g
from .tp import get_Tarr, Parr
from .continuum import dtauHminusCtm

cdbH2H2 = contdb.CdbCIA(
'.database/H2-H2_2011.cia', [nus_vis.min(), nus_vis.max()]
'/Users/brettmorris/git/exojax/.database/H2-H2_2011.cia',
[nus_vis.min(), nus_vis.max()]
)
cdbH2He = contdb.CdbCIA(
'.database/H2-He_2011.cia', [nus_vis.min(), nus_vis.max()]
)
mdbTiO = moldb.MdbExomol(
'.database/TiO/48Ti-16O/Toto/',
[nus_kepler.min(), nus_kepler.max()],
crit=1e-18
'/Users/brettmorris/git/exojax/.database/H2-He_2011.cia',
[nus_vis.min(), nus_vis.max()]
)
# mdbTiO = moldb.MdbExomol(
# '/Users/brettmorris/git/exojax/.database/TiO/48Ti-16O/Toto/',
# [nus_kepler.min(), nus_kepler.max()],
# crit=1e-18
# )

cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO = initspec.init_modit(
mdbTiO.nu_lines, nus
)
cnu_TiO_vis, indexnu_TiO_vis, R_TiO_vis, pmarray_TiO_vis = initspec.init_modit(
mdbTiO.nu_lines, nus_vis
)
# cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO = initspec.init_modit(
# mdbTiO.nu_lines, nus
# )
# cnu_TiO_vis, indexnu_TiO_vis, R_TiO_vis, pmarray_TiO_vis = initspec.init_modit(
# mdbTiO.nu_lines, nus_vis
# )


Pref = 1 # bar
Expand All @@ -53,15 +80,17 @@ def fT(T0, alpha):
T0_test = np.array([1000.0, 1700.0, 1000.0, 1700.0])
alpha_test = np.array([0.15, 0.15, 0.05, 0.05])
res = 0.2
dgm_ngammaL_TiO = setdgm_exomol(
mdbTiO, fT, Parr, R_TiO, mdbTiO.molmass,
res, T0_test, alpha_test
)
# dgm_ngammaL_TiO = setdgm_exomol(
# mdbTiO, fT, Parr, R_TiO, mdbTiO.molmass,
# res, T0_test, alpha_test
# )


@jit
def exojax_spectrum(temperatures, vmr_prod, mmr_TiO, Parr, dParr, nus, wav,
res, cnu_TiO, indexnu_TiO):
def exojax_spectrum(
temperatures, vmr_prod, mmr_TiO, Parr, dParr, nus, wav,
res, #cnu_TiO, indexnu_TiO
):
Tarr = get_Tarr(temperatures, Parr)

molmassH2 = molinfo.molmass("H2")
Expand All @@ -81,20 +110,19 @@ def exojax_spectrum(temperatures, vmr_prod, mmr_TiO, Parr, dParr, nus, wav,
nus, Tarr, Parr, dParr, vmr_prod, mmw, g
)

SijM_TiO, ngammaLM_TiO, nsigmaDl_CO = exomol(
mdbTiO, Tarr, Parr, res_kepler, mdbTiO.molmass
)
xsmdit3D = xsmatrix(
cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO,
nsigmaDl_CO, ngammaLM_TiO, SijM_TiO, nus,
dgm_ngammaL_TiO
)

dtaum_TiO = dtauM(
dParr, xsmdit3D, mmr_TiO * jnp.ones_like(Parr), mdbTiO.molmass, g
)

dtau = dtau_hminus + dtaucH2H2 + dtaucHeH2 + dtaum_TiO
# SijM_TiO, ngammaLM_TiO, nsigmaDl_CO = exomol(
# mdbTiO, Tarr, Parr, res_kepler, mdbTiO.molmass
# )
# xsmdit3D = xsmatrix(
# cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO,
# nsigmaDl_CO, ngammaLM_TiO, SijM_TiO, nus,
# dgm_ngammaL_TiO
# )
# dtaum_TiO = dtauM(
# dParr, xsmdit3D, mmr_TiO * jnp.ones_like(Parr), mdbTiO.molmass, g
# )

dtau = dtau_hminus + dtaucH2H2 + dtaucHeH2 # + dtaum_TiO
sourcef = planck.piBarr(Tarr, nus)
F0 = rtrun(dtau, sourcef)

Expand Down
5 changes: 5 additions & 0 deletions jaxon/thermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from .utils import cos_2d, sinsq_2d, trapz2d, floatX

__all__ = [
'thermal_phase_curve'
]


pi = np.cast[floatX](pi64)

h = np.cast[floatX](6.62607015e-34) # J s
Expand Down
6 changes: 6 additions & 0 deletions jaxon/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

quadrature_nodes = [gl_0, gl_1, gl_2, gl_3, gl_4, gl_5, gl_6]

__all__ = [
'Element',
'PiecewisePolynomial',
'piecewise_poly',
'get_Tarr'
]

class Element(object):
def __init__(self, edges, order):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ install_requires =
pymc3
pymc3_ext
exoplanet
lightkurve

[options.extras_require]
all =
Expand Down

0 comments on commit f240c0c

Please sign in to comment.