Skip to content

Commit

Permalink
Fix some test failures under H100.
Browse files Browse the repository at this point in the history
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.

Happily, this also allows us to remove a number of TPU special cases for the same reason.

PiperOrigin-RevId: 539101155
  • Loading branch information
hawkinsp authored and jax authors committed Jun 9, 2023
1 parent 00f2a8c commit 803c729
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 43 deletions.
14 changes: 6 additions & 8 deletions tests/batching_test.py
Expand Up @@ -49,6 +49,7 @@ def testConstantFunction(self):
expected = 3 * np.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)

@jax.default_matmul_precision("float32")
def testNestedBatchingMatMat(self):
matvec = vmap(jnp.vdot, in_axes=(0, None))
matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)
Expand All @@ -59,9 +60,7 @@ def testNestedBatchingMatMat(self):

ans = matmat(A, B)
expected = np.dot(A, B)
self.assertAllClose(
ans, expected, check_dtypes=False,
rtol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
self.assertAllClose(ans, expected, check_dtypes=False)

jaxpr = make_jaxpr(matmat)(A, B)
self.assertLen(jaxpr.jaxpr.eqns, 1)
Expand Down Expand Up @@ -98,6 +97,7 @@ def loss(params, data):
self.assertEqual(dW.shape, (batch_size,) + W.shape)
self.assertEqual(db.shape, (batch_size,) + b.shape)

@jax.default_matmul_precision("float32")
def testJacobians(self):
def jacbwd(f, x):
y, pullback = vjp(f, x)
Expand All @@ -118,8 +118,7 @@ def jacfwd(f, x):
f = lambda x: jnp.tanh(jnp.dot(A, x) + b)

x = R(3)
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False,
rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None)
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

def testBatchOfCompile(self):
side = []
Expand Down Expand Up @@ -201,6 +200,7 @@ def testNpGtrThan(self):
expected_ans = x > 1.0
self.assertAllClose(ans, expected_ans)

@jax.default_matmul_precision("float32")
def testNpMaximumPerExampleGrad(self):
R = self.rng().randn
x = R(10, 5)
Expand All @@ -218,9 +218,7 @@ def testNpMaximumPerExampleGrad(self):
jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
expected_ans = jnp.transpose(expected_ans)

self.assertAllClose(
ans[i], expected_ans, check_dtypes=False,
rtol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)

def testDotGeneral(self):
R = self.rng().randn
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_einsum_test.py
Expand Up @@ -291,6 +291,7 @@ def test_einsum_path(self):
C = self.rng().rand(10, 10)
np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy')

@jax.default_matmul_precision("float32")
def test_einsum_kpmurphy_example(self):
# code from an email with @murphyk
N, C, D, K, T = 2, 3, 4, 5, 6
Expand All @@ -309,9 +310,8 @@ def test_einsum_kpmurphy_example(self):
L[n,c] = s

path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path),
check_dtypes=False, rtol=rtol)
check_dtypes=False)

def test_contraction_broadcasting(self):
r = self.rng()
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_scipy_sparse_test.py
Expand Up @@ -420,7 +420,7 @@ def test_gmres_pytree(self):
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
self.assertAlmostEqual(expected["b"], actual["b"], places=5)

@jtu.skip_on_devices('tpu')
@jax.default_matmul_precision("float32")
def test_gmres_matmul(self):
A = CustomOperator(2 * jnp.eye(3))
b = jnp.arange(9.0).reshape(3, 3)
Expand Down
5 changes: 3 additions & 2 deletions tests/pmap_test.py
Expand Up @@ -1724,6 +1724,7 @@ def f(args_list):
self.assertAllClose(f([np.array([i] * ndevices) for i in range(500)]),
jnp.array([sum(vals)] * ndevices))

@jax.default_matmul_precision("float32")
def testPostProcessMap2(self):
# code from https://github.com/google/jax/issues/2787
def vv(x, y):
Expand All @@ -1743,8 +1744,8 @@ def distributed_matrix_vector(x, y):
y = random.normal(key, (10, 50, 1))
result = batched_mvm(y)
expected = jnp.einsum('ij,njk->nik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3,
rtol=1e-3)

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
Expand Down
52 changes: 22 additions & 30 deletions tests/xmap_test.py
Expand Up @@ -1348,6 +1348,7 @@ def f(x, y):
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))

@jtu.with_mesh([('r1', 2)])
@jax.default_matmul_precision("float32")
def testPdotBatchingShardUncontractedDim(self):
def f(x, y):
return lax.pdot(x, y, 'i')
Expand Down Expand Up @@ -1375,6 +1376,7 @@ def f(x, y):
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape))
)))
@jax.default_matmul_precision("float32")
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
mesh_data):
rng = jtu.rand_default(self.rng())
Expand All @@ -1398,9 +1400,7 @@ def pdot_fun(x, y):
result = fun(lhs, rhs)

expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(result, expected, check_dtypes=False,
atol=tol, rtol=tol)
self.assertAllClose(result, expected, check_dtypes=False)

@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": f"_{next(test_counter)}",
Expand All @@ -1412,6 +1412,7 @@ def pdot_fun(x, y):
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape))
)))
@jax.default_matmul_precision("float32")
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
axis_resources, mesh_data):
rng = jtu.rand_default(self.rng())
Expand Down Expand Up @@ -1441,11 +1442,8 @@ def pdot_fun(x, y, out_bar):
with jtu.with_mesh(mesh_data):
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)

tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
atol=tol, rtol=tol)
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
atol=tol, rtol=tol)
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False)
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False)

def test_xeinsum_vector_dot(self):
rng = self.rng()
Expand All @@ -1465,6 +1463,7 @@ def test_xeinsum_outer_product(self):
expected = np.einsum('i,j->ij', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

@jax.default_matmul_precision("float32")
def test_xeinsum_matmul(self):
rng = self.rng()
x = rng.randn(3, 4)
Expand All @@ -1475,9 +1474,7 @@ def check(spec):
in_axes=(['i', 'j'], ['j', 'k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,jk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)
check('{i,j},{j,k}->{i,k}')
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j},{k,j}->{k}')
Expand All @@ -1500,14 +1497,14 @@ def test_xeinsum_no_named_axes_batch_vector_dot(self):
expected = np.einsum('ij,ij->i', x, y)
self.assertAllClose(out, expected, check_dtypes=True)

@jax.default_matmul_precision("float32")
def test_xeinsum_no_named_axes_batch_matmul(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(3, 4, 2)
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

def test_xeinsum_no_named_axes_reduce_sum(self):
rng = self.rng()
Expand All @@ -1518,15 +1515,16 @@ def test_xeinsum_no_named_axes_reduce_sum(self):
self.assertAllClose(out, expected, check_dtypes=True)


@jax.default_matmul_precision("float32")
def test_xeinsum_no_named_axes_reduce_and_contract(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 5, 4)
y = rng.randn(2, 4, 2)
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
expected = np.einsum('bij,cjk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

@jax.default_matmul_precision("float32")
def test_xeinsum_named_axes_reduce(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 4)
Expand All @@ -1537,12 +1535,11 @@ def check(spec):
in_axes=(['i', 'j'], ['k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)
check('{i,j},{k}->{i,k}')

@jtu.with_mesh([('x', 2), ('y', 2)])
@jax.default_matmul_precision("float32")
def test_xeinsum_named_axes_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(6, 4)
Expand All @@ -1554,16 +1551,15 @@ def check(spec):
out_axes=['i', 'k'],
axis_resources={'i': 'x', 'k': 'y'})(x, y)
expected = np.einsum('ij,k->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

check('{i,j},{k}->{i,k}')
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j,i},{k}->{i,k}')
check('{j,i},{k}->{k,i}')

@jtu.with_mesh([('x', 2), ('y', 2)])
@jax.default_matmul_precision("float32")
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 3, 4)
Expand All @@ -1575,14 +1571,13 @@ def check(spec):
out_axes=['b', 'i', 'k'],
axis_resources={'b': 'x', 'j': 'y'})(x, y)
expected = np.einsum('bij,bjk->bik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

check('{b,i,j},{b,j,k}->{b,i,k}')
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!

@jtu.with_mesh([('x', 2), ('y', 2)])
@jax.default_matmul_precision("float32")
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4)
Expand All @@ -1593,15 +1588,14 @@ def check(spec):
out_axes=['b'],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bij->b', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

check('{b,i,j}->{b}')
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
check('{i,j,b}->{b}')

@jtu.with_mesh([('x', 2), ('y', 2)])
@jax.default_matmul_precision("float32")
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
rng = np.random.RandomState(0)
x = rng.randn(8, 6, 4, 5)
Expand All @@ -1612,9 +1606,7 @@ def check(spec):
out_axes=['b', ...],
axis_resources={'b': 'x', 'i': 'y'})(x)
expected = np.einsum('bijk->bk', x)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
self.assertAllClose(out, expected, check_dtypes=True)

check('jk{i,b}->k{b}')

Expand Down

0 comments on commit 803c729

Please sign in to comment.