Skip to content

Commit

Permalink
gp tutorial and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 22, 2018
1 parent 1d37031 commit 65433c3
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 90 deletions.
171 changes: 85 additions & 86 deletions docs/_static/notebooks/intro-to-pymc3.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ exoplanet

tutorials/quickstart
tutorials/intro-to-pymc3
tutorials/gp


.. toctree::
Expand Down
1 change: 1 addition & 0 deletions docs/user/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Scalable Gaussian processes
.. autoclass:: exoplanet.gp.terms.ComplexTerm
.. autoclass:: exoplanet.gp.terms.SHOTerm
.. autoclass:: exoplanet.gp.terms.Matern32Term
.. autoclass:: exoplanet.gp.terms.RotationTerm


Estimators
Expand Down
139 changes: 137 additions & 2 deletions exoplanet/gp/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = [
"Term", "TermSum", "TermProduct", "TermDiff",
"RealTerm", "ComplexTerm",
"SHOTerm", "Matern32Term",
"SHOTerm", "Matern32Term", "RotationTerm",
]

import numpy as np
Expand All @@ -17,6 +17,12 @@


class Term(object):
"""The abstract base "term" that is the superclass of all other terms
Subclasses should overload the :func:`terms.Term.get_real_coefficients`
and :func:`terms.Term.get_complex_coefficients` methods.
"""

parameter_names = tuple()

Expand All @@ -27,7 +33,7 @@ def __init__(self, **kwargs):
raise ValueError(("Missing required parameter {0}. "
"Provide {0} or log_{0}").format(name))
value = kwargs[name] if name in kwargs \
else tt.exp(kwargs["log_" + name], name=name)
else tt.exp(kwargs["log_" + name])
setattr(self, name, tt.cast(value, self.dtype))

self.coefficients = self.get_coefficients()
Expand Down Expand Up @@ -211,6 +217,27 @@ def get_coefficients(self):


class RealTerm(Term):
r"""The simplest celerite term
This term has the form
.. math::
k(\tau) = a_j\,e^{-c_j\,\tau}
with the parameters ``a`` and ``c``.
Strictly speaking, for a sum of terms, the parameter ``a`` could be
allowed to go negative but since it is somewhat subtle to ensure positive
definiteness, we recommend keeping both parameters strictly positive.
Advanced users can build a custom term that has negative coefficients but
care should be taken to ensure positivity.
Args:
a or log_a: The lamplitude of the term.
c or log_c: The lexponent of the term.
"""

parameter_names = ("a", "c")

Expand All @@ -222,6 +249,27 @@ def get_real_coefficients(self):


class ComplexTerm(Term):
r"""A general celerite term
This term has the form
.. math::
k(\tau) = \frac{1}{2}\,\left[(a_j + b_j)\,e^{-(c_j+d_j)\,\tau}
+ (a_j - b_j)\,e^{-(c_j-d_j)\,\tau}\right]
with the parameters ``a``, ``b``, ``c``, and ``d``.
This term will only correspond to a positive definite kernel (on its own)
if :math:`a_j\,c_j \ge b_j\,d_j`.
Args:
a or log_a: The real part of amplitude.
b or log_b: The imaginary part of amplitude.
c or log_c: The real part of the exponent.
d or log_d: The imaginary part of exponent.
"""

parameter_names = ("a", "b", "c", "d")

Expand All @@ -235,6 +283,23 @@ def get_complex_coefficients(self):


class SHOTerm(Term):
r"""A term representing a stochastically-driven, damped harmonic oscillator
The PSD of this term is
.. math::
S(\omega) = \sqrt{\frac{2}{\pi}} \frac{S_0\,\omega_0^4}
{(\omega^2-{\omega_0}^2)^2 + {\omega_0}^2\,\omega^2/Q^2}
with the parameters ``S0``, ``Q``, and ``w0``.
Args:
S0 or log_S0: The parameter :math:`S_0`.
Q or log_Q: The parameter :math:`Q`.
w0 or log_w0: The parameter :math:`\omega_0`.
"""

parameter_names = ("S0", "w0", "Q")

Expand Down Expand Up @@ -275,6 +340,34 @@ def underdamped():


class Matern32Term(Term):
r"""A term that approximates a Matern-3/2 function
This term is defined as
.. math::
k(\tau) = \sigma^2\,\left[
\left(1+1/\epsilon\right)\,e^{-(1-\epsilon)\sqrt{3}\,\tau/\rho}
\left(1-1/\epsilon\right)\,e^{-(1+\epsilon)\sqrt{3}\,\tau/\rho}
\right]
with the parameters ``sigma`` and ``rho``. The parameter ``eps``
controls the quality of the approximation since, in the limit
:math:`\epsilon \to 0` this becomes the Matern-3/2 function
.. math::
\lim_{\epsilon \to 0} k(\tau) = \sigma^2\,\left(1+
\frac{\sqrt{3}\,\tau}{\rho}\right)\,
\exp\left(-\frac{\sqrt{3}\,\tau}{\rho}\right)
Args:
sigma or log_sigma: The parameter :math:`\sigma`.
rho or log_rho: The parameter :math:`\rho`.
eps (Optional[float]): The value of the parameter :math:`\epsilon`.
(default: `0.01`)
"""

parameter_names = ("sigma", "rho")

Expand All @@ -294,3 +387,45 @@ def get_complex_coefficients(self):
tt.reshape(w0, (w0.size,)),
tt.reshape(self.eps, (w0.size,))
)


class RotationTerm(TermSum):
r"""A mixture of two SHO terms that can be used to model stellar rotation
This term has two modes in Fourier space: one at ``period`` and one at
``0.5 * period``. This can be a good descriptive model for a wide range of
stochastic variability in stellar time series from rotation to pulsations.
Args:
amp or log_amp: The amplitude of the variability.
period or log_period: The primary period of variability.
Q0 or log_Q0: The quality factor (or really the quality factor minus
one half) for the secondary oscillation.
deltaQ or log_deltaQ: The difference between the quality factors of
the first and the second modes. This parameterization (if
``deltaQ > 0``) ensures that the primary mode alway has higher
quality.
mix: The fractional amplitude of the secondary mode compared to the
primary. This should probably always be ``0 < mix < 1``.
"""

parameter_names = ("amp", "Q0", "deltaQ", "period", "mix")

def __init__(self, **kwargs):
super(RotationTerm, self).__init__(**kwargs)

# One term with a period of period
Q1 = 0.5 + self.Q0 + self.deltaQ
w1 = 4*np.pi*Q1/(self.period * tt.sqrt(4*Q1**2-1))
S1 = self.amp / (w1 * Q1)

# Another term at half the period
Q2 = 0.5 + self.Q0
w2 = 8*np.pi*Q2/(self.period * tt.sqrt(4*Q2**2-1))
S2 = self.mix * self.amp / (w1 * Q2)

self.terms = (
SHOTerm(S0=S1, w0=w1, Q=Q1),
SHOTerm(S0=S2, w0=w2, Q=Q2))
self.coefficients = self.get_coefficients()
35 changes: 33 additions & 2 deletions exoplanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,44 @@
import pymc3 as pm


def eval_in_model(var, point=None, model=None, return_func=False, **kwargs):
def eval_in_model(var, point=None, return_func=False, model=None, **kwargs):
"""Evaluate a Theano tensor or PyMC3 variable in a PyMC3 model
This method builds a Theano function for evaluating a node in the graph
given the required parameters. This will also cache the compiled Theano
function in the current ``pymc3.Model`` to reduce the overhead of calling
this function many times.
Args:
var: The variable or tensor to evaluate.
point (Optional): A ``dict`` of input parameter values. This can be
``model.test_point`` (default), the result of ``pymc3.find_MAP``,
a point in a ``pymc3.MultiTrace`` or any other representation of
the input parameters.
return_func (Optional[bool]): If ``False`` (default), return the
evaluated variable. If ``True``, return the result, the Theano
function and the list of arguments for that function.
Returns:
Depending on ``return_func``, either the value of ``var`` at ``point``,
or this value, the Theano function, and the input arguments.
"""
model = pm.modelcontext(model)
if point is None:
point = model.test_point

# Cache the function if it has previously been compiled
if not hasattr(model, "_exoplanet_eval_funcs"):
model._exoplanet_eval_funcs = dict()
kwargs["on_unused_input"] = kwargs.get("on_unused_input", "ignore")
func = theano.function(model.vars, var, **kwargs)
func = model._exoplanet_eval_funcs.get(
var, theano.function(model.vars, var, **kwargs))
model._exoplanet_eval_funcs[var] = func

# Work out the arguments
args = [point[k.name] for k in model.vars]

if return_func:
return func(*args), func, args
return func(*args)

0 comments on commit 65433c3

Please sign in to comment.