From 671790730e69276333e9553d3d6799c36bfa5e60 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 12 Dec 2023 18:31:07 -0800 Subject: [PATCH] introduce a config flag to control a random seed offset --- jax/_src/config.py | 13 +++++++++++++ jax/_src/prng.py | 2 ++ tests/random_test.py | 8 ++++++++ 3 files changed, 23 insertions(+) diff --git a/jax/_src/config.py b/jax/_src/config.py index 12d29fc55a82..41a2c4921ffd 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -210,6 +210,7 @@ def _trace_context(self): self.jax_numpy_rank_promotion, self.jax_default_matmul_precision, self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion, self.jax_default_device, + self.jax_random_seed_offset, self.jax_threefry_partitionable, self.jax_softmax_custom_jvp, self.jax_enable_memories, @@ -655,6 +656,7 @@ class _GlobalExtraJitContext(NamedTuple): numpy_dtype_promotion: str | None = None default_matmul_precision: Any | None = None dynamic_shapes: bool = False + random_seed_offset: int = 0 threefry_partitionable: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 @@ -682,6 +684,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): numpy_dtype_promotion: str | None = None default_matmul_precision: Any | None = None dynamic_shapes: bool = False + random_seed_offset: int = 0 threefry_partitionable: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 @@ -868,6 +871,16 @@ def _update_jax_memories_thread_local(val): 'computations. Logging is performed with `logging` at WARNING ' 'level.')) +random_seed_offset = define_int_state( + name='jax_random_seed_offset', + default=0, + help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), + update_global_hook=lambda val: _update_global_jit_state( + random_seed_offset=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + random_seed_offset=val) +) + legacy_prng_key = define_enum_state( name='jax_legacy_prng_key', enum_values=['allow', 'warn', 'error'], diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 187c7645fbec..4158b4eb53ef 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -733,6 +733,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayIm seeds_arr = jnp.asarray(np.int64(seeds)) else: seeds_arr = jnp.asarray(seeds) + if config.random_seed_offset.value: + seeds_arr += config.random_seed_offset.value return random_seed_p.bind(seeds_arr, impl=impl) random_seed_p = core.Primitive('random_seed') diff --git a/tests/random_test.py b/tests/random_test.py index 52ce462e3e43..bec04233a57a 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -399,6 +399,14 @@ def testPRNGValues(self, make_key): random.key_data(random.fold_in(k, 4)), np.array([2285895361, 433833334], dtype='uint32')) + @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) + def test_random_seed_offset(self, make_key): + k1 = make_key(17) + with config.random_seed_offset(3): + k2 = make_key(17) + eq = k1 == k2 if k2.ndim == 0 else all(k1 == k2) + self.assertFalse(eq) + @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_random_bits_error(self, make_key): msg = 'dtype argument .* must be an unsigned int dtype'