diff --git a/jax/experimental/array.py b/jax/experimental/array.py index a43f022f68ee..88f142d0a457 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -477,6 +477,7 @@ def _value(self) -> np.ndarray: setattr(Array, "__hash__", None) setattr(Array, "__array_priority__", 100) + def make_array_from_callback(shape: Shape, sharding: Sharding, data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array: device_to_index_map = sharding.devices_indices_map(shape) @@ -492,6 +493,14 @@ def make_array_from_callback(shape: Shape, sharding: Sharding, return Array(aval, sharding, arrays, committed=True) +def make_array_from_single_device_arrays(shape: Shape, sharding: Sharding, + arrays: Sequence[Array]) -> Array: + # All input arrays should be committed. Checking it is expensive on + # single-controller systems. + aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) + return Array(aval, sharding, arrays, committed=True) + + core.pytype_aval_mappings[Array] = abstract_arrays.canonical_concrete_aval xla.pytype_aval_mappings[Array] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[Array] = pxla.identity