New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
current rng setup is full of footguns in jax #18426
Comments
a reference re jax rng mechanics |
Right -- random ops in Keras Core are basically always intended to be called with a This is a gotcha for sure. Maybe we can do something to resolve it. Some considerations:
|
both, the int needs to come from somewhere
if someone is writing a fully custom training loop in jax they will either be aware of this or immediately shoot their foot off.
I don't quite understand the distinction here... maybe you could write some pseudocode to clarify it?
ideally we would not require the user to pass rng info all the way through model init in jax but nowhere else (and we cannot assume that model init happens in eager mode, in distributed training there isn't necessarily enough space on any given accelerator to support that) |
Yes but isn't that how we expect it? Anyone who writes JAX is aware of the PRNG handling, and the consequences of not handling it properly. Anything with the PRNG should be explicit rather than implicit. It also makes debugging easy |
The difficulty is that there will be some reference to a RNG seed variable that you'll have to take into account, something like --
Anyone writing custom training loops is going to need to know about this API. If they forget it -- well, back to the current status quo, which is that their unseeded RNG calls are unchanged across iterations. Maybe not that bad if you think of it like that. |
It seems like a clear improvement over the current scenario, where you not only have to remember to do this, but it's not even suggested by the default APIs, which make it very difficult to do this. |
+1 to proper RNG management. This is extremely important to JAX users. JAX offers good reproducibility out of the box (with some RNG learning curve for the user). It's fine is Keras can simplify the API with automatic jax.random.split(s) in the right places but "reproducibility out of the box" should remain. |
I looked at this more closely. What I can propose is this:
So you can manage your unseeded random ops calls like this: @jax.jit
def jitted_random_numbers(seed):
rng_state = keras.random.global_rng_state()
rng_state.assign(seed)
x = keras.random.normal(...)
y = keras.random.uniform(...)
return x, y, rng_state.value You could even have something a bit more intuitive like this: @jax.jit
def jitted_random_numbers(seed):
keras.random.set_global_rng_state(seed)
x = keras.random.normal(...)
y = keras.random.uniform(...)
return x, y, keras.random.global_rng_state().value # the .value won't be necessary in a future JAX version The default behavior would be unchanged from now (unseeded random op calls are deterministic per traced function execution). Does that work?
We have that already, via the class RandomLayer(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1234)
def call(self, x):
return x + keras.random.normal(..., seed=self.seed_generator)
layer = RandomLayer() # layer.non_trainable_variables includes the seed
non_trainable_vars = layer.non_trainable_variables
outputs, non_trainable_vars = layer.stateless_call(non_trainable_vars, x) # This returns the updated seed value |
If this results in Dropout layer silently not working (same dropout mask at each invocation), then this is not a good solution. You suggested catching this with an error in the compiled case. That sounds right. How doable is it? |
Wouldn't it be better to handle the rng at the model level, as part of the model state? The standard training loop with dropout in JAX usually looks like this (omitting non-trainable variables for simplicity but not RNGs:
|
Could Keras have something very similar ? Like this:
|
Compared to your proposal above, I'd like to build my dropout-like layers with SeedGenerator() without params and still be able to assign a seed to all these layers at once though the Model abstraction. Same thing for weight initializers, btw. |
That's roughly how it works, except it's actually much simpler and more intuitive.
|
This sounds like a clean solution to me. The only thing that I would expect in this scenario is to a way to access the random state at any step to ensure that the rng is being handled properly. That way I can validate that everything on the model side is working as expected |
OK, that sounds good. Three questions:
|
Also, |
Yes
Initializers are only meant to be called once, and integer-seeded initializers always return the same value, just like integer-seeded ops.
Yes
Call
No need. But if you do want to manage your random seed sequence yourself via whatever algorithm of your choice, you have that option (you can just subclass |
OK, thank you for your answers. For droput-style random layers as well as weight initializers, this looks good. I especially like that the the same setup (i.e. seed=123) is the correct one in both case. As for RNG splitting? I think the problem to solve is rng determinism, even in a distributed setting where execution order is not fully deterministic. The theory of why it is needed seems complicated (ref) and math-intensive. I don't have the full background so I will accept your conclusion that this mechanism is "not needed" at face value. I fear however that most users will prefer relying on the standard "split" mechanism implemented in TF and JAX rather than investing the time to analyze and be convinced by the assertion that the mechanism is useless. If you have a proof of this assertion, please put it forward, but consider the difficulty of then communicating it to all users and convincing them. |
@fchollet I looked at the
I totally agree with @martin-gorner here. It's very hard to convince the users to validate another PRNG implementation in their daily workflow. A simple thing to do would be to leverage JAX PRNG implementation in the |
Sure, I'm open to having a |
Thank you. I guess we can refactor our seed generator like this: class SeedGenerator:
def __init__(self, seed=None, **kwargs):
if seed is None:
seed = jax.random.PRNGKey(make_default_seed())
self._initial_seed = seed
def seed_initializer(*args, **kwargs):
return self.backend.convert_to_tensor(np.asarray(seed), dtype="uint32")
self.state = self.backend.Variable(
seed_initializer,
shape=(2,),
dtype="uint32",
trainable=False,
name="seed_generator_state",
)
def split_seed(self, seed_state):
return jax.random.split(seed_state)
def next(self):
seed_state = jnp.array(backend.convert_to_numpy(self.state), dtype="uint32")
seed_state, seed_sub_state = self.split_seed(seed_state)
self.state.assign(seed_sub_state)
return seed_sub_state
def draw_seed(seed):
from keras_core.backend import convert_to_tensor
if isinstance(seed, SeedGenerator):
return seed.next()
elif isinstance(seed, int):
seed = jax.random.PRNGKey(seed)
return SeedGenerator(seed=seed, dtype="uint32")
elif seed is None:
return global_seed_generator().next()
raise ValueError(
"Argument `seed` must be either an integer "
"or an instance of `SeedGenerator`. "
f"Received: seed={seed} (of type {type(seed)})"
) |
Thoughts @fchollet @martin-gorner @GallagherCommaJack ? |
That sounds good at a high level, but the |
@fchollet got it! I will make the changes accordingly, and will make a PR. Thanks for the pointers |
@fchollet I refactored the class SeedGenerator:
def __init__(self, seed=None, **kwargs):
custom_backend = kwargs.pop("backend", None)
if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
if custom_backend is not None:
self.backend = custom_backend
else:
self.backend = backend
if seed is None:
seed = backend.random.make_default_seed()
else:
seed = backend.random.make_initial_seed(seed)
if backend.backend() == "tensorflow":
seed_dtype = "int32"
else:
seed_dtype = "uint32"
self._initial_seed = seed
self.state = self.backend.Variable(
backend.convert_to_tensor(seed),
shape=tuple(seed.shape),
dtype=seed_dtype,
trainable=False,
name="seed_generator_state",
)
def next(self):
seed_state = backend.convert_to_tensor(self.state)
seed_state, seed_sub_state = backend.random.get_next_state(seed_state)
self.state.assign(seed_sub_state)
return seed_sub_state
def global_seed_generator():
gen = global_state.get_global_attribute("global_seed_generator")
if gen is None:
gen = SeedGenerator()
global_state.set_global_attribute("global_seed_generator", gen)
return gen
def global_rng_state():
return global_seed_generator().state
def draw_seed(seed):
from keras_core.backend import convert_to_tensor
if isinstance(seed, SeedGenerator):
return seed.next()
elif isinstance(seed, int):
return SeedGenerator(seed=seed).next()
elif seed is None:
return global_seed_generator().next()
raise ValueError(
"Argument `seed` must be either an integer "
"or an instance of `SeedGenerator`. "
f"Received: seed={seed} (of type {type(seed)})"
) One way to handle the differences in torch is to check the backend type in Please let me know what you think. I can make a PR and we can discuss the modification within the PR itself. Would be much easier for you to review and comment |
I have figured out a way to make everything work seamlessly. Will make a PR tomorrow |
Hi @GallagherCommaJack @fchollet, |
right now unseeded calls to e.g.
keras.random.uniform
are going to acquire static seeds at trace time. this has a few undesirable consequences:to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the
Scope
. such an approach is fairly complex however, and there might be simpler options e.g. a single globalrng
state, which gets included with the training state inmodel.fit
, unseeded rng calls would then do something along the lines ofThe text was updated successfully, but these errors were encountered: