Skip to content

Commit

Permalink
Merge pull request #392 from google/infix-operator-tests
Browse files Browse the repository at this point in the history
tests for numpy operator overloading (some fail!)
  • Loading branch information
mattjj committed Feb 17, 2019
2 parents e6f6810 + 9bf830e commit 0bcf3a3
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from absl.testing import absltest
from absl.testing import parameterized
import six

import numpy as onp

Expand Down Expand Up @@ -141,6 +142,7 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("real", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("mod", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("sinc", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
op_record("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
Expand Down Expand Up @@ -187,6 +189,49 @@ def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
op_record("argmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal(), []),
]

JAX_OPERATOR_OVERLOADS = [
op_record("__add__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__radd__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__sub__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__rsub__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__mul__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__rmul__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__eq__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__ne__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__lt__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__gt__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__ge__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__neg__", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
op_record("__pow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), []),
op_record("__rpow__", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), []),
op_record("__mod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__rmod__", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__floordiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__rfloordiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__truediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__rtruediv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__abs__", 1, number_dtypes, all_shapes, jtu.rand_default(), []),
# TODO(mattjj): __invert__ fails on bool dtypes because ~True == -2
op_record("__invert__", 1, int_dtypes, all_shapes, jtu.rand_default(), []),
# TODO(mattjj): investigate these failures
# op_record("__or__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
# op_record("__ror__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
# op_record("__and__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
# op_record("__rand__", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
# op_record("__xor__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
# op_record("__rxor__", 2, number_dtypes, all_shapes, jtu.rand_bool(), []),
# op_record("__divmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
# op_record("__rdivmod__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
# TODO(mattjj): lshift, rshift
]

if six.PY2:
JAX_OPERATOR_OVERLOADS += [
op_record("__div__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
op_record("__rdiv__", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), []),
]


CombosWithReplacement = itertools.combinations_with_replacement


Expand Down Expand Up @@ -237,6 +282,21 @@ def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
dtypes),
"rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "name": rec.name}
for shapes in filter(
_shapes_are_broadcast_compatible,
CombosWithReplacement(rec.shapes, rec.nargs))
for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
for rec in JAX_OPERATOR_OVERLOADS))
def testOperatorOverload(self, name, rng, shapes, dtypes):
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
fun = lambda x, *xs: getattr(x, name)(*xs)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(
Expand Down

0 comments on commit 0bcf3a3

Please sign in to comment.