Skip to content

Commit

Permalink
Adding mcmc example
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Mar 26, 2020
1 parent 06ae6da commit 5d73215
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 36 deletions.
8 changes: 5 additions & 3 deletions docs/index.rst
Expand Up @@ -7,22 +7,24 @@ Here's a simple example:

.. plot::

from retrieval import transit_depth
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np

from retrieval import Planet

temperatures = np.arange(1000, 3000, 500) * u.K

planet = Planet(1 * u.M_jup, 1 * u.R_jup, 1e-3 * u.bar, 2.2 * u.u)

for temperature in temperatures:
sp = transit_depth(temperature)
sp = planet.transit_depth(temperature)

ax = sp.plot(label=temperature)

ax.set_xlabel('Wavelength [$\mu$m]')
ax.set_ylabel('Transit depth')
ax.legend()
plt.tight_layout()
plt.show()


Expand Down
26 changes: 26 additions & 0 deletions example/fit_spectrum.py
@@ -0,0 +1,26 @@
import sys
sys.path.insert(0, '../')

from retrieval import Planet

import numpy as np
from scipy.optimize import fmin_l_bfgs_b
import astropy.units as u

example_spectrum = np.load('../retrieval/data/example_spectrum.npy')

planet = Planet(1 * u.M_jup, 1 * u.R_jup, 1e-3 * u.bar, 2.2 * u.u)


def minimize(p):
temperature = p[0] * u.K
return np.sum((example_spectrum[:, 1] -
planet.transit_depth(temperature).flux)**2 /
example_spectrum[:, 2]**2)

initp = [1700] # K

bestp = fmin_l_bfgs_b(minimize, initp, approx_grad=True,
bounds=[[500, 5000]])[0][0] * u.K

print(bestp)
26 changes: 26 additions & 0 deletions example/generate_spectrum.py
@@ -0,0 +1,26 @@
import sys
sys.path.insert(0, '../')

from retrieval import Planet

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np

temperature = 1500 * u.K

planet = Planet(1 * u.M_jup, 1 * u.R_jup, 1e-3 * u.bar, 2.2 * u.u)
sp = planet.transit_depth(temperature)

output_path = '../retrieval/data/example_spectrum.npy'

np.save(output_path, np.vstack([sp.wavelength.value, sp.flux.value,
sp.flux.mean().value / 100 *
np.ones(len(sp.flux))]).T)

ax = sp.plot(label=temperature)

ax.set_xlabel('Wavelength [$\mu$m]')
ax.set_ylabel('Transit depth')
ax.legend()
plt.show()
42 changes: 42 additions & 0 deletions example/mcmc.py
@@ -0,0 +1,42 @@
import sys
sys.path.insert(0, '../')

from retrieval import Planet

import numpy as np
import astropy.units as u
from emcee import EnsembleSampler
from multiprocessing import Pool
import matplotlib.pyplot as plt

example_spectrum = np.load('../retrieval/data/example_spectrum.npy')

planet = Planet(1 * u.M_jup, 1 * u.R_jup, 1e-3 * u.bar, 2.2 * u.u)


def lnprior(theta):
temperature = theta[0]

if 500 < temperature < 5000:
return 0
return -np.inf


def lnlikelihood(theta):
temperature = theta[0] * u.K
model = planet.transit_depth(temperature).flux
return -0.5 * np.sum((example_spectrum[:, 1] - model)**2 /
example_spectrum[:, 2]**2)

nwalkers = 10
ndim = 1

p0 = [[1500 + 10 * np.random.randn()] for i in range(nwalkers)]

with Pool() as pool:
sampler = EnsembleSampler(nwalkers, ndim, lnlikelihood, pool=pool)
sampler.run_mcmc(p0, 1000)

plt.hist(sampler.flatchain)
plt.xlabel('Temperature [K]')
plt.show()
67 changes: 36 additions & 31 deletions retrieval/core.py
@@ -1,44 +1,49 @@
import numpy as np
import astropy.units as u
from astropy.constants import G, k_B, R_jup, M_jup, R_sun
from astropy.constants import G, k_B

from .opacity import water_opacity
from .spectrum import Spectrum


__all__ = ['transit_depth']
__all__ = ['Planet']

gamma = 0.57721


def transit_depth(temperature):
class Planet(object):
"""
Compute the transit depth with wavelength at ``temperature``.
Parameters
----------
temperature : `~astropy.units.Quantity`
Returns
-------
sp : `~retrieval.Spectrum`
Transit depth spectrum
Properties of an exoplanet.
"""
wavenumber, kappa = water_opacity(temperature)

g = G * M_jup / R_jup**2
rstar = 1 * R_sun

R0 = R_jup
P0 = 1e-3 * u.bar

mu = 2 * u.u

scale_height = k_B * temperature / mu / g
tau = P0 * kappa / g * np.sqrt(2.0 * np.pi * R0 / scale_height)
r = R0 + scale_height * (gamma + np.log(tau))

depth = (r / rstar) ** 2
wavelength = wavenumber.to(u.um, u.spectral())

return Spectrum(wavelength, depth)
def __init__(self, mass, radius, pressure, mu):
self.mass = mass
self.radius = radius
self.pressure = pressure
self.mu = mu

def transit_depth(self, temperature, rstar=1 * u.R_sun):
"""
Compute the transit depth with wavelength at ``temperature``.
Parameters
----------
temperature : `~astropy.units.Quantity`
Returns
-------
sp : `~retrieval.Spectrum`
Transit depth spectrum
"""
wavenumber, kappa = water_opacity(temperature)

g = G * self.mass / self.radius**2
P0 = self.pressure

scale_height = k_B * temperature / self.mu / g
tau = P0 * kappa / g * np.sqrt(2 * np.pi * self.radius / scale_height)
r = self.radius + scale_height * (gamma + np.log(tau))

depth = (r / rstar).decompose() ** 2
wavelength = wavenumber.to(u.um, u.spectral())

return Spectrum(wavelength, depth)
Binary file added retrieval/data/example_spectrum.npy
Binary file not shown.
6 changes: 4 additions & 2 deletions retrieval/spectrum.py
@@ -1,4 +1,5 @@
import matplotlib.pyplot as plt
import numpy as np

__all__ = ['Spectrum']

Expand All @@ -12,8 +13,9 @@ class Spectrum(object):
flux, transit depth, etc.
"""
def __init__(self, wavelength, flux):
self.wavelength = wavelength
self.flux = flux
sort = np.argsort(wavelength)
self.wavelength = wavelength[sort]
self.flux = flux[sort]

def plot(self, ax=None, **kwargs):
"""
Expand Down
11 changes: 11 additions & 0 deletions retrieval/tests/test_core.py
@@ -0,0 +1,11 @@
import astropy.units as u

from ..core import Planet


def test_radius():
temperature = 1500 * u.K
planet = Planet(1 * u.M_jup, 1 * u.R_jup, 1e-3 * u.bar, 2.2 * u.u)
sp = planet.transit_depth(temperature)

assert abs(sp.flux.mean() - 0.01075) < 1e-5

0 comments on commit 5d73215

Please sign in to comment.