Skip to content

Commit

Permalink
tweak broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 10, 2021
1 parent 9577860 commit d11bba9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,9 +2134,9 @@ def f():
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

# TODO(jakevdp): re-enable this if possible.
@unittest.skipIf(True, "broken by convert_element_type change.")
def test_xla_computation_zeros_doesnt_device_put(self):
raise SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

Expand Down
12 changes: 8 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,13 +967,17 @@ def test_prng_errors(self):
with self.assertRaises(OverflowError):
api.jit(random.PRNGKey)(seed)

def test_random_split_doesnt_device_put(self):
# TODO(mattjj): Enable this after fixing convert_element_type.
raise SkipTest("Broken by convert_element_type.")
def test_random_split_doesnt_device_put_during_tracing(self):
raise SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise SkipTest("test is omnistaging-specific")

key = random.PRNGKey(1)
with jtu.count_device_put() as count:
api.jit(random.split)(key)
key, _ = random.split(key, 2)
self.assertEqual(count[0], 0)
self.assertEqual(count[0], 1) # 1 for the argument device_put call



Expand Down

0 comments on commit d11bba9

Please sign in to comment.