Skip to content

Commit

Permalink
Add layout support to make_array_from_callback.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625048520
  • Loading branch information
yashk2810 authored and jax authors committed Apr 15, 2024
1 parent b9a853d commit eb92a5c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 38 deletions.
46 changes: 37 additions & 9 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension as xe
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
Expand All @@ -45,7 +46,7 @@
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.layout import DeviceLocalLayout, Layout, AutoLayout
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method

Expand Down Expand Up @@ -644,7 +645,7 @@ def _value(self) -> np.ndarray:
setattr(ArrayImpl, "__array_priority__", 100)

def make_array_from_callback(
shape: Shape, sharding: Sharding,
shape: Shape, sharding: Sharding | Layout,
data_callback: Callable[[Index | None], ArrayLike]) -> ArrayImpl:
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
Expand Down Expand Up @@ -683,6 +684,17 @@ def make_array_from_callback(
>>> arr.addressable_data(0).shape
(4, 2)
"""
dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
if isinstance(dll, AutoLayout):
raise TypeError(
"`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
f" layout when calling `jax.make_array_from_callback`. Got {sharding}")
sharding = sharding.sharding if isinstance(sharding, Layout) else sharding # type: ignore
if not isinstance(sharding, Sharding):
raise TypeError(
f"sharding should be an instance of `jax.sharding`. Got {sharding} of"
f" type {type(sharding)}")

if sharding.is_fully_replicated:
devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore
per_device_values = [data_callback((slice(None),) * len(shape))] * len(devices)
Expand All @@ -701,7 +713,7 @@ def make_array_from_callback(

# first value can be numpy array, python scalar, etc.
if (sharding.is_fully_replicated and not isinstance(first_value, ArrayImpl)
and not dtypes.issubdtype(aval.dtype, dtypes.extended)):
and not dtypes.issubdtype(aval.dtype, dtypes.extended) and dll is None):
# Do this check outside because `batched_device_put` won't do these checks
# like ArrayImpl.
if shape != first_value.shape:
Expand All @@ -712,11 +724,25 @@ def make_array_from_callback(
return pxla.batched_device_put(
aval, sharding, per_device_values, devices, committed=True)

if (sharding.is_fully_replicated and isinstance(first_value, ArrayImpl) and
first_value.is_fully_replicated and
first_value.sharding._device_assignment == devices):
# After minimum jaxlib version >= 0.4.26, merge this condition into the
# following if block.
if xla_extension_version >= 256 and isinstance(first_value, ArrayImpl):
maybe_default_layout = pxla._maybe_get_default_layout(
Layout(dll, sharding), None, sharding, aval)
layout_eq = first_value.layout.device_local_layout == maybe_default_layout
else:
layout_eq = True

if (isinstance(first_value, ArrayImpl)
and first_value._committed
and sharding.is_fully_replicated
and first_value.is_fully_replicated
and first_value.sharding._device_assignment == tuple(devices)
and layout_eq):
return first_value

if dll is not None:
devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices]
arrays = api.device_put(per_device_values, devices)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
Expand Down Expand Up @@ -806,11 +832,13 @@ def make_array_from_single_device_arrays(
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
if any(isinstance(arr, core.Tracer) for arr in arrays):
raise ValueError("jax.make_array_from_single_device_arrays requires a list of concrete arrays as input. "
f"got types {set(map(type, arrays))}")
raise ValueError(
"jax.make_array_from_single_device_arrays requires a list of concrete"
f" arrays as input. got types {set(map(type, arrays))}")
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
committed=True)
# TODO(phawkins): ideally the cast() could be checked.
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)
Expand Down
29 changes: 29 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,35 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"extra data movement anyway, so maybe you don't want it after all).")


@lru_cache(maxsize=2048)
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
) -> DeviceLocalLayout | None:
if is_unspecified_or_auto(sharding):
return None
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return None
if not core.is_constant_shape(aval.shape):
return None
shard_shape = sharding.shard_shape(aval.shape)
d = sharding._device_assignment[0]
# If a backend doesn't implement `get_default_layout` return `None` to avoid
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
# first call you pass it a sharded array with layout and on second call you
# pass a numpy array. The layouts should be the same to get cache hits.
try:
al = DeviceLocalLayout(
d.client.get_default_layout(aval.dtype, shard_shape, d))
except:
return None
# argument does not have `.layout` property. ShapedArray, numpy array, etc
# are some examples.
if arg_layout is None:
return al if jit_in_layout is None else arg_layout # arg_layout is None
# If arg has a `.layout` property, then return device_local_layout as is.
return arg_layout.device_local_layout


@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
Expand Down
30 changes: 1 addition & 29 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,34 +1281,6 @@ def pjit_check_aval_sharding(
pjit_p.multiple_results = True


@lru_cache(maxsize=2048)
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval):
if is_unspecified_or_auto(sharding):
return None
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return None
if not core.is_constant_shape(aval.shape):
return None
shard_shape = sharding.shard_shape(aval.shape)
d = sharding._device_assignment[0]
# If a backend doesn't implement `get_default_layout` return `None` to avoid
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
# first call you pass it a sharded array with layout and on second call you
# pass a numpy array. The layouts should be the same to get cache hits.
try:
al = DeviceLocalLayout(
d.client.get_default_layout(aval.dtype, shard_shape, d))
except:
return None
# argument does not have `.layout` property. ShapedArray, ShapedDtypeStruct,
# numpy array, etc are some examples.
if arg_layout is None:
return al if jit_in_layout is None else arg_layout # arg_layout is None
# If arg has a `.layout` property, then return device_local_layout as is.
return arg_layout.device_local_layout


def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
Expand All @@ -1321,7 +1293,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
for arg, jit_in_l, rs, aval in safe_zip(
args, jit_in_layouts, resolved_in_shardings, in_avals):
arg_layout, committed = (
_maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
getattr(arg, '_committed', True))
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
Expand Down
25 changes: 25 additions & 0 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,31 @@ def test_layout_on_sds(self):
' layout in a `ShapeDtypeStruct`'):
jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO))

def test_make_array_from_callback(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)

layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts()

out = jax.make_array_from_callback(np_inp.shape, layout,
lambda idx: np_inp[idx])
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.layout, layout)

with self.assertRaisesRegex(
TypeError,
'`DeviceLocalLayout.AUTO` cannot be used in place of a device-local'
' layout'):
jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s),
lambda idx: np_inp[idx])

with self.assertRaisesRegex(
TypeError, 'sharding should be an instance of `jax.sharding`'):
jax.make_array_from_callback(
np_inp.shape, Layout(None, None), lambda idx: np_inp[idx])


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit eb92a5c

Please sign in to comment.