Skip to content

Commit

Permalink
Fix most test failures under NumPy 1.21.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Jun 22, 2021
1 parent f885366 commit 75c9bf0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def _normal(key, shape, dtype) -> jnp.ndarray:
@partial(jit, static_argnums=(1, 2))
def _normal_real(key, shape, dtype) -> jnp.ndarray:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), 0., dtype=dtype)
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
hi = np.array(1., dtype)
u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type]
return np.array(np.sqrt(2), dtype) * lax.erf_inv(u)
Expand Down
5 changes: 4 additions & 1 deletion jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def _safe_subtract(x, y, *, dtype):
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))

def scalar_mul(xs, a):
return tree_map(lambda x: np.multiply(x, a, dtype=_dtype(x)), xs)
def mul(x):
dtype = _dtype(x)
return np.multiply(x, np.array(a, dtype=dtype), dtype=dtype)
return tree_map(mul, xs)


def rand_like(rng, x):
Expand Down
11 changes: 5 additions & 6 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def test_jit_kwargs(self, one, two, three):
self.jit.cache_clear()

def f(x, y, z):
print(x, y, z)
side.append(None)
return 100 * x + 10 * y + z

Expand Down Expand Up @@ -1158,7 +1157,7 @@ def inner(y):
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))

def test_linear_transpose_abstract(self):
x = types.SimpleNamespace(shape=(3,), dtype=np.float32)
x = types.SimpleNamespace(shape=(3,), dtype=np.dtype(np.float32))
y = jnp.arange(3, dtype=np.float32)
transpose_fun = api.linear_transpose(lambda x: 2 * x, x)
z, = transpose_fun(y)
Expand Down Expand Up @@ -1399,7 +1398,7 @@ def fun(A, b, x):
class MyArgArray(object):
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
self.dtype = np.dtype(dtype)

A = MyArgArray((3, 4), jnp.float32)
b = MyArgArray((5,), jnp.float32)
Expand All @@ -1426,7 +1425,7 @@ def fun(x, y):
class MyArgArray(object):
def __init__(self, shape, dtype, named_shape):
self.shape = shape
self.dtype = dtype
self.dtype = jnp.dtype(dtype)
self.named_shape = named_shape

x = MyArgArray((3, 2), jnp.float32, {'i': 10})
Expand Down Expand Up @@ -3194,7 +3193,7 @@ def test_scalar_literals(self):

def test_abstract_inputs(self):
jaxpr = api.make_jaxpr(lambda x: x + 2.)(
types.SimpleNamespace(shape=(), dtype=np.float32))
types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)))
self.assertEqual(jaxpr.in_avals[0].shape, ())
self.assertEqual(jaxpr.in_avals[0].dtype, np.float32)

Expand Down Expand Up @@ -3262,7 +3261,7 @@ def f(x):
return x - lax.psum(x, 'i')

x = types.SimpleNamespace(
shape=(2, 3), dtype=jnp.float32, named_shape={'i': 10})
shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10})
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
self.assertEqual(named_shapes, [{'i': 10}, {}])
Expand Down
9 changes: 8 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,13 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
op_record("true_divide", 2, all_dtypes, all_shapes, jtu.rand_nonzero,
["rev"], inexact=True),
op_record("ediff1d", 3, [np.int32], all_shapes, jtu.rand_default, []),
op_record("unwrap", 1, float_dtypes, nonempty_nonscalar_array_shapes,
# TODO(phawkins): np.unwrap does not correctly promote its default period
# argument under NumPy 1.21 for bfloat16 inputs. It works fine if we
# explicitly pass a bfloat16 value that does not need promition. We should
# probably add a custom test harness for unwrap that tests the period
# argument anyway.
op_record("unwrap", 1, [t for t in float_dtypes if t != dtypes.bfloat16],
nonempty_nonscalar_array_shapes,
jtu.rand_default, ["rev"],
# numpy.unwrap always returns float64
check_dtypes=False,
Expand Down Expand Up @@ -5512,6 +5518,7 @@ def testWrappedSignaturesMatch(self):
'ones': ['order', 'like'],
'ones_like': ['subok', 'order'],
'tri': ['like'],
'unwrap': ['period'],
'zeros_like': ['subok', 'order']
}

Expand Down
2 changes: 1 addition & 1 deletion tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def tree_unflatten(cls, meta, data):
STRS = []
for tree_str in TREE_STRINGS:
tree_str = re.escape(tree_str)
tree_str = tree_str.replace("__main__", "(__main__|tree_util_test)")
tree_str = tree_str.replace("__main__", ".*")
STRS.append(tree_str)
TREE_STRINGS = STRS

Expand Down

0 comments on commit 75c9bf0

Please sign in to comment.