Skip to content

Commit

Permalink
Add make_array_from_single_device_arrays to prepare to rename of th…
Browse files Browse the repository at this point in the history
…e concrete `Array` to `ArrayImpl`.

PiperOrigin-RevId: 476965287
  • Loading branch information
yashk2810 authored and jax authors committed Sep 26, 2022
1 parent e034432 commit b2b60d9
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions jax/experimental/array.py
Expand Up @@ -477,6 +477,7 @@ def _value(self) -> np.ndarray:
setattr(Array, "__hash__", None)
setattr(Array, "__array_priority__", 100)


def make_array_from_callback(shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
device_to_index_map = sharding.devices_indices_map(shape)
Expand All @@ -492,6 +493,14 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
return Array(aval, sharding, arrays, committed=True)


def make_array_from_single_device_arrays(shape: Shape, sharding: Sharding,
arrays: Sequence[Array]) -> Array:
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
return Array(aval, sharding, arrays, committed=True)


core.pytype_aval_mappings[Array] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[Array] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[Array] = pxla.identity
Expand Down

0 comments on commit b2b60d9

Please sign in to comment.