Skip to content

Commit

Permalink
Fixing #145 (#146)
Browse files Browse the repository at this point in the history
* fixing #145

* pinning arviz version number

* don't pin pymc version
  • Loading branch information
dfm committed Feb 24, 2021
1 parent 55d2252 commit 16cc4ef
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
]
SETUP_REQUIRES = ["setuptools>=40.6.0", "setuptools_scm"]
INSTALL_REQUIRES = [
"pybind11>=2.4",
"numpy>=1.13.0",
"pymc3>=3.5",
"arviz<0.11",
"astropy>=3.1",
"pymc3-ext>=0.0.1",
]
Expand All @@ -38,7 +38,6 @@
"scipy",
"nose",
"parameterized",
"arviz",
"pytest",
"pytest-cov>=2.6.1",
"pytest-env",
Expand Down
3 changes: 2 additions & 1 deletion src/exoplanet/theano_ops/kepler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["kepler"]

import numpy as np
import theano
import theano.tensor as tt

Expand All @@ -28,7 +29,7 @@ def perform(self, node, inputs, outputs):
M, ecc = inputs
sinf = resize_or_set(outputs, 0, M.shape)
cosf = resize_or_set(outputs, 1, M.shape)
driver.kepler(M, ecc, sinf, cosf)
driver.kepler(M % (2 * np.pi), ecc, sinf, cosf)

def grad(self, inputs, gradients):
M, e = inputs
Expand Down
15 changes: 15 additions & 0 deletions tests/theano_ops/kepler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ def test_pi(self):
assert np.allclose(np.sin(f), sinf0)
assert np.allclose(np.cos(f), cosf0)

def test_twopi(self):
e = np.linspace(0, 1.0, 100)[:-1]
M = 2 * np.pi + np.zeros_like(e)

M_t = tt.dvector()
e_t = tt.dvector()
func = theano.function([M_t, e_t], self.op(M_t, e_t))
sinf0, cosf0 = func(np.zeros_like(M), e)
sinf, cosf = func(M, e)

assert np.all(np.isfinite(sinf0))
assert np.all(np.isfinite(cosf0))
assert np.allclose(sinf, sinf0)
assert np.allclose(cosf, cosf0)

def test_solver(self):
e = np.linspace(0, 1, 500)[:-1]
E = np.linspace(-300, 300, 1001)
Expand Down

0 comments on commit 16cc4ef

Please sign in to comment.