Skip to content

Commit

Permalink
Add jax_array coverage to debug_nans_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478079509
  • Loading branch information
yashk2810 authored and jax authors committed Sep 30, 2022
1 parent ec41de2 commit fb8558c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 1 addition & 5 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,9 @@ jax_test(
srcs = ["custom_object_test.py"],
)

py_test(
jax_test(
name = "debug_nans_test",
srcs = ["debug_nans_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)

py_test(
Expand Down
2 changes: 1 addition & 1 deletion tests/debug_nans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")

p = jax.experimental.PartitionSpec('x')
p = pjit.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p)
Expand Down

0 comments on commit fb8558c

Please sign in to comment.