Skip to content

Commit

Permalink
Migrate JAX from producing MHLO to producing StableHLO
Browse files Browse the repository at this point in the history
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. openxla/stablehlo#744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 495786033
  • Loading branch information
Eugene Burmako authored and jax authors committed Dec 21, 2022
1 parent e89b60e commit 6346256
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 37 deletions.
12 changes: 6 additions & 6 deletions docs/aot.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ are arrays, JAX does the following in order:
their shape and element type).

2. **Lower** this specialized, staged-out computation to the XLA compiler's
input language, MHLO.
input language, StableHLO.

3. **Compile** the lowered HLO program to produce an optimized executable for
the target device (CPU, GPU, or TPU).
Expand All @@ -45,9 +45,9 @@ way. An example:
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
%2 = mhlo.add %1, %arg1 : tensor<i32>
%0 = stablehlo.constant dense<2> : tensor<i32>
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
%2 = stablehlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
}
}
Expand Down Expand Up @@ -129,8 +129,8 @@ to invoke the resulting compiled function. Continuing with our example above:
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<14> : tensor<i32>
%1 = mhlo.add %0, %arg0 : tensor<i32>
%0 = stablehlo.constant dense<14> : tensor<i32>
%1 = stablehlo.add %0, %arg0 : tensor<i32>
return %1 : tensor<i32>
}
}
Expand Down
18 changes: 15 additions & 3 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
Expand Down Expand Up @@ -967,9 +968,20 @@ def hlo(self) -> xc.XlaComputation:
use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
return self._hlo

def stablehlo(self) -> ir.Module:
if use_stablehlo:
if self.is_trivial():
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
else:
return super().stablehlo()

def compile(self) -> XlaCompiledComputation:
if self._executable is None:
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import mlir_api_version
from jax._src.lib.mlir.dialects import hlo, use_stablehlo

unsafe_map, map = map, safe_map # type: ignore

Expand Down Expand Up @@ -955,7 +956,7 @@ def _all_to_all_lowering(ctx, x, *,
else:
other_args = {}
return hlo.AllToAllOp(
[x],
x if use_stablehlo else [x],
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/lib/mlir/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import jaxlib.mlir.dialects.stablehlo as stablehlo

# Alias that is set up to abstract away the transition from MHLO to StableHLO.
# At the moment, it points to MHLO, but in the future it will start to
# conditionally and then unconditionally point to StableHLO.
import jaxlib.mlir.dialects.mhlo as hlo
use_stablehlo = xla_client.mlir_api_version >= 42
if use_stablehlo:
import jaxlib.mlir.dialects.stablehlo as hlo
else:
import jaxlib.mlir.dialects.mhlo as hlo
35 changes: 25 additions & 10 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from jax._src import traceback_util
from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax.interpreters import mlir
from jax.interpreters import xla

Expand Down Expand Up @@ -169,7 +170,8 @@ def compiler_ir(self, dialect: Optional[str] = None) -> Any:
compiler.
Args:
dialect: Optional string specifying a representation dialect (e.g. "mhlo")
dialect: Optional string specifying a representation dialect
(e.g. "stablehlo")
"""
raise NotImplementedError

Expand Down Expand Up @@ -264,20 +266,31 @@ def hlo(self) -> xc.XlaComputation:

def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
raise NotImplementedError("must override")
if use_stablehlo:
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_string(self.stablehlo()))
with mlir.make_ir_context():
return ir.Module.parse(module_str)
else:
raise NotImplementedError("must override")

def stablehlo(self) -> ir.Module:
"""Return a StableHLO representation of this computation."""
module_str = xla_extension.mlir.mhlo_to_stablehlo(
mlir.module_to_string(self.mhlo()))
with mlir.make_ir_context():
return ir.Module.parse(module_str)
if use_stablehlo:
raise NotImplementedError("must override")
else:
module_str = xla_extension.mlir.mhlo_to_stablehlo(
mlir.module_to_string(self.mhlo()))
with mlir.make_ir_context():
return ir.Module.parse(module_str)

def compile(self) -> Executable:
raise NotImplementedError("must override")

def as_text(self, dialect: Optional[str] = None) -> str:
if dialect is None or dialect == "mhlo":
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
if dialect == "mhlo":
return str(self.mhlo())
elif dialect == "stablehlo":
return str(self.stablehlo())
Expand All @@ -287,7 +300,9 @@ def as_text(self, dialect: Optional[str] = None) -> str:
raise ValueError(f"unknown dialect: {dialect}")

def compiler_ir(self, dialect: Optional[str] = None) -> Any:
if dialect is None or dialect == "mhlo":
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
if dialect == "mhlo":
return self.mhlo()
elif dialect == "stablehlo":
return self.stablehlo()
Expand Down Expand Up @@ -579,7 +594,7 @@ def as_text(self, dialect: Optional[str] = None) -> str:
nor reliable serialization. It is relayed directly to external callers.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
return self._lowering.as_text(dialect)

Expand All @@ -594,7 +609,7 @@ def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:
runtime.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
try:
return self._lowering.compiler_ir(dialect)
Expand Down
30 changes: 25 additions & 5 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
Expand Down Expand Up @@ -1520,7 +1520,16 @@ def hlo(self) -> xc.XlaComputation:
use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
return self._hlo

def stablehlo(self) -> ir.Module:
if use_stablehlo:
return self._hlo
else:
return super().stablehlo()

@profiler.annotate_function
def compile(self) -> PmapExecutable:
Expand Down Expand Up @@ -3165,9 +3174,20 @@ def hlo(self) -> xc.XlaComputation:
use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
return self._hlo

def stablehlo(self) -> ir.Module:
if use_stablehlo:
if self.is_trivial:
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
else:
return super().stablehlo()

def compile(self,
_allow_propagation_to_outputs : bool = False,
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/ducc_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import List

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo


from .hlo_helpers import custom_call
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import operator

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/hlo_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion jaxlib/lapack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# via CustomCallWithLayout.

import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo

import numpy as np
from jaxlib import xla_client
Expand Down
6 changes: 3 additions & 3 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ def test_diff_executables(self):
cc.initialize_cache(tmpdir)
computation1 = str(jax.jit(lambda x, y: x + y)
.lower(1, 1)
.compiler_ir(dialect="mhlo"))
.compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y)
.lower(2, 2)
.compiler_ir(dialect="mhlo"))
.compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
Expand All @@ -230,7 +230,7 @@ def test_put_executable(self):
cc.initialize_cache(tmpdir)
computation = str(jax.jit(lambda x, y: x + y)
.lower(np.int32(1), np.int32(1))
.compiler_ir(dialect="mhlo"))
.compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
Expand Down
2 changes: 2 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def f(x, y):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompilerIR(self):
Expand All @@ -938,6 +939,7 @@ def f(x, y):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))

@jtu.ignore_warning(category=DeprecationWarning)
@jtu.with_mesh([('x', 2), ('y', 2)])
Expand Down
2 changes: 2 additions & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def testLowerAsText(self):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)

def testLowerCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
Expand All @@ -285,6 +286,7 @@ def testLowerCompilerIR(self):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))

@jtu.ignore_warning(category=DeprecationWarning)
def testLowerCompileCompilerIR(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def testLowerAsText(self):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)

def testLowerCompilerIR(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
Expand All @@ -752,6 +753,7 @@ def testLowerCompilerIR(self):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))

@jtu.ignore_warning(category=DeprecationWarning)
def testLowerCompileCompilerIR(self):
Expand Down

0 comments on commit 6346256

Please sign in to comment.