Skip to content

Commit

Permalink
more GP test
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Jan 2, 2019
1 parent 5107974 commit 9cc246c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .ci/travis.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if ! command -v conda > /dev/null; then
conda update -q conda
conda create -q --yes -n test python=$PYTHON_VERSION
conda activate test
conda install -q -c conda-forge numpy=$NUMPY_VERSION scipy astropy setuptools pymc3 pytest pytest-cov starry pip
conda install -q -c conda-forge numpy=$NUMPY_VERSION scipy astropy setuptools pymc3 pytest pytest-cov starry celerite pip
pip install parameterized nose coveralls
pip uninstall -y batman-package
git clone https://github.com/lkreidberg/batman.git
Expand Down
72 changes: 72 additions & 0 deletions exoplanet/gp/celerite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from __future__ import division, print_function

import pytest
import numpy as np

import theano
import theano.tensor as tt

import celerite
import celerite.terms as cterms

from . import terms
from .celerite import GP

Expand All @@ -32,3 +36,71 @@ def test_broadcast_dim():
y = np.sin(x)
diag = np.random.rand(N)
grad(-5.0, -2.0, 1.0, x, y, diag)


def _get_theano_kernel(celerite_kernel):
if isinstance(celerite_kernel, cterms.TermSum):
result = _get_theano_kernel(celerite_kernel.terms[0])
for k in celerite_kernel.terms[1:]:
result += _get_theano_kernel(k)
return result
elif isinstance(celerite_kernel, cterms.TermProduct):
return (
_get_theano_kernel(celerite_kernel.k1) *
_get_theano_kernel(celerite_kernel.k2))
elif isinstance(celerite_kernel, cterms.RealTerm):
return terms.RealTerm(log_a=celerite_kernel.log_a,
log_c=celerite_kernel.log_c)
elif isinstance(celerite_kernel, cterms.ComplexTerm):
if not celerite_kernel.fit_b:
return terms.ComplexTerm(log_a=celerite_kernel.log_a,
b=0.0,
log_c=celerite_kernel.log_c,
log_d=celerite_kernel.log_d)
return terms.ComplexTerm(log_a=celerite_kernel.log_a,
log_b=celerite_kernel.log_b,
log_c=celerite_kernel.log_c,
log_d=celerite_kernel.log_d)
elif isinstance(celerite_kernel, cterms.SHOTerm):
return terms.SHOTerm(log_S0=celerite_kernel.log_S0,
log_Q=celerite_kernel.log_Q,
log_w0=celerite_kernel.log_omega0)
elif isinstance(celerite_kernel, cterms.Matern32Term):
return terms.Matern32Term(log_sigma=celerite_kernel.log_sigma,
log_rho=celerite_kernel.log_rho)
raise NotImplementedError()


@pytest.mark.parametrize(
"celerite_kernel",
[
cterms.RealTerm(log_a=0.1, log_c=0.5),
cterms.RealTerm(log_a=0.1, log_c=0.5) +
cterms.RealTerm(log_a=-0.1, log_c=0.7),
cterms.ComplexTerm(log_a=0.1, log_c=0.5, log_d=0.1),
cterms.ComplexTerm(log_a=0.1, log_b=-0.2, log_c=0.5, log_d=0.1),
cterms.SHOTerm(log_S0=0.1, log_Q=-1, log_omega0=0.5),
cterms.SHOTerm(log_S0=0.1, log_Q=1.0, log_omega0=0.5),
cterms.SHOTerm(log_S0=0.1, log_Q=1.0, log_omega0=0.5) +
cterms.RealTerm(log_a=0.1, log_c=0.4),
cterms.SHOTerm(log_S0=0.1, log_Q=1.0, log_omega0=0.5) *
cterms.RealTerm(log_a=0.1, log_c=0.4),
cterms.Matern32Term(log_sigma=0.1, log_rho=0.4),
]
)
def test_gp(celerite_kernel, seed=1234):
np.random.seed(seed)
x = np.sort(np.random.rand(100))
yerr = np.random.uniform(0.1, 0.5, len(x))
y = np.sin(x)
diag = yerr**2

celerite_gp = celerite.GP(celerite_kernel)
celerite_gp.compute(x, yerr)
celerite_loglike = celerite_gp.log_likelihood(y)

kernel = _get_theano_kernel(celerite_kernel)
gp = GP(kernel, x, diag)
loglike = gp.log_likelihood(y).eval()

assert np.allclose(loglike, celerite_loglike)
5 changes: 1 addition & 4 deletions exoplanet/gp/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,8 @@ class Matern32Term(Term):
parameter_names = ("sigma", "rho")

def __init__(self, **kwargs):
eps = kwargs.pop("eps", None)
self.eps = tt.as_tensor_variable(kwargs.pop("eps", 0.01))
super(Matern32Term, self).__init__(**kwargs)
if eps is None:
eps = tt.as_tensor_variable(0.01)
self.eps = tt.cast(eps, self.dtype)

def get_complex_coefficients(self):
w0 = np.sqrt(3.0) / self.rho
Expand Down
1 change: 0 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
addopts =
-v
--cov=exoplanet
exoplanet
norecursedirs = exoplanet/theano_ops/starry/starry
filterwarnings =
ignore::DeprecationWarning
Expand Down

0 comments on commit 9cc246c

Please sign in to comment.