Skip to content

Commit

Permalink
Remove use_stablehlo as minimum mlir_api_version >= 43
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 512176274
  • Loading branch information
yashk2810 authored and jax authors committed Feb 24, 2023
1 parent aa5e229 commit d84ac22
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 37 deletions.
17 changes: 4 additions & 13 deletions jax/_src/dispatch.py
Expand Up @@ -56,7 +56,6 @@
from jax._src.interpreters import batching
from jax._src.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import pmap_lib
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -993,20 +992,12 @@ def hlo(self) -> xc.XlaComputation:
use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
if use_stablehlo:
return super().mhlo()
else:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
return self._hlo
return super().mhlo()

def stablehlo(self) -> ir.Module:
if use_stablehlo:
if self.is_trivial():
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
else:
return super().stablehlo()
if self.is_trivial():
raise ValueError("A trivial computation has no StableHLO")
return self._hlo

def compile(self) -> XlaCompiledComputation:
if self._executable is None:
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/parallel.py
Expand Up @@ -37,7 +37,7 @@
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.util import (
unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis)
Expand Down Expand Up @@ -990,7 +990,7 @@ def _all_to_all_lowering(ctx, x, *,
else:
other_args = {}
return hlo.AllToAllOp(
x if use_stablehlo else [x],
x,
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/lib/mlir/dialects/__init__.py
Expand Up @@ -24,8 +24,4 @@
import jaxlib.mlir.dialects.stablehlo as stablehlo

# Alias that is set up to abstract away the transition from MHLO to StableHLO.
use_stablehlo = xla_client.mlir_api_version >= 42
if use_stablehlo:
import jaxlib.mlir.dialects.stablehlo as hlo
else:
import jaxlib.mlir.dialects.mhlo as hlo # type: ignore[no-redef]
import jaxlib.mlir.dialects.stablehlo as hlo
24 changes: 7 additions & 17 deletions jax/_src/stages.py
Expand Up @@ -44,7 +44,6 @@
from jax._src import traceback_util
from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import xla_client as xc


Expand Down Expand Up @@ -287,30 +286,21 @@ def hlo(self) -> xc.XlaComputation:

def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
if use_stablehlo:
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_bytecode(self.stablehlo()))
with self.stablehlo().context:
return ir.Module.parse(module_str)
else:
raise NotImplementedError("must override")
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_bytecode(self.stablehlo()))
with self.stablehlo().context:
return ir.Module.parse(module_str)

def stablehlo(self) -> ir.Module:
"""Return a StableHLO representation of this computation."""
if use_stablehlo:
raise NotImplementedError("must override")
else:
module_str = xla_extension.mlir.mhlo_to_stablehlo(
mlir.module_to_bytecode(self.mhlo()))
with self.mhlo().context:
return ir.Module.parse(module_str)
raise NotImplementedError("must override")

def compile(self) -> Executable:
raise NotImplementedError("must override")

def as_text(self, dialect: Optional[str] = None) -> str:
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
dialect = "stablehlo"
if dialect == "mhlo":
return str(self.mhlo())
elif dialect == "stablehlo":
Expand All @@ -322,7 +312,7 @@ def as_text(self, dialect: Optional[str] = None) -> str:

def compiler_ir(self, dialect: Optional[str] = None) -> Any:
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
dialect = "stablehlo"
if dialect == "mhlo":
return self.mhlo()
elif dialect == "stablehlo":
Expand Down

0 comments on commit d84ac22

Please sign in to comment.