diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index f527c0dcfff7..eb2f183b5f40 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -330,7 +330,7 @@ True True ``` And in addition to `key.dtype._rules` as outlined for extended dtypes in -general, PRNG dtypes define `key.dtype.impl`, which contains the metadata +general, PRNG dtypes define `key.dtype._impl`, which contains the metadata that defines the PRNG implementation. The PRNG implementation is currently defined by the non-public `jax._src.prng.PRNGImpl` class. For now, `PRNGImpl` isn't meant to be a public API, but we might revisit this soon to allow for diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 391eea4510c7..e2b14017b949 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -258,14 +258,13 @@ class behave like an array whose base elements are keys, hiding the ``random_bits``, ``fold_in``). """ - # TODO(frostig,vanderplas): hide impl attribute - impl: PRNGImpl + _impl: PRNGImpl _base_array: typing.Array def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) - self.impl = impl + self._impl = impl self._base_array = key_data def block_until_ready(self): @@ -277,11 +276,11 @@ def copy_to_host_async(self): @property def aval(self): - return keys_shaped_array(self.impl, self.shape) + return keys_shaped_array(self._impl, self.shape) @property def shape(self): - return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape) + return base_arr_shape_to_keys_shape(self._impl, self._base_array.shape) @property def size(self): @@ -293,7 +292,7 @@ def ndim(self): @property def dtype(self): - return KeyTy(self.impl) + return KeyTy(self._impl) @property def itemsize(self): @@ -318,7 +317,7 @@ def unsafe_raw_array(self): return self._base_array def addressable_data(self, index: int) -> PRNGKeyArrayImpl: - return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index)) + return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index)) @property def addressable_shards(self) -> list[Shard]: @@ -327,7 +326,7 @@ def addressable_shards(self) -> list[Shard]: device=s._device, sharding=s._sharding, global_shape=s._global_shape, - data=PRNGKeyArrayImpl(self.impl, s._data), + data=PRNGKeyArrayImpl(self._impl, s._data), ) for s in self._base_array.addressable_shards ] @@ -339,7 +338,7 @@ def global_shards(self) -> list[Shard]: device=s._device, sharding=s._sharding, global_shape=s._global_shape, - data=PRNGKeyArrayImpl(self.impl, s._data), + data=PRNGKeyArrayImpl(self._impl, s._data), ) for s in self._base_array.global_shards ] @@ -350,7 +349,7 @@ def sharding(self): return KeyTyRules.logical_op_sharding(self.aval, phys_sharding) def _is_scalar(self): - base_ndim = len(self.impl.key_shape) + base_ndim = len(self._impl.key_shape) return self._base_array.ndim == base_ndim def __len__(self): @@ -370,12 +369,7 @@ def __iter__(self) -> Iterator[PRNGKeyArrayImpl]: # * return iter over these unpacked slices # Whatever we do, we'll want to do it by overriding # ShapedArray._iter when the element type is KeyTy... - return (PRNGKeyArrayImpl(self.impl, k) for k in iter(self._base_array)) - - # TODO(frostig): are all of the stackable methods below (reshape, - # concat, broadcast_to, expand_dims), and the stackable registration, - # still needed? If, with some work, none are needed, then do we want - # to remove stackables altogether? This may be the only application. + return (PRNGKeyArrayImpl(self._impl, k) for k in iter(self._base_array)) def __repr__(self): return (f'Array({self.shape}, dtype={self.dtype.name}) overlaying:\n' @@ -383,13 +377,13 @@ def __repr__(self): def pprint(self): pp_keys = pp.text('shape = ') + pp.text(str(self.shape)) - pp_impl = pp.text('impl = ') + self.impl.pprint() + pp_impl = pp.text('impl = ') + self._impl.pprint() return str(pp.group( pp.text('PRNGKeyArray:') + pp.nest(2, pp.brk() + pp_keys + pp.brk() + pp_impl))) def copy(self): - return self.__class__(self.impl, self._base_array.copy()) + return self.__class__(self._impl, self._base_array.copy()) __hash__ = None # type: ignore[assignment] __array_priority__ = 100 @@ -418,7 +412,7 @@ def transpose(self, *_, **__) -> PRNGKeyArray: assert False ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type] def prngkeyarrayimpl_flatten(x): - return (x._base_array,), x.impl + return (x._base_array,), x._impl def prngkeyarrayimpl_unflatten(impl, children): base_array, = children @@ -448,7 +442,7 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): if dispatch.is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) phys_sharding_spec = sharding_specs.ShardingSpec( sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), @@ -456,7 +450,7 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): return PmapSharding(devices=sharding.devices, sharding_spec=phys_sharding_spec) elif isinstance(sharding, NamedSharding): - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape trailing_spec = [None] * len(key_shape) return NamedSharding( sharding.mesh, @@ -473,18 +467,18 @@ class KeyTyRules: @staticmethod def full(shape, fill_value, dtype): - physical_shape = (*shape, *dtype.impl.key_shape) + physical_shape = (*shape, *dtype._impl.key_shape) if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key): key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape) else: key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32')) # TODO(frostig,mattjj,vanderplas,lenamartens): consider this consumed from # the outset. - return random_wrap(key_data, impl=dtype.impl) + return random_wrap(key_data, impl=dtype._impl) @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: - return core.ShapedArray(dtype.impl.key_shape, jnp.dtype('uint32')) + return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32')) @staticmethod def physical_const(val) -> Array: @@ -492,7 +486,7 @@ def physical_const(val) -> Array: @staticmethod def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape op_sharding_proto = hlo_sharding.to_proto() # type: ignore new_op_sharding = op_sharding_proto.clone() tad = list(new_op_sharding.tile_assignment_dimensions) @@ -506,19 +500,19 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: if dispatch.is_single_device_sharding(phys_sharding): return phys_sharding elif isinstance(phys_sharding, PmapSharding): - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape logical_sharding_spec = sharding_specs.ShardingSpec( sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)], mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) return PmapSharding(devices=phys_sharding.devices, sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape return pxla.create_mesh_pspec_sharding( phys_sharding.mesh, PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) else: - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape phys_op_sharding = phys_sharding._to_xla_hlo_sharding( aval.ndim + len(key_shape)).to_proto() logical_op_sharding = phys_op_sharding.clone() @@ -532,13 +526,13 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding: def result_handler(sticky_device, aval): def handler(_, buf): buf.aval = core.ShapedArray(buf.shape, buf.dtype) - return PRNGKeyArrayImpl(aval.dtype.impl, buf) + return PRNGKeyArrayImpl(aval.dtype._impl, buf) return handler @staticmethod def local_sharded_result_handler(aval, sharding, indices): phys_aval = core.physical_aval(aval) - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape phys_handler_maker = pxla.local_result_handlers[core.ShapedArray] # set up a grounded sharding (with a grounded sharding spec) @@ -557,7 +551,7 @@ def local_sharded_result_handler(aval, sharding, indices): # set up a handler that calls the physical one and wraps back up def handler(bufs): - return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs)) + return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs)) return handler @@ -572,7 +566,7 @@ def global_sharded_result_handler(aval, out_sharding, committed, phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, is_out_sharding_from_xla) def handler(bufs): - return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs)) + return PRNGKeyArrayImpl(aval.dtype._impl, phys_handler(bufs)) return handler @staticmethod @@ -584,7 +578,7 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_sharding = make_key_array_phys_sharding(aval, sharding, False) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False) phys_result = phys_handler(phys_arrays) - return PRNGKeyArrayImpl(aval.dtype.impl, phys_result) + return PRNGKeyArrayImpl(aval.dtype._impl, phys_result) @staticmethod def device_put_sharded(vals, aval, sharding, devices): @@ -592,7 +586,7 @@ def device_put_sharded(vals, aval, sharding, devices): physical_buffers = tree_util.tree_map(random_unwrap, vals) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) - return random_wrap(physical_result, impl=aval.dtype.impl) + return random_wrap(physical_result, impl=aval.dtype._impl) @staticmethod def device_put_replicated(val, aval, sharding, devices): @@ -601,34 +595,33 @@ def device_put_replicated(val, aval, sharding, devices): physical_buf = random_unwrap(val) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) - return random_wrap(physical_result, impl=aval.dtype.impl) + return random_wrap(physical_result, impl=aval.dtype._impl) class KeyTy(dtypes.ExtendedDType): - # TODO(frostig,vanderplas): hide impl attribute - impl: PRNGImpl # TODO(mattjj,frostig): protocol really + _impl: PRNGImpl # TODO(mattjj,frostig): protocol really _rules = KeyTyRules type = dtypes.prng_key def __init__(self, impl): - self.impl = impl + self._impl = impl @property def name(self) -> str: - return f'key<{self.impl.tag}>' + return f'key<{self._impl.tag}>' @property def itemsize(self) -> int: - return math.prod(self.impl.key_shape) * np.dtype('uint32').itemsize + return math.prod(self._impl.key_shape) * np.dtype('uint32').itemsize def __repr__(self) -> str: return self.name def __eq__(self, other): - return type(other) is KeyTy and self.impl == other.impl + return type(other) is KeyTy and self._impl == other._impl def __hash__(self) -> int: - return hash((self.__class__, self.impl)) + return hash((self.__class__, self._impl)) @@ -640,7 +633,7 @@ def __hash__(self) -> int: def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding): aval = x.aval - key_shape = aval.dtype.impl.key_shape + key_shape = aval.dtype._impl.key_shape arr = x._base_array # TODO(yashkatariya,frostig): This assumes that the last dimensions are not @@ -752,13 +745,13 @@ def random_split(keys, shape: Shape): @random_split_p.def_abstract_eval def random_split_abstract_eval(keys_aval, *, shape): - return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, *shape)) + return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape)) @random_split_p.def_impl def random_split_impl(keys, *, shape): base_arr = random_split_impl_base( - keys.impl, keys._base_array, keys.ndim, shape=shape) - return PRNGKeyArrayImpl(keys.impl, base_arr) + keys._impl, keys._base_array, keys.ndim, shape=shape) + return PRNGKeyArrayImpl(keys._impl, base_arr) def random_split_impl_base(impl, base_arr, keys_ndim, *, shape): split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape)) @@ -766,7 +759,7 @@ def random_split_impl_base(impl, base_arr, keys_ndim, *, shape): def random_split_lowering(ctx, keys, *, shape): aval, = ctx.avals_in - impl = aval.dtype.impl + impl = aval.dtype._impl split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, shape)) split_lowering = mlir.lower_fun(split, multiple_results=False) return mlir.delegate_lowering( @@ -794,8 +787,8 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval): @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): base_arr = random_fold_in_impl_base( - keys.impl, keys._base_array, msgs, keys.shape) - return PRNGKeyArrayImpl(keys.impl, base_arr) + keys._impl, keys._base_array, msgs, keys.shape) + return PRNGKeyArrayImpl(keys._impl, base_arr) def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape): fold_in = iterated_vmap_binary_bcast( @@ -804,7 +797,7 @@ def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape): def random_fold_in_lowering(ctx, keys, msgs): keys_aval, msgs_aval = ctx.avals_in - impl = keys_aval.dtype.impl + impl = keys_aval.dtype._impl fold_in = iterated_vmap_binary_bcast( keys_aval.shape, msgs_aval.shape, impl.fold_in) fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False) @@ -843,7 +836,7 @@ def random_bits_abstract_eval(keys_aval, *, bit_width, shape): @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): - return random_bits_impl_base(keys.impl, keys._base_array, keys.ndim, + return random_bits_impl_base(keys._impl, keys._base_array, keys.ndim, bit_width=bit_width, shape=shape) def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape): @@ -853,7 +846,7 @@ def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape): def random_bits_lowering(ctx, keys, *, bit_width, shape): aval, = ctx.avals_in - impl = aval.dtype.impl + impl = aval.dtype._impl bits = iterated_vmap_unary( aval.ndim, lambda k: impl.random_bits(k, bit_width, shape)) bits_lowering = mlir.lower_fun(bits, multiple_results=False) diff --git a/jax/_src/random.py b/jax/_src/random.py index 4b0ac00e3d66..63806c954056 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -308,7 +308,7 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: def _key_impl(keys: KeyArray) -> PRNGImpl: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) - return keys_dtype.impl + return keys_dtype._impl def key_impl(keys: KeyArray) -> Hashable: keys, _ = _check_prng_key(keys) @@ -1501,7 +1501,8 @@ def poisson(key: KeyArray, dtypes.check_user_dtype_supported(dtype) # TODO(frostig): generalize underlying poisson implementation and # remove this check - key_impl = key.dtype.impl # type: ignore[union-attr] + keys_dtype = typing.cast(prng.KeyTy, key.dtype) + key_impl = keys_dtype._impl if key_impl is not prng.threefry_prng_impl: raise NotImplementedError( '`poisson` is only implemented for the threefry2x32 RNG, ' diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 1ef0239644d4..ee38031db0a5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2651,7 +2651,7 @@ def _random_split_impl(keys: TfVal, *, shape, _in_avals, _out_aval): def impl_wrapper(keys: TfVal, *, shape): return prng.random_split_impl_base( - keys_aval.dtype.impl, keys, keys_aval.ndim, shape=shape) + keys_aval.dtype._impl, keys, keys_aval.ndim, shape=shape) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, @@ -2667,7 +2667,7 @@ def _random_fold_in_impl(keys: TfVal, msgs: TfVal, *, _in_avals, _out_aval): def impl_wrapper(keys: TfVal, msgs: TfVal): return prng.random_fold_in_impl_base( - keys_aval.dtype.impl, keys, msgs, keys_aval.shape) + keys_aval.dtype._impl, keys, msgs, keys_aval.shape) converted_impl = _convert_jax_impl( impl_wrapper, multiple_results=False, with_physical_avals=True, @@ -2683,7 +2683,7 @@ def _random_bits_impl(keys: TfVal, *, bit_width, shape, _in_avals, _out_aval): def impl_wrapper(keys: TfVal, **kwargs): return prng.random_bits_impl_base( - keys_aval.dtype.impl, keys, keys_aval.ndim, + keys_aval.dtype._impl, keys, keys_aval.ndim, bit_width=bit_width, shape=shape) converted_impl = _convert_jax_impl( diff --git a/tests/random_test.py b/tests/random_test.py index 2ad03710e9d0..f878f96f9fa2 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -184,7 +184,7 @@ class PrngTest(jtu.JaxTestCase): def check_key_has_impl(self, key, impl): if jnp.issubdtype(key.dtype, dtypes.prng_key): - self.assertIs(key.impl, impl) + self.assertIs(key._impl, impl) else: self.assertEqual(key.dtype, jnp.dtype('uint32')) self.assertEqual(key.shape, impl.key_shape) @@ -1190,7 +1190,7 @@ def test_reshape(self): newshape = (2, 2) key_func = partial(jnp.reshape, newshape=newshape) - arr_func = partial(jnp.reshape, newshape=(*newshape, *key.impl.key_shape)) + arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape)) self.check_shape(key_func, keys) self.check_against_reference(key_func, arr_func, keys) @@ -1200,7 +1200,7 @@ def test_tile(self): reps = 3 key_func = partial(jnp.tile, reps=reps) - arr_func = lambda x: jnp.tile(x[None], reps=(reps, *(1 for _ in key.impl.key_shape))) + arr_func = lambda x: jnp.tile(x[None], reps=(reps, *(1 for _ in key._impl.key_shape))) self.check_shape(key_func, key) self.check_against_reference(key_func, arr_func, key) @@ -1219,7 +1219,7 @@ def test_broadcast_to(self): shape = (3,) key_func = partial(jnp.broadcast_to, shape=shape) - arr_func = partial(jnp.broadcast_to, shape=(*shape, *key.impl.key_shape)) + arr_func = partial(jnp.broadcast_to, shape=(*shape, *key._impl.key_shape)) self.check_shape(key_func, key) self.check_against_reference(key_func, arr_func, key) @@ -1257,7 +1257,7 @@ def test_ravel(self): keys = random.split(key, 4).reshape(2, 2) key_func = jnp.ravel - arr_func = partial(jnp.reshape, newshape=(4, *key.impl.key_shape)) + arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape)) self.check_shape(key_func, keys) self.check_against_reference(key_func, arr_func, keys)