Skip to content

Commit

Permalink
reintroduce the Threefry GPU kernel lowering, under a flag
Browse files Browse the repository at this point in the history
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
  • Loading branch information
froystig authored and jax authors committed May 1, 2024
1 parent 9bf1148 commit 3f95407
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 11 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ Remember to align the itemized text with the first line of an item within a list
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.

* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def trace_context():
dynamic_shapes.value, numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
Expand Down Expand Up @@ -811,6 +812,7 @@ class _GlobalExtraJitContext(NamedTuple):
dynamic_shapes: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0

Expand Down Expand Up @@ -845,6 +847,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
dynamic_shapes: bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None

Expand Down Expand Up @@ -1083,6 +1086,17 @@ def _update_jax_memories_thread_local(val):
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))

threefry_gpu_kernel_lowering = define_bool_state(
name='jax_threefry_gpu_kernel_lowering',
default=False,
help=('On GPU, lower threefry PRNG operations to a kernel implementation. '
'This makes compile times faster at a potential runtime memory '
'cost.'),
update_global_hook=lambda val: _update_global_jit_state(
threefry_gpu_kernel_lowering=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_gpu_kernel_lowering=val))


softmax_custom_jvp = define_bool_state(
name='jax_softmax_custom_jvp',
Expand Down
61 changes: 54 additions & 7 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@
from jax._src.interpreters import xla
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir import ir
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
Expand Down Expand Up @@ -1002,17 +1003,63 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)


_threefry2x32_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True)

_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True)


def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2):
if not config.threefry_gpu_kernel_lowering.value: # back to default lowering
return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2)

aval_out, aval_out_2 = ctx.avals_out
assert aval_out == aval_out_2
k1_aval, k2_aval, x1_aval, x2_aval = ctx.avals_in
rank = len(aval_out.shape)
if 0 in aval_out.shape:
zeros = mlir.full_like_aval(ctx, 0, aval_out)
return [zeros, zeros]
def _broadcast(x, aval):
return mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=range(rank - len(aval.shape), rank))

out_len = reduce(op.mul, aval_out.shape, 1)
if not core.is_constant_dim(out_len):
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
length = mlir.hlo.convert(
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
length)
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
else:
length = int(out_len) # will be passed statically
output_shape = None

return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape)

threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True
threefry2x32_p.def_impl(partial(dispatch.apply_primitive, threefry2x32_p))
threefry2x32_p.def_abstract_eval(_threefry2x32_abstract_eval)
batching.defbroadcasting(threefry2x32_p)
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True))
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True), platform='cpu')
mlir.register_lowering(
threefry2x32_p, _threefry2x32_lowering_rule)
mlir.register_lowering(
threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32),
platform='cuda')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32),
platform='rocm')


def iota_2x32_shape(shape):
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def _check_lowering(lowering) -> None:
# Their backwards compatibility is tested by back_compat_test.py.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"dynamic_ducc_fft",
"dynamic_ducc_fft", "cu_threefry2x32",
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU
Expand Down
6 changes: 3 additions & 3 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@


@jtu.with_config(jax_legacy_prng_key='allow',
jax_debug_key_reuse=False)
jax_debug_key_reuse=False,
jax_threefry_gpu_kernel_lowering=True)
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
# Tests the testing mechanism. Let this test run on all platforms
Expand Down Expand Up @@ -573,12 +574,11 @@ def func():
self.run_one_test(func, data)

def test_cuda_threefry2x32(self):
# TODO(frostig): remove after 2024-11-01
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)

data = self.load_testdata(cuda_threefry2x32.data_2023_03_15)
self.run_one_test(func, data, expect_current_custom_calls=[])
self.run_one_test(func, data)

def test_sharding(self):
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
Expand Down
10 changes: 10 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,16 @@ def testPRNGValues(self, make_key):
random.key_data(random.fold_in(make_key(seed), 4)),
np.array([2285895361, 433833334], dtype='uint32'))

@jtu.run_on_devices("gpu")
def test_threefry_gpu_kernel_lowering(self):
f = lambda key: jax.random.uniform(key, (1,))
with jax._src.config.threefry_gpu_kernel_lowering(False):
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
self.assertNotIn("cu_threefry2x32", hlo_text)
with jax._src.config.threefry_gpu_kernel_lowering(True):
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
self.assertIn("cu_threefry2x32", hlo_text)

@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_seed_offset(self, make_key):
k1 = make_key(17)
Expand Down

0 comments on commit 3f95407

Please sign in to comment.