Skip to content

Commit

Permalink
Add support to handle arbitrary shardings to KeyArray. Resolve all th…
Browse files Browse the repository at this point in the history
…e TODOs that were created before.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
  • Loading branch information
2 people authored and jax authors committed Sep 1, 2022
1 parent bf7525e commit 0584c6a
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 63 deletions.
42 changes: 31 additions & 11 deletions jax/_src/prng.py
Expand Up @@ -287,6 +287,17 @@ def aval_to_ir_types(aval):
phys_aval, = KeyTyRules.physical_avals(aval)
return mlir.aval_to_ir_types(phys_aval)

@staticmethod
def physical_op_sharding(aval, sharding):
op_sharding = sharding._to_xla_op_sharding(aval.ndim)
key_shape = aval.dtype.impl.key_shape

new_op_sharding = op_sharding.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
tad.extend([1] * len(key_shape))
new_op_sharding.tile_assignment_dimensions = tad
return new_op_sharding

@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
Expand Down Expand Up @@ -338,7 +349,8 @@ def handler(bufs):
return handler

@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval, = KeyTyRules.physical_avals(aval)
key_shape = aval.dtype.impl.key_shape

Expand All @@ -355,20 +367,20 @@ def global_sharded_result_handler(aval, out_sharding, committed):
if isinstance(out_sharding, SingleDeviceSharding):
phys_sharding = out_sharding
elif isinstance(out_sharding, MeshPspecSharding):
# TODO(yashkatariya,frostig): not covered by tests until we write a test
# that uses pjit with axis resource annotations (using GDA, not Array)
trailing_spec = [None] * len(key_shape)
phys_sharding = MeshPspecSharding(
out_sharding.mesh,
pxla.PartitionSpec(*out_sharding.spec, *trailing_spec))
else:
# TODO(yashkatariya,frostig): implement. Plan: accept an argument in
# this handler indicating whether the sharding came from XLA or not.
# Pass the sharding through if it's from XLA, otherwise maybe create
# a new op sharding with a trivially extended `tile_assignment_dimensions`
raise NotImplementedError

phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
if is_out_sharding_from_xla:
phys_sharding = out_sharding
else:
phys_sharding = OpShardingSharding(
out_sharding._device_assignment,
KeyTyRules.physical_op_sharding(aval, out_sharding))

phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return handler
Expand Down Expand Up @@ -492,10 +504,18 @@ def device_put_key_array(x: PRNGKeyArray, device):
dispatch.device_put_handlers[PRNGKeyArray] = device_put_key_array

def key_array_shard_arg_handler(x: PRNGKeyArray, devices, indices, mode):
# TODO(frostig): Remove the need for `core.get_aval`.
key_shape = core.get_aval(x).dtype.impl.key_shape
arr = x.unsafe_raw_array()
return pxla.shard_arg_handlers[type(arr)](arr, devices, indices, mode)

# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
# sharded. This is only true when enable_custom_prng is True.
trailing_inds = [slice(None)] * len(key_shape)
phys_indices = [(*inds, *trailing_inds) for inds in indices]
return pxla.shard_arg_handlers[type(arr)](arr, devices, phys_indices, mode)
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler


def key_array_constant_handler(x, canonicalize_dtypes):
arr = x.unsafe_raw_array()
return mlir.get_constant_handler(type(arr))(arr, canonicalize_dtypes)
Expand Down
5 changes: 3 additions & 2 deletions jax/experimental/array.py
Expand Up @@ -486,12 +486,13 @@ def _array_shard_arg(x, devices, indices, mode):
pxla.shard_arg_handlers[Array] = _array_shard_arg


def _array_global_result_handler(global_aval, out_sharding, committed):
def _array_global_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed)
global_aval, out_sharding, committed, is_out_sharding_from_xla)

# Calculate the indices and addressable device assignment once during
# compilation and pass it to the constructor.
Expand Down
5 changes: 3 additions & 2 deletions jax/experimental/global_device_array.py
Expand Up @@ -618,10 +618,11 @@ def _gda_shard_arg(x, devices, indices, mode):
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg


def _gda_array_result_handler(global_aval, out_sharding, committed):
def _gda_array_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed)
global_aval, out_sharding, committed, is_out_sharding_from_xla)
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh,
out_axis_resources)
Expand Down
11 changes: 7 additions & 4 deletions jax/interpreters/mlir.py
Expand Up @@ -608,10 +608,13 @@ def lower_jaxpr_to_module(
]
out_avals = jaxpr.out_avals
if result_shardings is not None:
out_avals = [
sharded_aval(out_aval, out_sharding)
for out_aval, out_sharding in zip(out_avals, result_shardings)
]
out_avals = []
for out_aval, out_sharding in zip(jaxpr.out_avals, result_shardings):
if (out_aval is not core.abstract_token and
core.is_opaque_dtype(out_aval.dtype)):
out_aval, = out_aval.dtype._rules.physical_avals(out_aval)
out_avals.append(sharded_aval(out_aval, out_sharding))

platforms_with_donation = ("cuda", "rocm", "tpu")
if platform in platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
Expand Down
100 changes: 67 additions & 33 deletions jax/interpreters/pxla.py
Expand Up @@ -249,7 +249,7 @@ def _op_sharding_to_numpy_indices(
axis_indices.append([slice(None)])
elif n_shards > 1:
shard_size, ragged = divmod(dim, n_shards)
assert not ragged, (dim, n_shards, dim)
assert not ragged, (dim, n_shards)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(n_shards)])
else:
Expand Down Expand Up @@ -589,7 +589,8 @@ def sda_array_result_handler(aval: ShapedArray, sharding, indices):


def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Expand All @@ -599,6 +600,8 @@ def global_aval_to_result_handler(
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
the sharding is extracted from the HLO.
Returns:
A function for handling the Buffers that will eventually be produced
Expand All @@ -611,7 +614,7 @@ def global_aval_to_result_handler(
output_type = OutputType.GlobalDeviceArray
try:
return global_result_handlers[(type(aval), output_type)](
aval, out_sharding, committed)
aval, out_sharding, committed, is_out_sharding_from_xla)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
Expand Down Expand Up @@ -1846,13 +1849,15 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[XLACompatibleSharding],
committed: bool) -> ResultsHandler:
committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
from jax.experimental.sharding import MeshPspecSharding

if config.jax_parallel_functions_output_gda or config.jax_array:
handlers = [
global_aval_to_result_handler(global_aval, s, committed)
for global_aval, s in safe_zip(global_out_avals, shardings)
global_aval_to_result_handler(global_aval, s, committed, x)
for global_aval, s, x in safe_zip(global_out_avals, shardings,
are_out_shardings_from_xla)
]
return ResultsHandler(handlers, shardings, global_out_avals)
else:
Expand Down Expand Up @@ -2689,7 +2694,7 @@ def lower_sharding_computation(
# 1. Trace to jaxpr and preprocess/verify it
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
Expand All @@ -2700,12 +2705,10 @@ def lower_sharding_computation(
global_in_avals, in_shardings)

if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(out_jaxpr_avals)
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)

# mypy doesn't understand that out_sharding here is always a sequence.
assert len(out_shardings) == len(out_jaxpr_avals), (len(out_shardings), len(out_jaxpr_avals)) # type: ignore

global_out_avals = out_jaxpr_avals
assert len(out_shardings) == len(global_out_avals), (len(out_shardings), len(global_out_avals)) # type: ignore

if keep_unused:
kept_var_idx = set(range(len(global_in_avals)))
Expand Down Expand Up @@ -2738,16 +2741,25 @@ def lower_sharding_computation(
axis_ctx: mlir.AxisContext

if nreps == 1:
in_op_shardings = [
None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)
]
in_op_shardings = []
for aval, i in safe_zip(global_in_avals, in_shardings):
if aval is core.abstract_token:
in_op_shardings.append(None)
elif core.is_opaque_dtype(aval.dtype):
in_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, i))
else:
in_op_shardings.append(i._to_xla_op_sharding(aval.ndim))

# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [
None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)
]
out_op_shardings = []
for aval, o in safe_zip(global_out_avals, out_shardings):
if _is_unspecified(o) or aval is core.abstract_token:
out_op_shardings.append(None)
elif core.is_opaque_dtype(aval.dtype):
out_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, o))
else:
out_op_shardings.append(o._to_xla_op_sharding(aval.ndim))
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
else:
Expand Down Expand Up @@ -2899,16 +2911,25 @@ def lower_mesh_computation(
out_partitions: Optional[List[Optional[xc.OpSharding]]]
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = [
None if _is_auto(i) else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)
]
in_partitions = []
for aval, i in safe_zip(global_in_avals, in_shardings):
if _is_auto(i):
in_partitions.append(None)
elif core.is_opaque_dtype(aval.dtype):
in_partitions.append(aval.dtype._rules.physical_op_sharding(aval, i))
else:
in_partitions.append(i._to_xla_op_sharding(aval.ndim))

# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_partitions = [
None if _is_auto(o) or _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)
]
out_partitions = []
for aval, o in safe_zip(global_out_avals, out_shardings):
if _is_auto(o) or _is_unspecified(o):
out_partitions.append(None)
elif core.is_opaque_dtype(aval.dtype):
out_partitions.append(aval.dtype._rules.physical_op_sharding(aval, o))
else:
out_partitions.append(o._to_xla_op_sharding(aval.ndim))
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = mlir.SPMDAxisContext(mesh)
else:
Expand Down Expand Up @@ -3158,8 +3179,9 @@ def from_hlo(name: str,
assert not auto_spmd_lowering
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
are_out_shardings_from_xla = [False] * len(global_out_avals)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed) # type: ignore # arg-type
global_out_avals, out_shardings, committed, are_out_shardings_from_xla) # type: ignore # arg-type
unsafe_call = backend.compile_replicated(computation, compile_options,
host_callbacks, input_avals,
input_indices, in_shardings,
Expand All @@ -3174,22 +3196,34 @@ def from_hlo(name: str,

if auto_spmd_lowering:
assert mesh is not None
in_shardings, out_shardings = _get_mesh_pspec_shardings_from_executable(
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if _is_auto(i) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if _is_auto(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
assert mesh is None
_, out_shardings_xla = _get_op_sharding_shardings_from_executable(
xla_executable, device_assignment,
len(global_in_avals), len(global_out_avals))
out_shardings = [x if _is_unspecified(o) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
out_shardings_tuple = [
(x, True) if _is_unspecified(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
else:
are_out_shardings_from_xla = [False] * len(global_out_avals)

in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed) # type: ignore # arg-type
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
input_indices, InputsHandlerMode.pjit_or_xmap)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed, are_out_shardings_from_xla) # type: ignore # arg-type

# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
Expand Down
13 changes: 11 additions & 2 deletions tests/lax_test.py
Expand Up @@ -3018,11 +3018,19 @@ def test_abstract_eval_collective(self):
class FooTyRules:
# handlers

@staticmethod
def physical_avals(aval):
return [core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))]

@staticmethod
def aval_to_ir_types(aval):
aval2 = core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))
aval2, = FooTyRules.physical_avals(aval)
return mlir.aval_to_ir_types(aval2)

@staticmethod
def physical_op_sharding(aval, sharding):
return sharding._to_xla_op_sharding(aval.ndim)

@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
Expand All @@ -3031,7 +3039,8 @@ def handler(_, buf):
return handler

@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
def handler(bufs):
buf, = bufs
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
Expand Down

0 comments on commit 0584c6a

Please sign in to comment.