Skip to content

Commit

Permalink
Fix failing device_put rules tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 525179481
  • Loading branch information
romanngg committed Apr 19, 2023
1 parent 055ff43 commit ed115c7
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/rules_test.py
Expand Up @@ -482,6 +482,7 @@ def _compare_jacobians(self, j_fwd, j_rev, j_rule, primitive):
else:
if primitive == lax.reshape_p:
# Reshape Jacobian is special-case defined as identity.
j_rule: np.ndarray
j_rule = j_rule.reshape(j_fwd.shape)

self.assertAllClose(j_fwd, j_rev)
Expand Down Expand Up @@ -581,10 +582,12 @@ def _test_primitive(
def test_unary(self, primitive: Optional[Primitive], shape, dtype, params):
if primitive == lax.device_put_p:
# Can't instantiate devices at test generation time; using subtests.
for device in [None] + jax.devices() + jax.devices('cpu'):
with self.subTest(device=device):
params = {'device': device}
self._test_primitive(primitive, [shape], dtype, params)
devices = [None] + jax.devices() + jax.devices('cpu')
for device in devices:
for src in devices:
with self.subTest(device=device, src=src):
params = {'device': device, 'src': src}
self._test_primitive(primitive, [shape], dtype, params)

else:
self._test_primitive(primitive, [shape], dtype, params)
Expand Down

0 comments on commit ed115c7

Please sign in to comment.