Skip to content

Commit

Permalink
introduce a config flag to control a random seed offset
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Dec 13, 2023
1 parent 7305b64 commit 6717907
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
13 changes: 13 additions & 0 deletions jax/_src/config.py
Expand Up @@ -210,6 +210,7 @@ def _trace_context(self):
self.jax_numpy_rank_promotion, self.jax_default_matmul_precision,
self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion,
self.jax_default_device,
self.jax_random_seed_offset,
self.jax_threefry_partitionable,
self.jax_softmax_custom_jvp,
self.jax_enable_memories,
Expand Down Expand Up @@ -655,6 +656,7 @@ class _GlobalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
Expand Down Expand Up @@ -682,6 +684,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
Expand Down Expand Up @@ -868,6 +871,16 @@ def _update_jax_memories_thread_local(val):
'computations. Logging is performed with `logging` at WARNING '
'level.'))

random_seed_offset = define_int_state(
name='jax_random_seed_offset',
default=0,
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
update_global_hook=lambda val: _update_global_jit_state(
random_seed_offset=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
random_seed_offset=val)
)

legacy_prng_key = define_enum_state(
name='jax_legacy_prng_key',
enum_values=['allow', 'warn', 'error'],
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/prng.py
Expand Up @@ -733,6 +733,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayIm
seeds_arr = jnp.asarray(np.int64(seeds))
else:
seeds_arr = jnp.asarray(seeds)
if config.random_seed_offset.value:
seeds_arr += config.random_seed_offset.value
return random_seed_p.bind(seeds_arr, impl=impl)

random_seed_p = core.Primitive('random_seed')
Expand Down
8 changes: 8 additions & 0 deletions tests/random_test.py
Expand Up @@ -399,6 +399,14 @@ def testPRNGValues(self, make_key):
random.key_data(random.fold_in(k, 4)),
np.array([2285895361, 433833334], dtype='uint32'))

@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_seed_offset(self, make_key):
k1 = make_key(17)
with config.random_seed_offset(3):
k2 = make_key(17)
eq = k1 == k2 if k2.ndim == 0 else all(k1 == k2)
self.assertFalse(eq)

@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_random_bits_error(self, make_key):
msg = 'dtype argument .* must be an unsigned int dtype'
Expand Down

0 comments on commit 6717907

Please sign in to comment.