From 5aaa15df845aeedf999e6197d6d7bbdcf1f4e5c4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 25 Sep 2023 07:59:57 -0700 Subject: [PATCH] Remove the skip_on_xla_cpu_mlir decorator. We no longer test this variant in CI, so we don't need code to skip it. PiperOrigin-RevId: 568219651 --- jax/_src/test_util.py | 13 ------------- tests/api_test.py | 3 --- tests/pjit_test.py | 3 --- tests/pmap_test.py | 3 --- tests/xmap_test.py | 3 --- 5 files changed, 25 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 423242cef69e..87515f9bcddc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -397,19 +397,6 @@ def undo(): return undo -def skip_on_xla_cpu_mlir(test_method): - """A decorator to skip tests when MLIR lowering is enabled.""" - @functools.wraps(test_method) - def test_method_wrapper(self, *args, **kwargs): - xla_flags = os.getenv('XLA_FLAGS') or '' - if '--xla_cpu_use_xla_runtime' in xla_flags: - test_name = getattr(test_method, '__name__', '[unknown test]') - raise unittest.SkipTest( - f'{test_name} not supported on XLA:CPU MLIR') - return test_method(self, *args, **kwargs) - return test_method_wrapper - - def skip_on_flag(flag_name, skip_value): """A decorator for test methods to skip the test when flags are set.""" def skip(test_method): # pylint: disable=missing-docstring diff --git a/tests/api_test.py b/tests/api_test.py index 3788e9b7796c..d289e2297def 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1204,7 +1204,6 @@ def test_jit_lower_compile_as_text(self): self.assertIsInstance(f.as_text(), (str, type(None))) self.assertIsInstance(g.as_text(), (str, type(None))) - @jtu.skip_on_xla_cpu_mlir def test_jit_lower_cost_analysis(self): # TODO(b/261771737): add support for uncompiled cost analysis in C API. if "PJRT C API" in xla_bridge.get_backend().platform_version: @@ -1214,14 +1213,12 @@ def test_jit_lower_cost_analysis(self): f.cost_analysis() # doesn't raise g.cost_analysis() # doesn't raise - @jtu.skip_on_xla_cpu_mlir def test_jit_lower_compile_cost_analysis(self): f = self.jit(lambda x: x).lower(1.).compile() g = self.jit(lambda x: x + 4).lower(1.).compile() self.assertIsNotNone(f.cost_analysis()) self.assertIsNotNone(g.cost_analysis()) - @jtu.skip_on_xla_cpu_mlir def test_jit_lower_compile_memory_analysis(self): f = self.jit(lambda x: x).lower(1.).compile() g = self.jit(lambda x: x + 4).lower(1.).compile() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a4d7f020c7be..ca4addd5ed55 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1118,7 +1118,6 @@ def f(x, y): self.assertIsInstance(f.as_text(), (str, type(None))) @jtu.with_mesh([('x', 2), ('y', 2)]) - @jtu.skip_on_xla_cpu_mlir def testLowerCostAnalysis(self): @partial(pjit, in_shardings=P(('x', 'y'),), @@ -1132,7 +1131,6 @@ def f(x, y): f.cost_analysis() # doesn't raise @jtu.with_mesh([('x', 2), ('y', 2)]) - @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): @partial(pjit, in_shardings=P(('x', 'y'),), @@ -1146,7 +1144,6 @@ def f(x, y): f.cost_analysis() # doesn't raise @jtu.with_mesh([('x', 2), ('y', 2)]) - @jtu.skip_on_xla_cpu_mlir def testLowerCompileMemoryAnalysis(self): @partial(pjit, in_shardings=P(('x', 'y'),), diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 86c190fe4046..eae1d6130936 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -300,7 +300,6 @@ def testLowerCompileAsText(self): f = f.lower(x).compile() self.assertIsInstance(f.as_text(), (str, type(None))) - @jtu.skip_on_xla_cpu_mlir def testLowerCostAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) @@ -308,7 +307,6 @@ def testLowerCostAnalysis(self): f = f.lower(x) f.cost_analysis() # doesn't raise - @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) @@ -316,7 +314,6 @@ def testLowerCompileCostAnalysis(self): f = f.lower(x).compile() f.cost_analysis() # doesn't raise - @jtu.skip_on_xla_cpu_mlir def testLowerCompileMemoryAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 534de67577e2..dc424c7fdbb4 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -765,7 +765,6 @@ def testLowerCompileAsText(self): f = f.lower(x).compile() self.assertIsInstance(f.as_text(), (str, type(None))) - @jtu.skip_on_xla_cpu_mlir def testLowerCostAnalysis(self): # TODO(b/261771737): add support for uncompiled cost analysis in C API. if "PJRT C API" in xla_bridge.get_backend().platform_version: @@ -775,14 +774,12 @@ def testLowerCostAnalysis(self): f = f.lower(x) f.cost_analysis() # doesn't raise - @jtu.skip_on_xla_cpu_mlir def testLowerCompileCostAnalysis(self): f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]) x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2)) f = f.lower(x).compile() f.cost_analysis() # doesn't raise - @jtu.skip_on_xla_cpu_mlir def testLowerCompileMemoryAnalysis(self): f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]) x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))