Skip to content

Commit

Permalink
[random] make PRNG impl attributes private
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 18, 2023
1 parent cf65480 commit 0da4be5
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 63 deletions.
2 changes: 1 addition & 1 deletion docs/jep/9263-typed-keys.md
Expand Up @@ -330,7 +330,7 @@ True
True
```
And in addition to `key.dtype._rules` as outlined for extended dtypes in
general, PRNG dtypes define `key.dtype.impl`, which contains the metadata
general, PRNG dtypes define `key.dtype._impl`, which contains the metadata
that defines the PRNG implementation. The PRNG implementation is currently
defined by the non-public `jax._src.prng.PRNGImpl` class. For now, `PRNGImpl`
isn't meant to be a public API, but we might revisit this soon to allow for
Expand Down
97 changes: 45 additions & 52 deletions jax/_src/prng.py
Expand Up @@ -258,14 +258,13 @@ class behave like an array whose base elements are keys, hiding the
``random_bits``, ``fold_in``).
"""

# TODO(frostig,vanderplas): hide impl attribute
impl: PRNGImpl
_impl: PRNGImpl
_base_array: typing.Array

def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
_check_prng_key_data(impl, key_data)
self.impl = impl
self._impl = impl
self._base_array = key_data

def block_until_ready(self):
Expand All @@ -277,11 +276,11 @@ def copy_to_host_async(self):

@property
def aval(self):
return keys_shaped_array(self.impl, self.shape)
return keys_shaped_array(self._impl, self.shape)

@property
def shape(self):
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
return base_arr_shape_to_keys_shape(self._impl, self._base_array.shape)

@property
def size(self):
Expand All @@ -293,7 +292,7 @@ def ndim(self):

@property
def dtype(self):
return KeyTy(self.impl)
return KeyTy(self._impl)

@property
def itemsize(self):
Expand All @@ -318,7 +317,7 @@ def unsafe_raw_array(self):
return self._base_array

def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))

@property
def addressable_shards(self) -> list[Shard]:
Expand All @@ -327,7 +326,7 @@ def addressable_shards(self) -> list[Shard]:
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
data=PRNGKeyArrayImpl(self._impl, s._data),
)
for s in self._base_array.addressable_shards
]
Expand All @@ -339,7 +338,7 @@ def global_shards(self) -> list[Shard]:
device=s._device,
sharding=s._sharding,
global_shape=s._global_shape,
data=PRNGKeyArrayImpl(self.impl, s._data),
data=PRNGKeyArrayImpl(self._impl, s._data),
)
for s in self._base_array.global_shards
]
Expand All @@ -350,7 +349,7 @@ def sharding(self):
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)

def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
base_ndim = len(self._impl.key_shape)
return self._base_array.ndim == base_ndim

def __len__(self):
Expand All @@ -370,26 +369,21 @@ def __iter__(self) -> Iterator[PRNGKeyArrayImpl]:
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArrayImpl(self.impl, k) for k in iter(self._base_array))

# TODO(frostig): are all of the stackable methods below (reshape,
# concat, broadcast_to, expand_dims), and the stackable registration,
# still needed? If, with some work, none are needed, then do we want
# to remove stackables altogether? This may be the only application.
return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array))

def __repr__(self):
return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n'
f'{self._base_array}')

def pprint(self):
pp_keys = pp.text('shape = ') + pp.text(str(self.shape))
pp_impl = pp.text('impl = ') + self.impl.pprint()
pp_impl = pp.text('impl = ') + self._impl.pprint()
return str(pp.group(
pp.text('PRNGKeyArray:') +
pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl)))

def copy(self):
return self.__class__(self.impl, self._base_array.copy())
return self.__class__(self._impl, self._base_array.copy())

__hash__ = None # type: ignore[assignment]
__array_priority__ = 100
Expand Down Expand Up @@ -418,7 +412,7 @@ def transpose(self, *_, **__) -> PRNGKeyArray: assert False
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]

def prngkeyarrayimpl_flatten(x):
return (x._base_array,), x.impl
return (x._base_array,), x._impl

def prngkeyarrayimpl_unflatten(impl, children):
base_array, = children
Expand Down Expand Up @@ -448,15 +442,15 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
if dispatch.is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
trailing_spec = [None] * len(key_shape)
return NamedSharding(
sharding.mesh,
Expand All @@ -473,26 +467,26 @@ class KeyTyRules:

@staticmethod
def full(shape, fill_value, dtype):
physical_shape = (*shape, *dtype.impl.key_shape)
physical_shape = (*shape, *dtype._impl.key_shape)
if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key):
key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape)
else:
key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32'))
# TODO(frostig,mattjj,vanderplas,lenamartens): consider this consumed from
# the outset.
return random_wrap(key_data, impl=dtype.impl)
return random_wrap(key_data, impl=dtype._impl)

@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray(dtype.impl.key_shape, jnp.dtype('uint32'))
return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32'))

@staticmethod
def physical_const(val) -> Array:
return val._base_array

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
op_sharding_proto = hlo_sharding.to_proto() # type: ignore
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
Expand All @@ -506,19 +500,19 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
if dispatch.is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
return pxla.create_mesh_pspec_sharding(
phys_sharding.mesh,
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
else:
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
phys_op_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape)).to_proto()
logical_op_sharding = phys_op_sharding.clone()
Expand All @@ -532,13 +526,13 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return PRNGKeyArrayImpl(aval.dtype.impl, buf)
return PRNGKeyArrayImpl(aval.dtype._impl, buf)
return handler

@staticmethod
def local_sharded_result_handler(aval, sharding, indices):
phys_aval = core.physical_aval(aval)
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]

# set up a grounded sharding (with a grounded sharding spec)
Expand All @@ -557,7 +551,7 @@ def local_sharded_result_handler(aval, sharding, indices):

# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))

return handler

Expand All @@ -572,7 +566,7 @@ def global_sharded_result_handler(aval, out_sharding, committed,
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs))
return handler

@staticmethod
Expand All @@ -584,15 +578,15 @@ def make_sharded_array(aval, sharding, arrays, committed):
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArrayImpl(aval.dtype.impl, phys_result)
return PRNGKeyArrayImpl(aval.dtype._impl, phys_result)

@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = keys_aval_to_base_arr_aval(aval)
physical_buffers = tree_util.tree_map(random_unwrap, vals)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
return random_wrap(physical_result, impl=aval.dtype.impl)
return random_wrap(physical_result, impl=aval.dtype._impl)

@staticmethod
def device_put_replicated(val, aval, sharding, devices):
Expand All @@ -601,34 +595,33 @@ def device_put_replicated(val, aval, sharding, devices):
physical_buf = random_unwrap(val)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype.impl)
return random_wrap(physical_result, impl=aval.dtype._impl)


class KeyTy(dtypes.ExtendedDType):
# TODO(frostig,vanderplas): hide impl attribute
impl: PRNGImpl # TODO(mattjj,frostig): protocol really
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
_rules = KeyTyRules
type = dtypes.prng_key

def __init__(self, impl):
self.impl = impl
self._impl = impl

@property
def name(self) -> str:
return f'key<{self.impl.tag}>'
return f'key<{self._impl.tag}>'

@property
def itemsize(self) -> int:
return math.prod(self.impl.key_shape) * np.dtype('uint32').itemsize
return math.prod(self._impl.key_shape) * np.dtype('uint32').itemsize

def __repr__(self) -> str:
return self.name

def __eq__(self, other):
return type(other) is KeyTy and self.impl == other.impl
return type(other) is KeyTy and self._impl == other._impl

def __hash__(self) -> int:
return hash((self.__class__, self.impl))
return hash((self.__class__, self._impl))



Expand All @@ -640,7 +633,7 @@ def __hash__(self) -> int:

def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
aval = x.aval
key_shape = aval.dtype.impl.key_shape
key_shape = aval.dtype._impl.key_shape
arr = x._base_array

# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
Expand Down Expand Up @@ -752,21 +745,21 @@ def random_split(keys, shape: Shape):

@random_split_p.def_abstract_eval
def random_split_abstract_eval(keys_aval, *, shape):
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, *shape))
return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape))

@random_split_p.def_impl
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys.impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys.impl, base_arr)
keys._impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)

def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
return split(base_arr)

def random_split_lowering(ctx, keys, *, shape):
aval, = ctx.avals_in
impl = aval.dtype.impl
impl = aval.dtype._impl
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, shape))
split_lowering = mlir.lower_fun(split, multiple_results=False)
return mlir.delegate_lowering(
Expand Down Expand Up @@ -794,8 +787,8 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys.impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys.impl, base_arr)
keys._impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys._impl, base_arr)

def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
fold_in = iterated_vmap_binary_bcast(
Expand All @@ -804,7 +797,7 @@ def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):

def random_fold_in_lowering(ctx, keys, msgs):
keys_aval, msgs_aval = ctx.avals_in
impl = keys_aval.dtype.impl
impl = keys_aval.dtype._impl
fold_in = iterated_vmap_binary_bcast(
keys_aval.shape, msgs_aval.shape, impl.fold_in)
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
Expand Down Expand Up @@ -843,7 +836,7 @@ def random_bits_abstract_eval(keys_aval, *, bit_width, shape):

@random_bits_p.def_impl
def random_bits_impl(keys, *, bit_width, shape):
return random_bits_impl_base(keys.impl, keys._base_array, keys.ndim,
return random_bits_impl_base(keys._impl, keys._base_array, keys.ndim,
bit_width=bit_width, shape=shape)

def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
Expand All @@ -853,7 +846,7 @@ def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):

def random_bits_lowering(ctx, keys, *, bit_width, shape):
aval, = ctx.avals_in
impl = aval.dtype.impl
impl = aval.dtype._impl
bits = iterated_vmap_unary(
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/random.py
Expand Up @@ -308,7 +308,7 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
def _key_impl(keys: KeyArray) -> PRNGImpl:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype.impl
return keys_dtype._impl

def key_impl(keys: KeyArray) -> Hashable:
keys, _ = _check_prng_key(keys)
Expand Down Expand Up @@ -1501,7 +1501,8 @@ def poisson(key: KeyArray,
dtypes.check_user_dtype_supported(dtype)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check
key_impl = key.dtype.impl # type: ignore[union-attr]
keys_dtype = typing.cast(prng.KeyTy, key.dtype)
key_impl = keys_dtype._impl
if key_impl is not prng.threefry_prng_impl:
raise NotImplementedError(
'`poisson` is only implemented for the threefry2x32 RNG, '
Expand Down

0 comments on commit 0da4be5

Please sign in to comment.