Skip to content

Commit

Permalink
Remove indices and devices from shard_arg_handlers and shard_args.
Browse files Browse the repository at this point in the history
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
  • Loading branch information
yashk2810 authored and jax authors committed Jan 5, 2024
1 parent ed62f28 commit b8098b1
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 168 deletions.
4 changes: 3 additions & 1 deletion jax/_src/api.py
Expand Up @@ -63,6 +63,7 @@
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
Expand Down Expand Up @@ -1845,7 +1846,8 @@ def cache_miss(*args, **kwargs):
return out, fastpath_data

cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
fun, cache_miss, static_broadcasted_tuple,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)

Expand Down
11 changes: 5 additions & 6 deletions jax/_src/array.py
Expand Up @@ -834,8 +834,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_arg(x._value, devices, indices, sharding,
canonicalize=False)
return pxla.shard_arg(x._value, sharding, canonicalize=False)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
Expand All @@ -848,27 +847,27 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
return pxla.batched_device_put(x.aval, sharding, bufs, devices)


def _array_shard_arg(x, devices, indices, sharding):
def _array_shard_arg(x, sharding):
x._check_if_deleted()

x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
indices = sharding.addressable_devices_indices_map(x.shape).values()
if not x.is_fully_addressable:
if tuple(x_indices) == tuple(indices):
return x
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
devices = pxla.get_addressable_devices_for_shard_arg(sharding)
if tuple(x_indices) == tuple(indices):
return xc.copy_array_to_devices_with_sharding(
x, list(devices), sharding)
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return shard_device_array(x, devices, indices, sharding)
else:
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)


pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg


Expand Down
9 changes: 2 additions & 7 deletions jax/_src/dispatch.py
Expand Up @@ -124,12 +124,8 @@ def __init__(self):
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
s = NamedSharding(pxla.Mesh(devices, axis_names=["dev"]),
PartitionSpec([]))
s = jax.sharding.GSPMDSharding.get_replicated(devices)
indices = tuple(
s.addressable_devices_indices_map(tok.shape).values())
sharded_tok = pxla.shard_args(devices, [indices], [s], [tok])[0]
sharded_tok = pxla.shard_args([s], [tok])[0]
self.current_tokens[eff] = sharded_tok
return sharded_tok

Expand Down Expand Up @@ -331,8 +327,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:

def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False)
map_ = s.devices_indices_map(aval.shape) # type: ignore
return result_handler(pxla.shard_arg(x, list(map_), list(map_.values()), s))
return result_handler(pxla.shard_arg(x, s))

def _override_get_device_assignment(sharding, *args, **kwargs):
da = sharding._device_assignment
Expand Down
156 changes: 73 additions & 83 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -106,69 +106,55 @@ class WeakRefList(list):

def identity(x): return x

def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
arguments that are not supported by the C++ `ShardArg`.
Args:
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
def shard_arg(arg, sharding, canonicalize=True):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
return shard_arg_handlers[type(arg)](arg, sharding)


@profiler.annotate_function
def shard_args(
devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
args,
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
) -> Sequence[jax.Array]:
"""Shard each argument data array along its leading axis.
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]

Args:
devices: sequence of Devices mapping replica index to a physical device.
indices: sequence of the same length as `args` describing how each arg
should be sharded/replicated across `devices`. Each element in `indices`
is the same length as `devices`.
args: a sequence of JaxTypes representing arguments to be sharded according
to `indices` and placed on `devices`.
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}

Returns:
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [shard_arg(arg, devices, indices[i], shardings[i])
for i, arg in enumerate(args)]

shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
@lru_cache(maxsize=1024)
def get_addressable_devices_for_shard_arg(
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
return s._addressable_device_assignment

@lru_cache(maxsize=1024)
def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices

def _shard_token(x, devices, indices, sharding):
def _shard_token(x, sharding):
devices = get_addressable_devices_for_shard_arg(sharding)
indices = _get_replicated_slices(len(devices))
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
return batched_device_put(aval, sharding, [zeros for i in indices], devices)
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
shard_arg_handlers[core.Token] = _shard_token

def _masked_array_error(x, devices, indices, sharding):
def _masked_array_error(x, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error

def _shard_array(x, devices, indices, sharding):
def _shard_array(x, sharding):
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
devices = get_addressable_devices_for_shard_arg(sharding)
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
return batched_device_put(aval, sharding, [x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array

def _shard_darray(x, devices, indices, sharding):
return shard_arg(x._data, devices, indices, sharding)
def _shard_darray(x, sharding):
return shard_arg(x._data, sharding)
shard_arg_handlers[core.DArray] = _shard_darray

def batched_device_put(aval: core.ShapedArray,
Expand All @@ -183,7 +169,7 @@ def batched_device_put(aval: core.ShapedArray,
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore

def shard_aval(size, axis: int, aval):
try:
Expand Down Expand Up @@ -849,8 +835,8 @@ def build_execute_fun(self):
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
handle_args = InputsHandler(self.input_shardings,
self.compiled.local_devices(), input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.unordered_effects,
Expand Down Expand Up @@ -1054,9 +1040,8 @@ def _get_pmap_sharding(devices, specs):
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")

def __init__(self, local_devices, in_shardings, input_indices):
self.handler = partial(
shard_args, local_devices, input_indices, in_shardings)
def __init__(self, in_shardings, local_devices=None, input_indices=None):
self.handler = partial(shard_args, in_shardings)
self.local_devices = local_devices
self.in_shardings = in_shardings
self.input_indices = input_indices
Expand Down Expand Up @@ -2248,37 +2233,36 @@ def cost_analysis(self) -> dict[str, float]:
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())


@lru_cache(maxsize=1024)
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
if ndim is None:
return ((slice(None),),) * num_addressable_devices
else:
return ((slice(None),) * ndim,) * num_addressable_devices
if xla_extension_version < 229:
def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:

input_indices = []
if not isinstance(da_object, _DeviceAssignment):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)

def _get_input_indices(
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: _DeviceAssignment | Sequence[xc.Device], # type: ignore
) -> Sequence[tuple[Index | None, ...]]:

input_indices = []
if not isinstance(da_object, _DeviceAssignment):
da_object = _create_da_object(tuple(da_object))
num_addressable_devices = len(da_object.addressable_device_list)
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
if ndim is None:
return ((slice(None),),) * num_addressable_devices
else:
return ((slice(None),) * ndim,) * num_addressable_devices

for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = _get_replicated_slices(num_addressable_devices, None)
else:
if sharding.is_fully_replicated:
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = _get_replicated_slices(num_addressable_devices, None)
else:
index = tuple(
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
input_indices.append(index)
if sharding.is_fully_replicated:
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
else:
index = tuple(
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
input_indices.append(index)

return input_indices
return input_indices


def get_gspmd_shardings_from_executable(
Expand Down Expand Up @@ -2604,10 +2588,13 @@ class UnloadedMeshExecutable:
all_args_info: AllArgsInfo | None

def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
self.device_assignment)
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices)
if xla_extension_version >= 229:
handle_args = InputsHandler(self.input_shardings)
else:
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
self.device_assignment)
handle_args = InputsHandler(
self.input_shardings, self.xla_executable.local_devices(), input_indices)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
Expand Down Expand Up @@ -2755,6 +2742,7 @@ class MeshExecutableFastpathData(NamedTuple):
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
arg_handler_devices: Sequence[xc.Device]
arg_handler_indices: Sequence[tuple[Index | None, ...]]

Expand Down Expand Up @@ -2865,13 +2853,20 @@ def aot_cache_miss(*args, **kwargs):
return outs, fastpath_data

if xla_extension_version >= 226:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, shard_arg)
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry,
shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore
else:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
tree_util.dispatch_registry)


# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
return shard_arg(arg, sharding)


def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
if len(ref_avals) != len(arg_avals):
Expand Down Expand Up @@ -2926,20 +2921,15 @@ def _compile_replicated_mesh_executable_from_hlo(
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings

input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
if pmap_nreps > 1:
# For a jit wrapping a pmap, replicate each input index to match the
# devices of the replicated jit computation.
input_indices = [index * pmap_nreps for index in input_indices]
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx,
device_assignment=da, ordered_effects=ordered_effects,
in_avals=global_in_avals,
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/pjit.py
Expand Up @@ -234,7 +234,8 @@ def cache_miss(*args, **kwargs):
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
Expand Down Expand Up @@ -1348,9 +1349,11 @@ def call_impl_cache_miss(*args_, **kwargs_):
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
if xla_extension_version >= 226:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(has_explicit_sharding))(*args)
else:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
tree_util.dispatch_registry,
Expand Down
15 changes: 3 additions & 12 deletions jax/_src/prng.py
Expand Up @@ -636,20 +636,11 @@ def __hash__(self) -> int:
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x


def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
aval = x.aval
key_shape = aval.dtype._impl.key_shape
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
arr = x._base_array

# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
# sharded. This is only true when enable_custom_prng is True.
trailing_inds = [slice(None)] * len(key_shape)
phys_indices = [(*inds, *trailing_inds) for inds in indices]
phys_sharding = make_key_array_phys_sharding(
aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](
arr, devices, phys_indices, phys_sharding
)
x.aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)


pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
Expand Down

0 comments on commit b8098b1

Please sign in to comment.