Skip to content

Commit

Permalink
dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Dec 4, 2018
1 parent 2b0c479 commit 36e72be
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions exoplanet/theano_ops/kepler/contact_points_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@ def setUp(self):

def test_infer_shape(self):
np.random.seed(42)
args = [tt.vector() for i in range(6)]
args = [tt.dvector() for i in range(6)]
vals = [np.random.rand(50) for i in range(6)]
self._compile_and_check(args,
self.op(*args),
vals,
self.op_class)

def test_basic(self):
a = 100.0
e = 0.3
w = 0.1
i = 0.5*np.pi
r = 0.1
R = 1.1
a = np.float64(100.0)
e = np.float64(0.3)
w = np.float64(0.1)
i = np.float64(0.5*np.pi)
r = np.float64(0.1)
R = np.float64(1.1)

M_expect = np.array([0.88452506, 0.8863776, 0.90490204, 0.90675455])
M_calc = theano.function([], self.op(a, e, w, i, r, R))()
print(M_expect, M_calc)

utt.assert_allclose(M_expect, M_calc)
utt.assert_allclose(M_calc, M_expect)


class TestCircularContactPoints(utt.InferShapeTester):
Expand All @@ -50,22 +51,23 @@ def setUp(self):

def test_infer_shape(self):
np.random.seed(42)
args = [tt.vector() for i in range(4)]
args = [tt.dvector() for i in range(4)]
vals = [np.random.rand(50) for i in range(4)]
self._compile_and_check(args,
self.op(*args),
vals,
self.op_class)

def test_basic(self):
a = 100.0
e = 0.0
w = 0.0
i = 0.5*np.pi
r = 0.1
R = 1.1
a = np.float64(100.0)
e = np.float64(0.0)
w = np.float64(0.0)
i = np.float64(0.5*np.pi)
r = np.float64(0.1)
R = np.float64(1.1)

M_circ = theano.function([], self.op(a, i, r, R))()
M_gen = theano.function([], ContactPointsOp()(a, e, w, i, r, R))()
print(M_circ, M_gen)

utt.assert_allclose(M_circ, M_gen)

0 comments on commit 36e72be

Please sign in to comment.