From 69da8393581a2eedc9ef9cee05313b8e93ac7d76 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 25 Sep 2023 09:29:22 -0700 Subject: [PATCH] Remove test code that checks for the se_tpu runtime. This runtime no longer exists. PiperOrigin-RevId: 568242078 --- jax/_src/test_util.py | 8 -------- tests/aot_test.py | 3 --- tests/cache_key_test.py | 2 -- 3 files changed, 13 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 87515f9bcddc..9ebd1965e537 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -323,14 +323,6 @@ def is_cloud_tpu(): return 'libtpu' in xla_bridge.get_backend().platform_version -def is_se_tpu(): - return ( - is_cloud_tpu() and not xla_bridge.using_pjrt_c_api() - ) or xla_bridge.get_backend().platform_version.startswith( - 'StreamExecutor TPU' - ) - - def is_device_tpu_v4(): return jax.devices()[0].device_kind == "TPU v4" diff --git a/tests/aot_test.py b/tests/aot_test.py index 7159d27d2a61..e5dda3016e6a 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -45,9 +45,6 @@ class JaxAotTest(jtu.JaxTestCase): @jtu.run_on_devices('tpu') def test_pickle_pjit_lower(self): - if jtu.is_se_tpu(): - raise unittest.SkipTest('StreamExecutor not supported.') - def fun(x): return x * x diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index a3f904b270e1..9220e75dcd49 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -136,8 +136,6 @@ def test_serialized_compile_options(self): ) @jtu.skip_on_devices("cpu") def test_hash_accelerator_devices(self): - if jtu.is_se_tpu(): - raise unittest.SkipTest("StreamExecutor not supported.") if xla_bridge.using_pjrt_c_api(): # TODO(b/290248051): expose PjRtTopologyDesc in PjRt C API. raise unittest.SkipTest("PjRt C API not yet supported.")