Skip to content

Commit

Permalink
Merge pull request #398 from hawkinsp/master
Browse files Browse the repository at this point in the history
Fix some test failures.
  • Loading branch information
hawkinsp committed Feb 17, 2019
2 parents 901a5e5 + 5511567 commit 848b769
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
26 changes: 17 additions & 9 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@

OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng", "diff_modes", "test_name"])
["name", "nargs", "dtypes", "shapes", "rng", "diff_modes", "test_name",
"check_dtypes"])


def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None,
check_dtypes=True):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name)
return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name,
check_dtypes)

JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Expand Down Expand Up @@ -109,7 +112,9 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
]

JAX_COMPOUND_OP_RECORDS = [
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default(), [],
check_dtypes=False),
op_record("atleast_1d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
op_record("atleast_2d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
op_record("atleast_3d", 1, default_dtypes, all_shapes, jtu.rand_default(), []),
Expand Down Expand Up @@ -271,16 +276,19 @@ def _GetArgsMaker(self, rng, shapes, dtypes):
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
"onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
"check_dtypes": rec.check_dtypes}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS)))
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
JAX_COMPOUND_OP_RECORDS)))
def testOp(self, onp_op, lnp_op, rng, shapes, dtypes, check_dtypes):
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker,
check_dtypes=check_dtypes)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=check_dtypes)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
Expand Down
3 changes: 2 additions & 1 deletion tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]

self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

@genNamedParametersNArgs(3, jtu.rand_positive())
Expand Down

0 comments on commit 848b769

Please sign in to comment.