Skip to content

Commit

Permalink
Remove the skip_on_xla_cpu_mlir decorator.
Browse files Browse the repository at this point in the history
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
  • Loading branch information
hawkinsp authored and jax authors committed Sep 25, 2023
1 parent d3f5e7f commit 5aaa15d
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 25 deletions.
13 changes: 0 additions & 13 deletions jax/_src/test_util.py
Expand Up @@ -397,19 +397,6 @@ def undo():
return undo


def skip_on_xla_cpu_mlir(test_method):
"""A decorator to skip tests when MLIR lowering is enabled."""
@functools.wraps(test_method)
def test_method_wrapper(self, *args, **kwargs):
xla_flags = os.getenv('XLA_FLAGS') or ''
if '--xla_cpu_use_xla_runtime' in xla_flags:
test_name = getattr(test_method, '__name__', '[unknown test]')
raise unittest.SkipTest(
f'{test_name} not supported on XLA:CPU MLIR')
return test_method(self, *args, **kwargs)
return test_method_wrapper


def skip_on_flag(flag_name, skip_value):
"""A decorator for test methods to skip the test when flags are set."""
def skip(test_method): # pylint: disable=missing-docstring
Expand Down
3 changes: 0 additions & 3 deletions tests/api_test.py
Expand Up @@ -1204,7 +1204,6 @@ def test_jit_lower_compile_as_text(self):
self.assertIsInstance(f.as_text(), (str, type(None)))
self.assertIsInstance(g.as_text(), (str, type(None)))

@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_cost_analysis(self):
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
if "PJRT C API" in xla_bridge.get_backend().platform_version:
Expand All @@ -1214,14 +1213,12 @@ def test_jit_lower_cost_analysis(self):
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
self.assertIsNotNone(f.cost_analysis())
self.assertIsNotNone(g.cost_analysis())

@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_memory_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
Expand Down
3 changes: 0 additions & 3 deletions tests/pjit_test.py
Expand Up @@ -1118,7 +1118,6 @@ def f(x, y):
self.assertIsInstance(f.as_text(), (str, type(None)))

@jtu.with_mesh([('x', 2), ('y', 2)])
@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
@partial(pjit,
in_shardings=P(('x', 'y'),),
Expand All @@ -1132,7 +1131,6 @@ def f(x, y):
f.cost_analysis() # doesn't raise

@jtu.with_mesh([('x', 2), ('y', 2)])
@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):
@partial(pjit,
in_shardings=P(('x', 'y'),),
Expand All @@ -1146,7 +1144,6 @@ def f(x, y):
f.cost_analysis() # doesn't raise

@jtu.with_mesh([('x', 2), ('y', 2)])
@jtu.skip_on_xla_cpu_mlir
def testLowerCompileMemoryAnalysis(self):
@partial(pjit,
in_shardings=P(('x', 'y'),),
Expand Down
3 changes: 0 additions & 3 deletions tests/pmap_test.py
Expand Up @@ -300,23 +300,20 @@ def testLowerCompileAsText(self):
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))

@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
f.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
f.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def testLowerCompileMemoryAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
Expand Down
3 changes: 0 additions & 3 deletions tests/xmap_test.py
Expand Up @@ -765,7 +765,6 @@ def testLowerCompileAsText(self):
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))

@jtu.skip_on_xla_cpu_mlir
def testLowerCostAnalysis(self):
# TODO(b/261771737): add support for uncompiled cost analysis in C API.
if "PJRT C API" in xla_bridge.get_backend().platform_version:
Expand All @@ -775,14 +774,12 @@ def testLowerCostAnalysis(self):
f = f.lower(x)
f.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def testLowerCompileCostAnalysis(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
f = f.lower(x).compile()
f.cost_analysis() # doesn't raise

@jtu.skip_on_xla_cpu_mlir
def testLowerCompileMemoryAnalysis(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
Expand Down

0 comments on commit 5aaa15d

Please sign in to comment.