Skip to content

Commit

Permalink
streamline tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2020
1 parent 80c66ca commit e455e48
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions tests/custom_object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,19 @@ def make_sparse_array():


class CustomObjectTest(jtu.JaxTestCase):
def testIdentityFunction(self):
M = make_sparse_array()

@jit
def f(x):
return x

M2 = f(M)
self.assertAllClose(M.data, M2.data)
self.assertAllClose(M.indices, M2.indices)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_compile={}".format(compile),
"compile": compile}
{"testcase_name": "_compile={}_primitive={}".format(compile, primitive),
"compile": compile, "primitive": primitive}
for primitive in [True, False]
for compile in [True, False]))
def testIdentityPrimitive(self, compile):
def testIdentity(self, compile, primitive):
f = identity if primitive else (lambda x: x)
f = jit(f) if compile else f
M = make_sparse_array()
f = jit(identity) if compile else identity
M2 = f(M)
self.assertEqual(M.dtype, M2.dtype)
self.assertEqual(M.index_dtype, M2.index_dtype)
self.assertAllClose(M.data, M2.data)
self.assertAllClose(M.indices, M2.indices)

Expand Down

0 comments on commit e455e48

Please sign in to comment.