Skip to content

Commit

Permalink
custom prng: add shard methods to PRNGKeyArrayImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 1, 2023
1 parent eb41e9c commit 3a7ccf7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 6 deletions.
7 changes: 4 additions & 3 deletions jax/_src/array.py
Expand Up @@ -18,8 +18,8 @@
import operator as op
import numpy as np
import functools
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
TYPE_CHECKING)
from typing import (Any, Callable, List, Optional, Sequence, Set, Tuple,
Union, cast, TYPE_CHECKING)

from jax._src import abstract_arrays
from jax._src import api
Expand All @@ -46,6 +46,7 @@
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.


class Shard:
Expand All @@ -61,7 +62,7 @@ class Shard:
"""

def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: Optional[ArrayImpl] = None):
data: Union[None, ArrayImpl, PRNGKeyArrayImpl] = None):
self._device = device
self._sharding = sharding
self._global_shape = global_shape
Expand Down
43 changes: 40 additions & 3 deletions jax/_src/prng.py
Expand Up @@ -17,7 +17,8 @@
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Set, Sequence, Tuple, Union
from typing import (Any, Callable, Hashable, Iterator, List, NamedTuple,
Set, Sequence, Tuple, Union)

import numpy as np

Expand Down Expand Up @@ -62,6 +63,7 @@
zip, unsafe_zip = safe_zip, zip

Device = xc.Device
Shard = Any # TODO(jakevdp): fix circular imports and import Shard

UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
Expand Down Expand Up @@ -216,9 +218,17 @@ def delete(self) -> None: ...
def is_deleted(self) -> bool: ...
@abc.abstractmethod
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> List[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> List[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...

# TODO(jakevdp): potentially add tolist(), tobytes(), device_buffer, device_buffers,
# addressable_data(), addressable_shards(), global_shards(), __cuda_interface__()
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()


class PRNGKeyArrayImpl(PRNGKeyArray):
Expand Down Expand Up @@ -291,6 +301,33 @@ def dtype(self):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]

def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))

@property
def addressable_shards(self) -> List[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
)
for s in self._base_array.addressable_shards
]

@property
def global_shards(self) -> List[Shard]:
return [
type(s)(
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
)
for s in self._base_array.global_shards
]

@property
def sharding(self):
phys_sharding = self._base_array.sharding
Expand Down
4 changes: 4 additions & 0 deletions tests/random_test.py
Expand Up @@ -1950,6 +1950,10 @@ def test_array_impl_attributes(self):
self.assertEqual(key.devices(), key._base_array.devices())
self.assertEqual(key.on_device_size_in_bytes, key._base_array.on_device_size_in_bytes)
self.assertEqual(key.unsafe_buffer_pointer, key._base_array.unsafe_buffer_pointer)
self.assertArraysEqual(key.addressable_data(0)._base_array,
key._base_array.addressable_data(0))
self.assertLen(key.addressable_shards, len(key._base_array.addressable_shards))
self.assertLen(key.global_shards, len(key._base_array.global_shards))

def test_delete(self):
key = self.make_keys(10)
Expand Down

0 comments on commit 3a7ccf7

Please sign in to comment.