Skip to content

Commit

Permalink
remove PRNGKeyArray ABC
Browse files Browse the repository at this point in the history
We don't expose the `PRNGKeyArray` symbol publicly any longer and we only implement the interface in one place.

PiperOrigin-RevId: 602470550
  • Loading branch information
froystig authored and jax authors committed Jan 29, 2024
1 parent 37b6d22 commit a043325
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 126 deletions.
4 changes: 2 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.

def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
Expand All @@ -69,7 +69,7 @@ class Shard:
"""

def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: None | ArrayImpl | PRNGKeyArrayImpl = None):
data: None | ArrayImpl | PRNGKeyArray = None):
self._device = device
self._sharding = sharding
self._global_shape = global_shape
Expand Down
154 changes: 30 additions & 124 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import abc
from collections.abc import Iterator, Sequence
from functools import partial, reduce
import math
Expand Down Expand Up @@ -136,101 +135,6 @@ def _check_prng_key_data(impl, key_data: typing.Array):


class PRNGKeyArray(jax.Array):
"""An array whose elements are PRNG keys"""

@abc.abstractmethod
def unsafe_buffer_pointer(self) -> int: ...

@abc.abstractmethod
def block_until_ready(self) -> PRNGKeyArray: ...

@abc.abstractmethod
def copy_to_host_async(self) -> None: ...

@property
@abc.abstractmethod
def shape(self) -> tuple[int, ...]: ...

@property
@abc.abstractmethod
def ndim(self) -> int: ...

@property
@abc.abstractmethod
def size(self) -> int: ...

@property
@abc.abstractmethod
def dtype(self): ...

@property
@abc.abstractmethod
def itemsize(self): ...

@property
@abc.abstractmethod
def sharding(self): ...

@property
@abc.abstractmethod
def at(self) -> _IndexUpdateHelper: ... # type: ignore[override]

@abc.abstractmethod
def __len__(self) -> int: ...
@abc.abstractmethod
def __iter__(self) -> Iterator[PRNGKeyArray]: ...

@abc.abstractmethod
def reshape(self, *args, order='C') -> PRNGKeyArray: ...

@property
@abc.abstractmethod
def T(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def __getitem__(self, _) -> PRNGKeyArray: ...
@abc.abstractmethod
def ravel(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def squeeze(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def swapaxes(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def take(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def transpose(self, *_, **__) -> PRNGKeyArray: ...
@abc.abstractmethod
def flatten(self, *_, **__) -> PRNGKeyArray: ...

@property
@abc.abstractmethod
def is_fully_addressable(self) -> bool: ...
@property
@abc.abstractmethod
def is_fully_replicated(self) -> bool: ...
@abc.abstractmethod
def device(self) -> Device: ...
@abc.abstractmethod
def devices(self) -> set[Device]: ...
@abc.abstractmethod
def delete(self) -> None: ...
@abc.abstractmethod
def is_deleted(self) -> bool: ...
@abc.abstractmethod
def on_device_size_in_bytes(self) -> int: ...
@property
@abc.abstractmethod
def addressable_shards(self) -> list[Shard]: ...
@property
@abc.abstractmethod
def global_shards(self) -> list[Shard]: ...
@abc.abstractmethod
def addressable_data(self, index: int) -> PRNGKeyArray: ...

# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()


class PRNGKeyArrayImpl(PRNGKeyArray):
"""An array of PRNG keys backed by an RNG implementation.
This class lifts the definition of a PRNG, provided in the form of a
Expand All @@ -243,6 +147,8 @@ class behave like an array whose base elements are keys, hiding the
wrapper methods around the PRNG implementation functions (``split``,
``random_bits``, ``fold_in``).
"""
# TODO(jakevdp): potentially add tolist(), tobytes(),
# device_buffer, device_buffers, __cuda_interface__()

_impl: PRNGImpl
_base_array: typing.Array
Expand Down Expand Up @@ -295,8 +201,8 @@ def itemsize(self):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]

def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))
def addressable_data(self, index: int) -> PRNGKeyArray:
return PRNGKeyArray(self._impl, self._base_array.addressable_data(index))

@property
def addressable_shards(self) -> list[Shard]:
Expand All @@ -305,7 +211,7 @@ def addressable_shards(self) -> list[Shard]:
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
data=PRNGKeyArray(self._impl, s._data),
)
for s in self._base_array.addressable_shards
]
Expand All @@ -317,7 +223,7 @@ def global_shards(self) -> list[Shard]:
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self._impl, s._data),
data=PRNGKeyArray(self._impl, s._data),
)
for s in self._base_array.global_shards
]
Expand All @@ -336,7 +242,7 @@ def __len__(self):
raise TypeError('len() of unsized object')
return len(self._base_array)

def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
def __iter__(self) -> Iterator[PRNGKeyArray]:
if self._is_scalar():
raise TypeError('iteration over a 0-d key array')
# TODO(frostig): we may want to avoid iteration by slicing because
Expand All @@ -348,7 +254,7 @@ def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array))
return (PRNGKeyArray(self._impl, k) for k in iter(self._base_array))

def __repr__(self):
return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n'
Expand Down Expand Up @@ -381,26 +287,26 @@ def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
def take(self, *_, **__) -> PRNGKeyArray: assert False
def transpose(self, *_, **__) -> PRNGKeyArray: assert False

_set_array_base_attributes(PRNGKeyArrayImpl, include=[
_set_array_base_attributes(PRNGKeyArray, include=[
*(f"__{op}__" for op in _array_operators),
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])

api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')

def prngkeyarrayimpl_flatten(x):
def prngkeyarray_flatten(x):
return (x._base_array,), x._impl

def prngkeyarrayimpl_unflatten(impl, children):
def prngkeyarray_unflatten(impl, children):
base_array, = children
return PRNGKeyArrayImpl(impl, base_array)
return PRNGKeyArray(impl, base_array)

tree_util_internal.dispatch_registry.register_node(
PRNGKeyArrayImpl, prngkeyarrayimpl_flatten, prngkeyarrayimpl_unflatten)
PRNGKeyArray, prngkeyarray_flatten, prngkeyarray_unflatten)


# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArrayImpl:
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray:
return random_seed(seed, impl=impl)


Expand Down Expand Up @@ -499,7 +405,7 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArrayImpl(aval.dtype._impl, buf)
return PRNGKeyArray(aval.dtype._impl, buf)
return handler

@staticmethod
Expand All @@ -524,7 +430,7 @@ def local_sharded_result_handler(aval, sharding, indices):

# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))

return handler

Expand All @@ -539,7 +445,7 @@ def global_sharded_result_handler(aval, out_sharding, committed,
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
return handler

@staticmethod
Expand All @@ -551,7 +457,7 @@ def make_sharded_array(aval, sharding, arrays, committed):
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArrayImpl(aval.dtype._impl, phys_result)
return PRNGKeyArray(aval.dtype._impl, phys_result)

@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
Expand Down Expand Up @@ -617,26 +523,26 @@ def __hash__(self) -> int:



core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval

xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x


def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, sharding):
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
arr = x._base_array
phys_sharding = make_key_array_phys_sharding(
x.aval, sharding, is_sharding_from_xla=False)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)


pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler


def key_array_constant_handler(x):
arr = x._base_array
return mlir.get_constant_handler(type(arr))(arr)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler)


# -- primitives
Expand Down Expand Up @@ -681,7 +587,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
return f


def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayImpl:
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray:
# Avoid overflow error in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
Expand All @@ -704,7 +610,7 @@ def random_seed_abstract_eval(seeds_aval, *, impl):
@random_seed_p.def_impl
def random_seed_impl(seeds, *, impl):
base_arr = random_seed_impl_base(seeds, impl=impl)
return PRNGKeyArrayImpl(impl, base_arr)
return PRNGKeyArray(impl, base_arr)

def random_seed_impl_base(seeds, *, impl):
seed = iterated_vmap_unary(np.ndim(seeds), impl.seed)
Expand Down Expand Up @@ -736,7 +642,7 @@ def random_split_abstract_eval(keys_aval, *, shape):
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys._impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
return PRNGKeyArray(keys._impl, base_arr)

def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
Expand Down Expand Up @@ -773,7 +679,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys._impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)
return PRNGKeyArray(keys._impl, base_arr)

def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
Expand Down Expand Up @@ -877,7 +783,7 @@ def random_wrap_abstract_eval(base_arr_aval, *, impl):

@random_wrap_p.def_impl
def random_wrap_impl(base_arr, *, impl):
return PRNGKeyArrayImpl(impl, base_arr)
return PRNGKeyArray(impl, base_arr)

def random_wrap_lowering(ctx, base_arr, *, impl):
return [base_arr]
Expand Down

0 comments on commit a043325

Please sign in to comment.