Skip to content

Commit

Permalink
DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
  • Loading branch information
2 people authored and jax authors committed Sep 6, 2022
1 parent e9204e3 commit b7e4e44
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 47 deletions.
33 changes: 16 additions & 17 deletions jax/_src/dispatch.py
Expand Up @@ -302,22 +302,25 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
(i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh)
in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i
for i in in_shardings]
if not in_shardings:
inp_device_assignment = da
else:
inp_device_assignment = None

# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED,
donated_invars, in_avals,
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
committed=committed).compile(
committed=committed, inp_device_assignment=inp_device_assignment).compile(
_allow_propagation_to_outputs=True).unsafe_call


def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, keep_unused, *arg_specs):
# TODO(yashkatariya): Remove the `and arg_specs` from here once
# lower_sharding_computation supports no avals as input.
if config.jax_array and arg_specs:
if config.jax_array:
return sharded_lowering(fun, device, backend, name,
donated_invars, keep_unused, *arg_specs)
else:
Expand All @@ -327,6 +330,10 @@ def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
xla_callable = lu.cache(_xla_callable_uncached)


def is_single_device_sharding(sharding) -> bool:
return len(sharding.device_set) == 1


@contextlib.contextmanager
def log_elapsed_time(fmt: str):
if _on_exit:
Expand Down Expand Up @@ -517,19 +524,11 @@ def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
# primitives and nested jaxpr.
used.update(
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
kept_const_idx, new_constvars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns,
jaxpr.effects)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
return new_jaxpr, kept_const_idx, kept_var_idx


# We can optionally set a Jaxpr rewriter that can be applied just before
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -30,6 +30,7 @@
from jax._src import api
from jax._src import api_util
from jax._src import device_array
from jax._src import dispatch
from jax import linear_util as lu
from jax._src import dtypes
from jax import tree_util
Expand Down Expand Up @@ -1322,7 +1323,7 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
# (so it works in staged-out code as well as 'eager' code). Related to
# equi-sharding.
if (config.jax_array and hasattr(x, 'sharding') and
not isinstance(x.sharding, sharding.SingleDeviceSharding)):
not dispatch.is_single_device_sharding(x.sharding)):
return array.make_array_from_callback(
fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type]
return val
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/prng.py
Expand Up @@ -31,7 +31,7 @@
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.experimental.sharding import (
MeshPspecSharding, SingleDeviceSharding, PmapSharding, OpShardingSharding)
MeshPspecSharding, PmapSharding, OpShardingSharding)

from jax._src import dispatch
from jax._src import dtypes
Expand Down Expand Up @@ -364,7 +364,7 @@ def global_sharded_result_handler(aval, out_sharding, committed,
phys_handler_maker = pxla.global_result_handlers[
(core.ShapedArray, output_type)]

if isinstance(out_sharding, SingleDeviceSharding):
if dispatch.is_single_device_sharding(out_sharding):
phys_sharding = out_sharding
elif isinstance(out_sharding, MeshPspecSharding):
trailing_spec = [None] * len(key_shape)
Expand Down
1 change: 1 addition & 0 deletions jax/core.py
Expand Up @@ -2162,6 +2162,7 @@ def _unmap_dshaped_array(
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
}

@contextmanager
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/array.py
Expand Up @@ -451,7 +451,7 @@ def _device_put_array(x, device: Optional[Device]):
# TODO(yashkatariya): Remove this restriction and the round trip via host
# once lowering to XLA goes through `lower_mesh_computation`.
assert x.is_fully_addressable()
if isinstance(x.sharding, SingleDeviceSharding):
if dispatch.is_single_device_sharding(x.sharding):
x = dispatch._copy_device_array_to_device(pxla._set_aval(x._arrays[0]), device)
return (x,)
else:
Expand All @@ -462,7 +462,7 @@ def _device_put_array(x, device: Optional[Device]):


def _array_pmap_shard_arg(x, devices, indices, mode):
if isinstance(x.sharding, SingleDeviceSharding):
if dispatch.is_single_device_sharding(x.sharding):
return pxla._shard_device_array(x, devices, indices, mode)

if x._fast_path_args is None:
Expand All @@ -484,7 +484,7 @@ def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
if isinstance(x.sharding, SingleDeviceSharding):
if dispatch.is_single_device_sharding(x.sharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
Expand Down
22 changes: 22 additions & 0 deletions jax/interpreters/partial_eval.py
Expand Up @@ -961,6 +961,17 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr

@weakref_lru_cache
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
config.jax_enable_checks and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
lifted_jaxpr = Jaxpr(constvars=tuple(constvars), invars=invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr

def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
Expand Down Expand Up @@ -1306,6 +1317,17 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate = (instantiate,) * len(jaxpr.invars)
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))


def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[bool], List[bool]]:
jaxpr_ = convert_constvars_jaxpr(jaxpr)
new_jaxpr_, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs)
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
new_jaxpr = convert_invars_to_constvars(new_jaxpr_, sum(used_consts))
return new_jaxpr, used_consts, used_inputs


@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
instantiate: Tuple[bool, ...]
Expand Down
123 changes: 109 additions & 14 deletions jax/interpreters/pxla.py
Expand Up @@ -2670,7 +2670,8 @@ def lower_sharding_computation(
global_in_avals: Sequence[core.ShapedArray],
in_is_global: Sequence[bool],
keep_unused: bool,
committed: bool):
committed: bool,
inp_device_assignment: Optional[Sequence[xc.Device]] = None):
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton _UNSPECIFIED because the
Expand All @@ -2680,14 +2681,19 @@ def lower_sharding_computation(
"""
# Device assignment across all inputs and outputs should be the same. This
# is checked in pjit.
if _is_unspecified(out_shardings):
backend, first_sharding = _get_backend_from_shardings(in_shardings)
if inp_device_assignment is not None:
device_assignment = inp_device_assignment
backend = xb.get_device_backend(device_assignment[0])
first_sharding = None
else:
# type ignore because mypy can't understand that out_shardings that are
# UNSPECIFIED singleton are filtered above.
backend, first_sharding = _get_backend_from_shardings(
it.chain(in_shardings, out_shardings)) # type: ignore
device_assignment = first_sharding._device_assignment
if _is_unspecified(out_shardings):
backend, first_sharding = _get_backend_from_shardings(in_shardings)
else:
# type ignore because mypy can't understand that out_shardings that are
# UNSPECIFIED singleton are filtered above.
backend, first_sharding = _get_backend_from_shardings(
it.chain(in_shardings, out_shardings)) # type: ignore
device_assignment = first_sharding._device_assignment

name_stack = new_name_stack(wrap_name(fun_name, api_name))

Expand All @@ -2696,6 +2702,7 @@ def lower_sharding_computation(
"in {elapsed_time} sec"):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
kept_outputs = [True] * len(global_out_avals)

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
Expand All @@ -2721,10 +2728,29 @@ def lower_sharding_computation(
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx

if not first_sharding.is_fully_addressable():
process_index = xb.process_index()
local_device_assignment = [d for d in device_assignment
if d.process_index == process_index]
if len(device_assignment) != len(local_device_assignment):
check_multihost_collective_allowlist(jaxpr)

has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings) and # type: ignore
not hasattr(backend, "compile_replicated")):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts,
global_in_avals=global_in_avals, global_out_avals=global_out_avals,
in_shardings=in_shardings,
device_assignment=device_assignment, committed=committed,
kept_var_idx=kept_var_idx, keepalive=None)

# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
# handled here so as to deprecate the lower_xla_callable codepath when
Expand Down Expand Up @@ -2807,6 +2833,7 @@ def lower_sharding_computation(
return MeshComputation(
str(name_stack),
module,
False,
donated_invars,
mesh=None,
global_in_avals=global_in_avals,
Expand Down Expand Up @@ -2969,6 +2996,7 @@ def lower_mesh_computation(
return MeshComputation(
str(name_stack),
module,
False,
donated_invars,
mesh=mesh,
global_in_avals=global_in_avals,
Expand All @@ -2994,16 +3022,19 @@ class MeshComputation(stages.XlaLowering):
_executable: Optional[MeshExecutable]

def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
donated_invars: Sequence[bool], **compile_args):
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
self.is_trivial = is_trivial
self._donated_invars = donated_invars
self.compile_args = compile_args
self._executable = None

# -- stages.XlaLowering overrides

def hlo(self) -> xc.XlaComputation:
if self.is_trivial:
raise ValueError("A trivial computation has no HLO")
# this is a method for api consistency with dispatch.XlaComputation
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
Expand All @@ -3012,6 +3043,8 @@ def hlo(self) -> xc.XlaComputation:
use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
Expand All @@ -3022,10 +3055,13 @@ def compile(self,
_allow_propagation_to_outputs : bool = False,
_allow_compile_replicated : bool = True) -> MeshExecutable:
if self._executable is None:
self._executable = MeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
if self.is_trivial:
self._executable = MeshExecutable.from_trivial_jaxpr(**self.compile_args)
else:
self._executable = MeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
return self._executable


Expand Down Expand Up @@ -3245,6 +3281,32 @@ def from_hlo(name: str,
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering)

@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
in_shardings, device_assignment,
committed, kept_var_idx, keepalive) -> MeshExecutable:
assert keepalive is None
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, device_assignment)
if config.jax_array or config.jax_parallel_functions_output_gda:
are_global = [True] * len(global_out_avals)
else:
are_global = [False] * len(global_out_avals)
_, indices, _ = _get_input_metadata(global_out_avals, out_shardings,
are_global)
process_index = xb.process_index()
local_device_assignment = [d for d in device_assignment
if d.process_index == process_index]
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices,
InputsHandlerMode.pjit_or_xmap)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
out_shardings, False)

# -- stages.XlaExecutable overrides

def xla_extension_executable(self):
Expand All @@ -3259,6 +3321,39 @@ def call(self, *args):
return self.unsafe_call(*args)


def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> List[XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
# a replicated sharding
from jax.experimental import array
from jax.experimental import sharding
rep = sharding.OpShardingSharding(
device_assignment, sharding._get_replicated_op_sharding())
shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.Array):
shardings[constvar] = constval.sharding
map(shardings.setdefault, jaxpr.invars, in_shardings)
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
for x in jaxpr.outvars]


def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
env: Dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return out_handler(in_handler(outs))


@lru_cache()
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
from jax.experimental.sharding import MeshPspecSharding
Expand Down
5 changes: 1 addition & 4 deletions tests/api_test.py
Expand Up @@ -1070,10 +1070,7 @@ def test_jit_lower_no_prunning(self):
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
with jtu.count_device_put() as count:
_ = jitted_f(1, 2)
if config.jax_array:
self.assertEqual(count[0], 2)
else:
self.assertEqual(count[0], 1)
self.assertEqual(count[0], 1)

@jtu.ignore_warning(category=DeprecationWarning)
def test_jit_lower_compile_compiler_ir(self):
Expand Down

0 comments on commit b7e4e44

Please sign in to comment.