Skip to content

Commit

Permalink
fixing up flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Oct 6, 2021
1 parent 0bac7c1 commit b6b0df3
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 57 deletions.
12 changes: 8 additions & 4 deletions jaxon/continuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

CONST_K, CONST_C, CONST_H = 1.380649e-16, 29979245800.0, 6.62607015e-27 # cgs


@jit
def log_hminus_continuum(wavelength_um, temperature, pressure,
volume_mixing_ratio_product, truncation_value=-100):
Expand Down Expand Up @@ -149,9 +150,12 @@ def dtauHminusCtm(nus, Tarr, Parr, dParr, volume_mixing_ratio_product, mmw, g):
logg = jnp.log10(g)
ddParr = dParr / Parr
wavelength_um = 1e4 / nus[::-1]
dtauctm = (10 ** (log_hminus_continuum(wavelength_um, Tarr[:, None], Parr,
volume_mixing_ratio_product)
+ lognarr1[:, None] + logkb - logg - logm_ucgs)
* Tarr[:, None] / mmw * ddParr[:, None])
dtauctm = (
10 ** (log_hminus_continuum(wavelength_um, Tarr[:, None], Parr,
volume_mixing_ratio_product) +
lognarr1[:, None] + logkb - logg - logm_ucgs) *
Tarr[:, None] / mmw *
ddParr[:, None]
)

return dtauctm
22 changes: 13 additions & 9 deletions jaxon/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@
)

lcf = search_lightcurve(
planet_name, mission="Kepler", cadence="long"#, quarter=10 #[10, 11, 12]
planet_name, mission="Kepler", cadence="long"
# quarter=10
).download_all()

slc = lcf.stitch()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < 1.5 * eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)# | in_eclipse)
in_transit = (
(phases < 1.5 * eclipse_half_dur) |
(phases > 1 - 1.5 * eclipse_half_dur)
)
out_of_transit = np.logical_not(in_transit)

slc = slc.flatten(
polyorder=3, break_tolerance=10, window_length=1001, mask=~out_of_transit
).remove_nans()

phases = ((slc.time.jd - t0) % period) / period
in_eclipse = np.abs(phases - 0.5) < 1.5 * eclipse_half_dur
in_transit = (phases < 1.5 * eclipse_half_dur) | (
phases > 1 - 1.5 * eclipse_half_dur)
out_of_transit = np.logical_not(in_transit)# | in_eclipse)
in_transit = (
(phases < 1.5 * eclipse_half_dur) |
(phases > 1 - 1.5 * eclipse_half_dur)
)
out_of_transit = np.logical_not(in_transit)

sc = sigma_clip(
np.ascontiguousarray(slc.flux[out_of_transit], dtype=floatX),
Expand All @@ -47,7 +52,7 @@
)

bin_in_eclipse = np.abs(phase - 0.5) < eclipse_half_dur
unbinned_flux_mean = np.mean(sc[~sc.mask].data) # .mean()
unbinned_flux_mean = np.mean(sc[~sc.mask].data)

unbinned_flux_mean_ppm = 1e6 * (unbinned_flux_mean - 1)
flux_normed = np.ascontiguousarray(
Expand Down Expand Up @@ -80,4 +85,3 @@
eclipse = 1 + pm.math.sum(eclipse_light_curves, axis=-1)

eclipse_numpy = pmx.eval_in_model(eclipse)

29 changes: 12 additions & 17 deletions jaxon/reflected.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from numpy import pi as pi64
from jax import config, jit, numpy as jnp
from jax import jit, numpy as jnp

from .utils import trapz3d, floatX

Expand Down Expand Up @@ -120,11 +120,12 @@ def h_ml(omega_drag, alpha, theta, phi, C_11, m=1, l=1):
jnp.exp(-jnp.power(tilda_mu(theta, alpha), two) * half))

result = prefactor * (
mu(theta) * m * H(l, theta, alpha) * jnp.cos(m * phi) +
alpha * omega_drag * (tilda_mu(theta, alpha) *
H(l, theta, alpha) -
H(l + one, theta, alpha)) *
jnp.sin(m * phi))
mu(theta) * m * H(l, theta, alpha) * jnp.cos(m * phi) +
alpha * omega_drag * (tilda_mu(theta, alpha) *
H(l, theta, alpha) -
H(l + one, theta, alpha)) *
jnp.sin(m * phi)
)
return result


Expand Down Expand Up @@ -178,7 +179,6 @@ def blackbody2d(wavelengths, temperature):
return blackbody_lambda(wavelengths, temperature)



@jit
def integrate_planck(filt_wavelength, filt_trans,
temperature):
Expand Down Expand Up @@ -231,11 +231,6 @@ def reflected_phase_curve(phases, omega, g, a_rp):
)

abs_alpha = jnp.abs(alpha) # .astype(floatX)
alpha_sort_order = jnp.argsort(alpha)
sin_abs_sort_alpha = jnp.sin(
abs_alpha[alpha_sort_order]) # .astype(floatX)
sort_alpha = alpha[alpha_sort_order] # .astype(floatX)

gamma = jnp.sqrt(1 - omega)
eps = (1 - gamma) / (1 + gamma)

Expand Down Expand Up @@ -321,11 +316,11 @@ def I(alpha, Phi):
I_L = 1 / np.pi * (Phi * cos_alpha -
0.5 * jnp.sin(alpha - 2 * Phi))
I_C = -1 / (24 * cos_alpha_2) * (
-3 * jnp.sin(alpha / 2 - Phi) +
jnp.sin(3 * alpha / 2 - 3 * Phi) +
6 * jnp.sin(3 * alpha / 2 - Phi) -
6 * jnp.sin(alpha / 2 + Phi) +
24 * jnp.sin(alpha / 2) ** 4 * I_0
-3 * jnp.sin(alpha / 2 - Phi) +
jnp.sin(3 * alpha / 2 - 3 * Phi) +
6 * jnp.sin(3 * alpha / 2 - Phi) -
6 * jnp.sin(alpha / 2 + Phi) +
24 * jnp.sin(alpha / 2) ** 4 * I_0
)

return I_S, I_L, I_C
Expand Down
44 changes: 29 additions & 15 deletions jaxon/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
from jax import numpy as jnp, jit

from exojax.spec import (
moldb, contdb, initspec, SijT, molinfo, gamma_natural, planck,
doppler_sigma
moldb, contdb, initspec, molinfo, planck
)
from exojax.spec.modit import setdgm_exomol, exomol, xsmatrix
from exojax.spec.rtransfer import nugrid, dtauM, dtauCIA, rtrun

from .hatp7 import g, rprs
from .hatp7 import g
from .tp import get_Tarr, Parr
from .continuum import dtauHminusCtm

mmw = 2.33 # mean molecular weight
mmw = 2.33 # mean molecular weight
mmrH2 = 0.74

nus_kepler, wav_kepler, res_kepler = nugrid(348, 970, 10, "nm", xsmode="modit")
Expand All @@ -24,28 +23,42 @@

nus_vis, wav_vis, res_vis = nugrid(300, 6000, 100, "nm", xsmode="modit")

cdbH2H2=contdb.CdbCIA('.database/H2-H2_2011.cia', [nus_vis.min(), nus_vis.max()])
cdbH2He=contdb.CdbCIA('.database/H2-He_2011.cia', [nus_vis.min(), nus_vis.max()])
mdbTiO=moldb.MdbExomol(
cdbH2H2 = contdb.CdbCIA(
'.database/H2-H2_2011.cia', [nus_vis.min(), nus_vis.max()]
)
cdbH2He = contdb.CdbCIA(
'.database/H2-He_2011.cia', [nus_vis.min(), nus_vis.max()]
)
mdbTiO = moldb.MdbExomol(
'.database/TiO/48Ti-16O/Toto/',
[nus_kepler.min(), nus_kepler.max()],
crit=1e-18
)

cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO = initspec.init_modit(mdbTiO.nu_lines, nus)
cnu_TiO_vis, indexnu_TiO_vis, R_TiO_vis, pmarray_TiO_vis = initspec.init_modit(mdbTiO.nu_lines, nus_vis)
cnu_TiO, indexnu_TiO, R_TiO, pmarray_TiO = initspec.init_modit(
mdbTiO.nu_lines, nus
)
cnu_TiO_vis, indexnu_TiO_vis, R_TiO_vis, pmarray_TiO_vis = initspec.init_modit(
mdbTiO.nu_lines, nus_vis
)


Pref = 1 # bar

Pref = 1 # bar
fT = lambda T0, alpha: T0[:, None]*(Parr[None, :]/Pref)**alpha[:, None]
T0_test = np.array([1000.0,1700.0,1000.0,1700.0])
alpha_test = np.array([0.15,0.15,0.05,0.05])

def fT(T0, alpha):
return T0[:, None] * (Parr[None, :]/Pref) ** alpha[:, None]


T0_test = np.array([1000.0, 1700.0, 1000.0, 1700.0])
alpha_test = np.array([0.15, 0.15, 0.05, 0.05])
res = 0.2
dgm_ngammaL_TiO = setdgm_exomol(
mdbTiO, fT, Parr, R_TiO, mdbTiO.molmass,
res, T0_test, alpha_test
)


@jit
def exojax_spectrum(temperatures, vmr_prod, mmr_TiO, Parr, dParr, nus, wav,
res, cnu_TiO, indexnu_TiO):
Expand Down Expand Up @@ -77,8 +90,9 @@ def exojax_spectrum(temperatures, vmr_prod, mmr_TiO, Parr, dParr, nus, wav,
dgm_ngammaL_TiO
)

dtaum_TiO = dtauM(dParr, xsmdit3D, mmr_TiO * jnp.ones_like(Parr),
mdbTiO.molmass, g)
dtaum_TiO = dtauM(
dParr, xsmdit3D, mmr_TiO * jnp.ones_like(Parr), mdbTiO.molmass, g
)

dtau = dtau_hminus + dtaucH2H2 + dtaucHeH2 + dtaum_TiO
sourcef = planck.piBarr(Tarr, nus)
Expand Down
11 changes: 6 additions & 5 deletions jaxon/thermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,12 @@ def h_ml(omega_drag, alpha, theta, phi, C_11, m=1, l=1):
jnp.exp(-jnp.power(tilda_mu(theta, alpha), two) * half))

result = prefactor * (
mu(theta) * m * H(l, theta, alpha) * jnp.cos(m * phi) +
alpha * omega_drag * (tilda_mu(theta, alpha) *
H(l, theta, alpha) -
H(l + one, theta, alpha)) *
jnp.sin(m * phi))
mu(theta) * m * H(l, theta, alpha) * jnp.cos(m * phi) +
alpha * omega_drag * (tilda_mu(theta, alpha) *
H(l, theta, alpha) -
H(l + one, theta, alpha)) *
jnp.sin(m * phi)
)
return result


Expand Down
8 changes: 3 additions & 5 deletions jaxon/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from exojax.spec import rtransfer as rt

NP = 50
Parr, dParr, k=rt.pressure_layer(NP=NP, logPtop=-5, logPbtm=2.5)
Parr_fine, dParr_fine, k=rt.pressure_layer(NP=100, logPtop=-5, logPbtm=2.5)
mmw = 2.33 #mean molecular weight
Parr, dParr, k = rt.pressure_layer(NP=NP, logPtop=-5, logPbtm=2.5)
Parr_fine, dParr_fine, k = rt.pressure_layer(NP=100, logPtop=-5, logPbtm=2.5)
mmw = 2.33 # mean molecular weight
mmrH2 = 0.74

element_number = 3
Expand Down Expand Up @@ -146,8 +146,6 @@ def __call__(self, x_vector):
return values




def piecewise_poly(log_p, domain_boundaries, dof_values, element_number,
polynomial_order):
pp = PiecewisePolynomial(
Expand Down
1 change: 1 addition & 0 deletions jaxon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
two = np.cast[floatX](2)
half = np.cast[floatX](0.5)


@jit
def sum2d(z):
"""
Expand Down
13 changes: 11 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@ packages = find:
python_requires = >=3.5
setup_requires = setuptools_scm
install_requires =


numpy
jax
exojax
astropy
celerite2
arviz
numpyro
kelp
pymc3
pymc3_ext
exoplanet

[options.extras_require]
all =
Expand Down
8 changes: 8 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,16 @@ deps =
devdeps: git+https://github.com/astropy/astropy.git#egg=astropy

numpy
jax
exojax
astropy
celerite2
arviz
numpyro
kelp
pymc3
pymc3_ext
exoplanet

# The following indicates which extras_require from setup.cfg will be installed
extras =
Expand Down

0 comments on commit b6b0df3

Please sign in to comment.