Skip to content

Commit

Permalink
Fix test flakiness in autodiff tests for min/max type functions (#2918)
Browse files Browse the repository at this point in the history
* Fix test flakiness in autodiff tests for clamp, reduce, and reduce-window.

We change the tests to avoid computing numerical gradients in the neighborhood of nondifferentiable points where, for example, the maximum element in a reduce-max changes. The autodiff approximation is only valid within an epsilon ball around a point, and close to an inflection point the approximation may not be valid.

* Only test reduce-grad-mul for float types.
  • Loading branch information
hawkinsp committed May 1, 2020
1 parent 0736679 commit 1b56428
Showing 1 changed file with 37 additions and 45 deletions.
82 changes: 37 additions & 45 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,11 +1741,14 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
LAX_GRAD_OPS = [
grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.floor, nargs=1, order=2, rng_factory=jtu.rand_default,
grad_test_spec(lax.floor, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),
grad_test_spec(lax.ceil, nargs=1, order=2, rng_factory=jtu.rand_default,
grad_test_spec(lax.ceil, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),
grad_test_spec(lax.round, nargs=1, order=2, rng_factory=jtu.rand_default,
grad_test_spec(lax.round, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),

grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
Expand Down Expand Up @@ -1908,29 +1911,22 @@ def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}".format(
jtu.format_shape_dtype_string(min_shape, dtype),
jtu.format_shape_dtype_string(operand_shape, dtype),
jtu.format_shape_dtype_string(max_shape, dtype)),
"min_shape": min_shape, "operand_shape": operand_shape,
"max_shape": max_shape, "dtype": dtype, "rng_factory": rng_factory}
for min_shape, operand_shape, max_shape in [
[(), (), ()],
[(), (2, 3), ()],
[(2, 3), (2, 3), (2, 3)],
]
# TODO(phawkins): this test fails for bfloat16.
for dtype in [t for t in float_dtypes if t != dtypes.bfloat16]
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype,
"rng_factory": rng_factory}
for shape in [(), (2, 3)]
for dtype in grad_float_dtypes
for rng_factory in [jtu.rand_default]))
def testClampGrad(self, min_shape, operand_shape, max_shape, dtype, rng_factory):
def testClampGrad(self, shape, dtype, rng_factory):
rng = rng_factory()
tol = {dtypes.bfloat16: 1e-1, onp.float16: 1e-1, onp.float32: 1e-2}
shapes = [min_shape, operand_shape, max_shape]
min, operand, max = (rng(shape, dtype) for shape in shapes)
min, max = onp.minimum(min, max), onp.maximum(min, max) # broadcast
eps = 1e-1 if dtypes.finfo(dtype).bits == 16 else 1e-2
check_grads(lax.clamp, (min, operand, max), 2, ["fwd", "rev"], tol, tol,
eps=eps)
operand = rng(shape, dtype)
low = operand - dtype(10)
high = operand + dtype(10)
# Avoids points near the boundary where the gradient may be inaccurate.
check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
Expand Down Expand Up @@ -2329,16 +2325,11 @@ def testTransposeGrad(self, shape, dtype, perm, rng_factory):
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "rng_factory": rng_factory}
for init_val, op, dtypes in [
(0, lax.add, inexact_dtypes),
# Precision problems for float16 tests.
(-onp.inf, lax.max,
[t for t in inexact_dtypes if t not in [onp.float16, dtypes.bfloat16]]),
(onp.inf, lax.min,
[t for t in inexact_dtypes if t not in [onp.float16, dtypes.bfloat16]]),
# The mul test overflows the range of a float16.
(1, lax.mul,
[t for t in inexact_dtypes if t not in [onp.float16, dtypes.bfloat16]]),
for init_val, op, dtypes, rng_factory in [
(0, lax.add, inexact_dtypes, jtu.rand_default),
(-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
(onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
(1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
]
for dtype in dtypes
for shape, dims in [
Expand All @@ -2348,14 +2339,13 @@ def testTransposeGrad(self, shape, dtype, perm, rng_factory):
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
]
for rng_factory in [jtu.rand_default]))
]))
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
rng = rng_factory()
if jtu.device_under_test() == "tpu" and op is lax.mul:
raise SkipTest("unimplemented case")
tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 4e-2,
onp.float64: 1e-3, onp.complex64: 1e-2}
tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
onp.float64: 1e-3, onp.complex64: 1e-1}
operand = rng(shape, dtype)
init_val = onp.asarray(init_val, dtype=dtype)
reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
Expand All @@ -2369,18 +2359,18 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
.format(op.__name__, onp.dtype(dtype).name, padding),
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
"rng_factory": rng_factory}
for init_val, op, dtypes, rng in [
(0, lax.add, float_dtypes, jtu.rand_small),
(-onp.inf, lax.max, [onp.float32], jtu.rand_default),
(onp.inf, lax.min, [onp.float32], jtu.rand_default),
for init_val, op, dtypes, rng_factory in [
(0, lax.add, grad_float_dtypes, jtu.rand_small),
(-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
(onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
]
for dtype in dtypes
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_default]))
for padding in ["VALID", "SAME"]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
rng = rng_factory()
tol = {onp.float16: 1e-1, onp.float32: 1e-3}
init_val = onp.asarray(init_val, dtype=dtype)

# We need this conditional and the corresponding loop logic to be in the
Expand Down Expand Up @@ -2413,11 +2403,13 @@ def fun(operand):
operand = rng(shape, dtype)
if op is lax.add:
eps = 1.
tol = None
else:
# this test can fail if there are duplicates in operand
self.assertEqual(onp.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
eps = 1e-2
tol = {onp.float16: 1e-1, onp.float32: 2e-2, onp.float64: 2e-2}
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
eps)

Expand Down

0 comments on commit 1b56428

Please sign in to comment.