Skip to content

Commit

Permalink
Remove the jax_enable_mlir flag. MLIR is now the only supported code …
Browse files Browse the repository at this point in the history
…path.

This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
  • Loading branch information
hawkinsp authored and jax authors committed Apr 4, 2022
1 parent e1bbbf5 commit 1b8be90
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 102 deletions.
7 changes: 0 additions & 7 deletions jax/_src/config.py
Expand Up @@ -687,13 +687,6 @@ def _update_disable_jit_thread_local(val):
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")

enable_mlir = config.define_bool_state(
name='jax_enable_mlir',
default=lib.mlir_api_version >= 1,
help=('Enables an experimental code path that compiles JAX programs via '
'emitting the MLIR MHLO dialect.'))


# This flag is temporary and for internal use.
# TODO(tianjianlu): Removes after providing the information in BCOO meta data.
bcoo_cusparse_lowering = config.define_bool_state(
Expand Down
12 changes: 3 additions & 9 deletions jax/_src/dispatch.py
Expand Up @@ -255,15 +255,9 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"jit_{fun.__name__}"
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args=None,
arg_partitions=None, out_partitions=None)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
return XlaComputation(
name, module, False, donated_invars, nreps=nreps, device=device,
backend=backend, tuple_args=tuple_args, in_avals=abstract_args,
Expand Down
36 changes: 9 additions & 27 deletions jax/interpreters/pxla.py
Expand Up @@ -1061,17 +1061,11 @@ def lower_parallel_callable(
tuple_args = should_tuple_args(shards)
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args,
parts.arg_parts, parts.out_parts)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
shards=shards, tuple_args=tuple_args)

Expand Down Expand Up @@ -2255,25 +2249,20 @@ def lower_mesh_computation(
if auto_spmd_lowering:
in_partitions = None
out_partitions = None
out_partitions_t = None
else:
global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names)
in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto()
if aval is not core.abstract_unit else None
for aval, aval_in_axes in safe_zip(global_in_avals, in_axes)]
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
for aval, aval_out_axes in safe_zip(global_out_avals, out_axes)]
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
replicated_args = [False] * len(in_jaxpr_avals)
partitions_proto = True
axis_ctx = mlir.SPMDAxisContext(mesh)
axis_env = axis_ctx.axis_env
else:
replicated_args = [not axis for axis in in_axes]
in_partitions = None
out_partitions = None
out_partitions_t = None
partitions_proto = False
axis_env = xla.AxisEnv(nreps=mesh.size,
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
Expand All @@ -2282,17 +2271,10 @@ def lower_mesh_computation(
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
if config.jax_enable_mlir:
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)
else:
module = xla.lower_jaxpr_to_xla_module(
module_name, closed_jaxpr, backend.platform, axis_env,
name_stack, tuple_args, donated_invars, replicated_args,
in_partitions, out_partitions_t,
partitions_are_protos=partitions_proto)
module = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)

return MeshComputation(
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
Expand Down
54 changes: 0 additions & 54 deletions jax/interpreters/xla.py
Expand Up @@ -25,7 +25,6 @@
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
Sequence, Set, Type, Tuple, Union)
from typing_extensions import Protocol
import warnings

import numpy as np

Expand All @@ -34,7 +33,6 @@
from jax._src import ad_util
from jax._src import device_array
from jax._src import dtypes
from jax._src import profiler
from jax import linear_util as lu
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types)
Expand Down Expand Up @@ -764,58 +762,6 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
return tuple(out_donated_args)


@profiler.annotate_function
def lower_jaxpr_to_xla_module(
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
name_stack: Union[source_info_util.NameStack, str], tuple_args: bool,
donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]],
arg_partitions: Optional[Any],
out_partitions: Optional[Any],
partitions_are_protos: bool = False
) -> xc.XlaComputation:
"""Lowers a closed jaxpr to a top-level XLA module."""
c = xc.XlaBuilder(fn_name)
xla_consts = _xla_consts(c, jaxpr.consts)
xla_args, donated_invars = _xla_callable_args(
c, jaxpr.in_avals, tuple_args, donated_invars=donated_invars,
replicated=replicated_args, partitions=arg_partitions,
partitions_proto=partitions_are_protos)
ctx = TranslationContext(c, platform, axis_env, name_stack)
out_nodes = jaxpr_subcomp(ctx, jaxpr.jaxpr, xla_consts, *xla_args)
# Replace tokens with a dummy array value, because the runtime cannot
# handle token arguments.
out_aval_lens = [len(aval_to_xla_shapes(a)) for a in jaxpr.out_avals]
out_nodes = util.flatten(
[[_make_token_return_value(c)] if a is core.abstract_token
else v for a, v in zip(jaxpr.out_avals,
util.unflatten(out_nodes, out_aval_lens))])

# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
if out_partitions is None:
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
else:
build_out_tuple = partial(xops.Tuple, c, out_nodes)
if partitions_are_protos:
output = with_sharding_proto(c, out_partitions, build_out_tuple)
else:
output = with_sharding(c, out_partitions, build_out_tuple)

platforms_with_donation = ("gpu", "tpu")
if platform in platforms_with_donation:
donated_invars = set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
return c.build(output)


xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind

Expand Down
1 change: 0 additions & 1 deletion tests/filecheck/array.filecheck.py
Expand Up @@ -27,7 +27,6 @@

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)


Expand Down
1 change: 0 additions & 1 deletion tests/filecheck/math.filecheck.py
Expand Up @@ -26,7 +26,6 @@

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)


Expand Down
1 change: 0 additions & 1 deletion tests/filecheck/names.filecheck.py
Expand Up @@ -24,7 +24,6 @@

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)


Expand Down
1 change: 0 additions & 1 deletion tests/filecheck/shapes.filecheck.py
Expand Up @@ -25,7 +25,6 @@

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)


Expand Down
1 change: 0 additions & 1 deletion tests/filecheck/subcomputations.filecheck.py
Expand Up @@ -24,7 +24,6 @@

from jax.tests.filecheck.jax_filecheck_helpers import print_ir

jax.config.update("jax_enable_mlir", True)
jax.config.update("jax_enable_x64", True)


Expand Down

0 comments on commit 1b8be90

Please sign in to comment.