Skip to content

Commit

Permalink
Remove the fallback to lower_xla_callable that exist for `jit(pmap)…
Browse files Browse the repository at this point in the history
…` cases when `Array` was enabled and add minimal support to `lower_sharding_computation`.

The `jit(pmap)` codepath is added to `lower_sharding_computation` to delete the `lower_xla_callable` codepath when `jax.Array` is enabled by default. This will help in cleaning up the codebase and get rid of tech debt.

* Round trip through host for `Array`'s that have PmapSharding and come through the `jit` path (exactly like SDAs).

* For other cases i.e. when `num_replicas > 1`, default to the `_execute_replicated` path in dispatch.py from `lower_sharding_computation`. This is exactly same to what happens in `lower_xla_callable`.

PiperOrigin-RevId: 471033420
  • Loading branch information
yashk2810 authored and jax authors committed Aug 30, 2022
1 parent 2d14e35 commit 83d7e3f
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 93 deletions.
113 changes: 54 additions & 59 deletions jax/_src/dispatch.py
Expand Up @@ -97,6 +97,8 @@ def arg_spec(x: Any) -> ArgSpec:
aval = xla.abstractify(x)
try:
if config.jax_array:
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
else:
return aval, x._device
Expand Down Expand Up @@ -275,10 +277,6 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
xla.xla_call_p.def_impl(_xla_call_impl)


TracedJaxprInfo = collections.namedtuple(
'TracedJaxprInfo', ['jaxpr', 'out_jaxpr_avals', 'consts'])


def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
*arg_specs):
# TODO(yashkatariya): Remove the local imports from here when the functions
Expand All @@ -288,26 +286,16 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,

in_avals, in_shardings = util.unzip2(arg_specs)

with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun, in_avals, debug_info=pe.debug_info_final(fun, "jit"))
traced_jaxpr_info = TracedJaxprInfo(jaxpr, out_jaxpr_avals, consts)

# If jaxpr has the pmap primitive or if `backend` is provided on `jit`, then
# take the lower_xla_callable lowering path. This is because pmap's programming
# model is not compatible with lower_sharding_computation.
# Specifying backend on `jit` is not supported when Array is enabled. So take
# the `lower_xla_callable` path which can handle it.
if (jaxpr_has_primitive(jaxpr, 'xla_pmap') or
any(isinstance(s, sharding.PmapSharding) for s in in_shardings) or
backend is not None):
# TODO(yashkatariya): Figure out what to do with the backend argument in `jit`
if backend is not None:
arg_specs = tuple(
(a, s._device) if isinstance(s, sharding.SingleDeviceSharding) else (a, None)
for a, s in zip(in_avals, in_shardings))
return lower_xla_callable(
fun, None, backend, name, donated_invars, False, keep_unused, *arg_specs,
traced_jaxpr_info=traced_jaxpr_info).compile().unsafe_call
fun, None, backend, name, donated_invars, False, keep_unused,
*arg_specs).compile().unsafe_call

committed = any(i is not None for i in in_shardings)
da = pjit._get_and_check_device_assignment(
Expand All @@ -318,10 +306,10 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'xla_callable', name, in_shardings, pjit._UNSPECIFIED,
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED,
donated_invars, in_avals,
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
committed=committed, traced_jaxpr_info=traced_jaxpr_info).compile(
committed=committed).compile(
_allow_propagation_to_outputs=True).unsafe_call


Expand Down Expand Up @@ -359,11 +347,32 @@ def should_tuple_args(num_args: int, platform: str):
return num_args > 100


def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")

if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")

if xb.process_count() > 1 and (nreps > 1 or
jaxpr_has_primitive(jaxpr, "xla_pmap")):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")


@profiler.annotate_function
def lower_xla_callable(
fun: lu.WrappedFun, device, backend, name, donated_invars,
always_lower: bool, keep_unused: bool, *arg_specs,
traced_jaxpr_info: Optional[TracedJaxprInfo] = None):
always_lower: bool, keep_unused: bool, *arg_specs):
"""Lower into XLA.
Args:
Expand All @@ -386,16 +395,11 @@ def lower_xla_callable(
assert abstract_args == (None,) * len(abstract_args)
abstract_args = [aval for aval, _ in fun.in_type]

if traced_jaxpr_info is None:
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)
else:
jaxpr, out_avals, consts = traced_jaxpr_info
kept_outputs = [True] * len(out_avals)
out_type = tuple(zip(out_avals, kept_outputs))
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)

if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
Expand Down Expand Up @@ -447,25 +451,7 @@ def lower_xla_callable(
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
logging.log(log_priority, msg)

if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")

if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")

if xb.process_count() > 1 and (nreps > 1 or
jaxpr_has_primitive(jaxpr, "xla_pmap")):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr)

# pass long arg lists as tuple for TPU
tuple_args = should_tuple_args(len(abstract_args), backend.platform)
Expand Down Expand Up @@ -860,7 +846,8 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
result_handler: Callable,
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx, has_host_callbacks: bool, *args):
kept_var_idx, has_host_callbacks: bool,
*args, from_lower_sharding_computation: bool = False):
if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError(
Expand All @@ -874,6 +861,8 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if from_lower_sharding_computation:
return result_handler(out_bufs)
return result_handler(None, out_bufs)


Expand Down Expand Up @@ -1015,6 +1004,17 @@ def compile_or_get_cached(backend, computation, compile_options,
return backend_compile(backend, computation, compile_options, host_callbacks)


def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:
num_output_tokens = len(ordered_effects)
# TODO(sharadmv): remove check when minimum jaxlib version is bumped
if not can_execute_with_token:
num_output_tokens += has_unordered_effects
buffer_counts = ([1] * num_output_tokens) + buffer_counts
return buffer_counts


class XlaCompiledComputation(stages.XlaExecutable):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
keepalive: Any):
Expand Down Expand Up @@ -1049,13 +1049,8 @@ def from_xla_computation(name: str, xla_computation: Optional[ir.Module],
"in {elapsed_time} sec"):
compiled = compile_or_get_cached(backend, xla_computation, options,
host_callbacks)
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:
num_output_tokens = len(ordered_effects)
# TODO(sharadmv): remove check when minimum jaxlib version is bumped
if not can_execute_with_token:
num_output_tokens += has_unordered_effects
buffer_counts = ([1] * num_output_tokens) + buffer_counts
buffer_counts = get_buffer_counts(out_avals, ordered_effects,
has_unordered_effects)
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
result_handler, has_unordered_effects,
Expand Down
9 changes: 8 additions & 1 deletion jax/experimental/array.py
Expand Up @@ -33,7 +33,8 @@
from jax._src.numpy.ndarray import ndarray
from jax.interpreters import pxla, xla, mlir
from jax.experimental.sharding import (
Sharding, SingleDeviceSharding, XLACompatibleSharding, device_replica_id_map)
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map)

Shape = Tuple[int, ...]
Device = xc.Device
Expand Down Expand Up @@ -446,6 +447,12 @@ def _array_shard_arg(x, devices, indices, mode):
if isinstance(x.sharding, SingleDeviceSharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
# If PmapSharding exists, then do a round trip via host. This will happen
# if the input Array containing PmapSharding takes the jit path
# i.e. `apply_primitive` or `xla_callable_uncached`. `jit(pmap)` is the most
# common case where this will happen.
elif isinstance(x.sharding, PmapSharding):
return pxla.device_put(x._value, devices, replicate=True)
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg
Expand Down
100 changes: 67 additions & 33 deletions jax/interpreters/pxla.py
Expand Up @@ -2664,8 +2664,7 @@ def lower_sharding_computation(
global_in_avals: Sequence[core.ShapedArray],
in_is_global: Sequence[bool],
keep_unused: bool,
committed: bool,
traced_jaxpr_info: Optional[dispatch.TracedJaxprInfo] = None):
committed: bool):
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton _UNSPECIFIED because the
Expand All @@ -2686,22 +2685,19 @@ def lower_sharding_computation(

name_stack = new_name_stack(wrap_name(fun_name, api_name))

# 1. Trace to jaxpr and preprocess/verify it
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority,
"Compiling %s (%d) for with global shapes and types %s. "
"Argument mapping: %s.",
getattr(fun, '__name__', '<unnamed function>'), id(fun),
global_in_avals, in_shardings)

# 1. Trace to jaxpr and preprocess/verify it
if traced_jaxpr_info is None:
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"for sharded computation in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, "sharded computation"))
else:
jaxpr, out_jaxpr_avals, consts = traced_jaxpr_info

if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(out_jaxpr_avals)

Expand All @@ -2721,30 +2717,45 @@ def lower_sharding_computation(
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx

_sanitize_mesh_jaxpr(jaxpr)
if not first_sharding.is_fully_addressable():
check_multihost_collective_allowlist(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
# handled here so as to deprecate the lower_xla_callable codepath when
# `jax.Array` is turned on by default.
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
nreps = dispatch.jaxpr_replicas(jaxpr)
dispatch.raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)

# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)

in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
axis_ctx: mlir.ShardingContext
axis_ctx: mlir.AxisContext

in_op_shardings = [
None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)
]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [
None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)
]
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
if nreps == 1:
in_op_shardings = [
None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)
]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [
None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)
]
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
else:
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
in_op_shardings = None
out_op_shardings = None
axis_env = xla.AxisEnv(nreps, (), ())
axis_ctx = mlir.ReplicaAxisContext(axis_env)

closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
Expand Down Expand Up @@ -2800,7 +2811,8 @@ def lower_sharding_computation(
kept_var_idx=kept_var_idx,
backend=backend,
device_assignment=device_assignment,
committed=committed)
committed=committed,
pmap_nreps=nreps)


@profiler.annotate_function
Expand Down Expand Up @@ -3101,19 +3113,28 @@ def from_hlo(name: str,
kept_var_idx: Set[int],
backend: xb.XlaBackend,
device_assignment: Sequence[xc.Device],
committed: bool) -> MeshExecutable:
committed: bool,
pmap_nreps: int = 1) -> MeshExecutable:
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
else:
dev = np.array(device_assignment)
if spmd_lowering:
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
num_replicas, num_partitions = dev.size, 1
xla_device_assignment = dev.reshape((num_replicas, num_partitions))

if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
# the same thing here too.
xla_device_assignment = None
else:
xla_device_assignment = dev.reshape((num_replicas, num_partitions))

compile_options = xb.get_compile_options(
num_replicas=num_replicas,
Expand Down Expand Up @@ -3168,10 +3189,23 @@ def from_hlo(name: str,
global_out_avals, out_shardings, committed) # type: ignore # arg-type
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
input_indices, InputsHandlerMode.pjit_or_xmap)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks), kept_var_idx)

# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
if pmap_nreps > 1:
has_unordered_effects = bool(unordered_effects)
buffer_counts = dispatch.get_buffer_counts(
global_out_avals, ordered_effects, has_unordered_effects)
unsafe_call = partial(
dispatch._execute_replicated, name, xla_executable, None,
buffer_counts, handle_outs, has_unordered_effects, ordered_effects,
kept_var_idx, bool(host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks), kept_var_idx)

return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering)
Expand Down

0 comments on commit 83d7e3f

Please sign in to comment.