From 563673576e1fde1984444712c71375700e407162 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Oct 2023 15:47:32 -0700 Subject: [PATCH] [random] cleanup internal implementation --- jax/_src/prng.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4d59fe7f85fe..267a7fe78e07 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -436,10 +436,6 @@ def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl: def keys_shaped_array(impl, shape): return core.ShapedArray(shape, KeyTy(impl)) -# TODO(frostig): remove in favor of physical_aval call -def keys_aval_to_base_arr_aval(keys_aval): - return core.physical_aval(keys_aval) - def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] @@ -588,7 +584,7 @@ def make_sharded_array(aval, sharding, arrays, committed): @staticmethod def device_put_sharded(vals, aval, sharding, devices): - physical_aval = keys_aval_to_base_arr_aval(aval) + physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) @@ -596,7 +592,7 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): - physical_aval = keys_aval_to_base_arr_aval(aval) + physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) physical_sharding = make_key_array_phys_sharding(aval, sharding, False) @@ -738,7 +734,7 @@ def random_seed_lowering(ctx, seeds, *, impl): seed_lowering = mlir.lower_fun(seed, multiple_results=False) return mlir.delegate_lowering( ctx, seed_lowering, seeds, - avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) + avals_out=map(core.physical_aval, ctx.avals_out)) mlir.register_lowering(random_seed_p, random_seed_lowering) @@ -771,8 +767,8 @@ def random_split_lowering(ctx, keys, *, shape): split_lowering = mlir.lower_fun(split, multiple_results=False) return mlir.delegate_lowering( ctx, split_lowering, keys, - avals_in=[keys_aval_to_base_arr_aval(aval)], - avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) + avals_in=[core.physical_aval(aval)], + avals_out=map(core.physical_aval, ctx.avals_out)) mlir.register_lowering(random_split_p, random_split_lowering) @@ -810,8 +806,8 @@ def random_fold_in_lowering(ctx, keys, msgs): fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False) return mlir.delegate_lowering( ctx, fold_in_lowering, keys, msgs, - avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval], - avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out)) + avals_in=[core.physical_aval(keys_aval), msgs_aval], + avals_out=map(core.physical_aval, ctx.avals_out)) mlir.register_lowering(random_fold_in_p, random_fold_in_lowering) @@ -857,7 +853,7 @@ def random_bits_lowering(ctx, keys, *, bit_width, shape): bits = iterated_vmap_unary( aval.ndim, lambda k: impl.random_bits(k, bit_width, shape)) bits_lowering = mlir.lower_fun(bits, multiple_results=False) - ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)]) + ctx_new = ctx.replace(avals_in=[core.physical_aval(aval)]) out = bits_lowering(ctx_new, keys) ctx.set_tokens_out(ctx_new.tokens_out) return out @@ -925,7 +921,7 @@ def random_unwrap(keys): @random_unwrap_p.def_abstract_eval def random_unwrap_abstract_eval(keys_aval): - return keys_aval_to_base_arr_aval(keys_aval) + return core.physical_aval(keys_aval) @random_unwrap_p.def_impl def random_unwrap_impl(keys):