Skip to content

Commit

Permalink
[random] cleanup internal implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 17, 2023
1 parent d03bbc0 commit 5636735
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions jax/_src/prng.py
Expand Up @@ -436,10 +436,6 @@ def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:
def keys_shaped_array(impl, shape):
return core.ShapedArray(shape, KeyTy(impl))

# TODO(frostig): remove in favor of physical_aval call
def keys_aval_to_base_arr_aval(keys_aval):
return core.physical_aval(keys_aval)

def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
return base_arr_shape[:-base_ndim]
Expand Down Expand Up @@ -588,15 +584,15 @@ def make_sharded_array(aval, sharding, arrays, committed):

@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = keys_aval_to_base_arr_aval(aval)
physical_aval = core.physical_aval(aval)
physical_buffers = tree_util.tree_map(random_unwrap, vals)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
return random_wrap(physical_result, impl=aval.dtype.impl)

@staticmethod
def device_put_replicated(val, aval, sharding, devices):
physical_aval = keys_aval_to_base_arr_aval(aval)
physical_aval = core.physical_aval(aval)
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
physical_buf = random_unwrap(val)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
Expand Down Expand Up @@ -738,7 +734,7 @@ def random_seed_lowering(ctx, seeds, *, impl):
seed_lowering = mlir.lower_fun(seed, multiple_results=False)
return mlir.delegate_lowering(
ctx, seed_lowering, seeds,
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
avals_out=map(core.physical_aval, ctx.avals_out))

mlir.register_lowering(random_seed_p, random_seed_lowering)

Expand Down Expand Up @@ -771,8 +767,8 @@ def random_split_lowering(ctx, keys, *, shape):
split_lowering = mlir.lower_fun(split, multiple_results=False)
return mlir.delegate_lowering(
ctx, split_lowering, keys,
avals_in=[keys_aval_to_base_arr_aval(aval)],
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
avals_in=[core.physical_aval(aval)],
avals_out=map(core.physical_aval, ctx.avals_out))

mlir.register_lowering(random_split_p, random_split_lowering)

Expand Down Expand Up @@ -810,8 +806,8 @@ def random_fold_in_lowering(ctx, keys, msgs):
fold_in_lowering = mlir.lower_fun(fold_in, multiple_results=False)
return mlir.delegate_lowering(
ctx, fold_in_lowering, keys, msgs,
avals_in=[keys_aval_to_base_arr_aval(keys_aval), msgs_aval],
avals_out=map(keys_aval_to_base_arr_aval, ctx.avals_out))
avals_in=[core.physical_aval(keys_aval), msgs_aval],
avals_out=map(core.physical_aval, ctx.avals_out))

mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)

Expand Down Expand Up @@ -857,7 +853,7 @@ def random_bits_lowering(ctx, keys, *, bit_width, shape):
bits = iterated_vmap_unary(
aval.ndim, lambda k: impl.random_bits(k, bit_width, shape))
bits_lowering = mlir.lower_fun(bits, multiple_results=False)
ctx_new = ctx.replace(avals_in=[keys_aval_to_base_arr_aval(aval)])
ctx_new = ctx.replace(avals_in=[core.physical_aval(aval)])
out = bits_lowering(ctx_new, keys)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
Expand Down Expand Up @@ -925,7 +921,7 @@ def random_unwrap(keys):

@random_unwrap_p.def_abstract_eval
def random_unwrap_abstract_eval(keys_aval):
return keys_aval_to_base_arr_aval(keys_aval)
return core.physical_aval(keys_aval)

@random_unwrap_p.def_impl
def random_unwrap_impl(keys):
Expand Down

0 comments on commit 5636735

Please sign in to comment.