From 3a7ccf70f23a58abeb670e009176932ee46755f4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Jun 2023 04:10:12 -0700 Subject: [PATCH] custom prng: add shard methods to PRNGKeyArrayImpl --- jax/_src/array.py | 7 ++++--- jax/_src/prng.py | 43 ++++++++++++++++++++++++++++++++++++++++--- tests/random_test.py | 4 ++++ 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 92b2340afdf0..0b2a59f55d04 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -18,8 +18,8 @@ import operator as op import numpy as np import functools -from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set, - TYPE_CHECKING) +from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple, + Union, cast, TYPE_CHECKING) from jax._src import abstract_arrays from jax._src import api @@ -46,6 +46,7 @@ Shape = Tuple[int, ...] Device = xc.Device Index = Tuple[slice, ...] +PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this. class Shard: @@ -61,7 +62,7 @@ class Shard: """ def __init__(self, device: Device, sharding: Sharding, global_shape: Shape, - data: Optional[ArrayImpl] = None): + data: Union[None, ArrayImpl, PRNGKeyArrayImpl] = None): self._device = device self._sharding = sharding self._global_shape = global_shape diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 24ce3d82affc..e285f7bb5a9b 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -17,7 +17,8 @@ from functools import partial, reduce import math import operator as op -from typing import Any, Callable, Hashable, Iterator, NamedTuple, Set, Sequence, Tuple, Union +from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple, + Set, Sequence, Tuple, Union) import numpy as np @@ -62,6 +63,7 @@ zip, unsafe_zip = safe_zip, zip Device = xc.Device +Shard = Any # TODO(jakevdp): fix circular imports and import Shard UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type] @@ -216,9 +218,17 @@ def delete(self) -> None: ... def is_deleted(self) -> bool: ... @abc.abstractmethod def on_device_size_in_bytes(self) -> int: ... + @property + @abc.abstractmethod + def addressable_shards(self) -> List[Shard]: ... + @property + @abc.abstractmethod + def global_shards(self) -> List[Shard]: ... + @abc.abstractmethod + def addressable_data(self, index: int) -> PRNGKeyArray: ... - # TODO(jakevdp): potentially add tolist(), tobytes(), device_buffer, device_buffers, - # addressable_data(), addressable_shards(), global_shards(), __cuda_interface__() + # TODO(jakevdp): potentially add tolist(), tobytes(), + # device_buffer, device_buffers, __cuda_interface__() class PRNGKeyArrayImpl(PRNGKeyArray): @@ -291,6 +301,33 @@ def dtype(self): on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment] unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment] + def addressable_data(self, index: int) -> PRNGKeyArrayImpl: + return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index)) + + @property + def addressable_shards(self) -> List[Shard]: + return [ + type(s)( + device=s._device, + sharding=s._sharding, + global_shape=s._global_shape, + data=PRNGKeyArrayImpl(self.impl, s._data), + ) + for s in self._base_array.addressable_shards + ] + + @property + def global_shards(self) -> List[Shard]: + return [ + type(s)( + device=s._device, + sharding=s._sharding, + global_shape=s._global_shape, + data=PRNGKeyArrayImpl(self.impl, s._data), + ) + for s in self._base_array.global_shards + ] + @property def sharding(self): phys_sharding = self._base_array.sharding diff --git a/tests/random_test.py b/tests/random_test.py index ea77a439ab64..a08ba1b15f2b 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1950,6 +1950,10 @@ def test_array_impl_attributes(self): self.assertEqual(key.devices(), key._base_array.devices()) self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes) self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer) + self.assertArraysEqual(key.addressable_data(0)._base_array, + key._base_array.addressable_data(0)) + self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards)) + self.assertLen(key.global_shards, len(key._base_array.global_shards)) def test_delete(self): key = self.make_keys(10)