Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

current rng setup is full of footguns in jax #18426

Open
GallagherCommaJack opened this issue Aug 1, 2023 · 28 comments
Open

current rng setup is full of footguns in jax #18426

GallagherCommaJack opened this issue Aug 1, 2023 · 28 comments
Labels
type:feature The user is asking for a new feature.

Comments

@GallagherCommaJack
Copy link
Contributor

right now unseeded calls to e.g. keras.random.uniform are going to acquire static seeds at trace time. this has a few undesirable consequences:

  1. subsequent calls will have the same randomness each time (e.g. dropout will have a fixed mask instead of random each step)
  2. the jax compiler cache will ~never hit, as the constant rng seed values will be different every time

to get around this, some kind of rng state management is necessary. flax does this with hierarchical management of rng's from the Scope. such an approach is fairly complex however, and there might be simpler options e.g. a single global rng state, which gets included with the training state in model.fit, unseeded rng calls would then do something along the lines of

state.seed, local_seed = jax.random.split(state.seed)
@GallagherCommaJack
Copy link
Contributor Author

@fchollet
Copy link
Member

fchollet commented Aug 1, 2023

Right -- random ops in Keras Core are basically always intended to be called with a SeedGenerator instance as the seed argument, since seed=None defaults to an integer seed, which is going to get back into the graph you're tracing. Integer seed only really work intuitively in eager mode.

This is a gotcha for sure. Maybe we can do something to resolve it.

Some considerations:

  • Is this a gotcha specifically in the seed=None case, or the more general seed=int case? I'm guessing the latter.
  • If we convert seed=None/int into a seed generator/variable in the background, we're going to need to be tracking it, since it needs to be updated by any train/test step function. Pretty easy to do for built-in methods, but the problem is that anyone writing custom training loops is going to have to be aware of that global state and take it into account.
  • Should there be one global seed variable (used in the seed=None case) shared by all unseeded random ops (with different values for each op) or should there be one seed variable per unseeded op?
  • Could an alternative be to disallow seed=None/int when tracing? It would only work in eager, and require you to pass your own SeedGenerator if you're tracing.

@GallagherCommaJack
Copy link
Contributor Author

Is this a gotcha specifically in the seed=None case, or the more general seed=int case? I'm guessing the latter.

both, the int needs to come from somewhere

If we convert seed=None/int into a seed generator/variable in the background, we're going to need to be tracking it, since it needs to be updated by any train/test step function. Pretty easy to do for built-in methods, but the problem is that anyone writing custom training loops is going to have to be aware of that global state and take it into account.

if someone is writing a fully custom training loop in jax they will either be aware of this or immediately shoot their foot off.
that said, there's probably some way to progressively disclose rng handling eg have a "base train step" that in jax increments the rng counter

Should there be one global seed variable (used in the seed=None case) shared by all unseeded random ops (with different values for each op) or should there be one seed variable per unseeded op?

I don't quite understand the distinction here... maybe you could write some pseudocode to clarify it?

Could an alternative be to disallow seed=None/int when tracing? It would only work in eager, and require you to pass your own SeedGenerator if you're tracing

ideally we would not require the user to pass rng info all the way through model init in jax but nowhere else (and we cannot assume that model init happens in eager mode, in distributed training there isn't necessarily enough space on any given accelerator to support that)

@AakashKumarNain
Copy link
Contributor

If we convert seed=None/int into a seed generator/variable in the background, we're going to need to be tracking it, since it needs to be updated by any train/test step function. Pretty easy to do for built-in methods, but the problem is that anyone writing custom training loops is going to have to be aware of that global state and take it into account.

Yes but isn't that how we expect it? Anyone who writes JAX is aware of the PRNG handling, and the consequences of not handling it properly. Anything with the PRNG should be explicit rather than implicit. It also makes debugging easy

@fchollet
Copy link
Member

fchollet commented Aug 1, 2023

The difficulty is that there will be some reference to a RNG seed variable that you'll have to take into account, something like --

def fn(variables):
    trainable_variables = ...
    non_trainable_variables = ...
    return (trainable_variables, non_trainable_variables, keras.random.global_rng_seed())

variables = (trainable_variables, non_trainable_variables, keras.random.global_rng_seed())
fn(variables)

Anyone writing custom training loops is going to need to know about this API. If they forget it -- well, back to the current status quo, which is that their unseeded RNG calls are unchanged across iterations. Maybe not that bad if you think of it like that.

@GallagherCommaJack
Copy link
Contributor Author

Anyone writing custom training loops is going to need to know about this API. If they forget it -- well, back to the current status quo, which is that their unseeded RNG calls are unchanged across iterations. Maybe not that bad if you think of it like that.

It seems like a clear improvement over the current scenario, where you not only have to remember to do this, but it's not even suggested by the default APIs, which make it very difficult to do this.

@martin-gorner
Copy link
Contributor

+1 to proper RNG management. This is extremely important to JAX users. JAX offers good reproducibility out of the box (with some RNG learning curve for the user). It's fine is Keras can simplify the API with automatic jax.random.split(s) in the right places but "reproducibility out of the box" should remain.

@fchollet
Copy link
Member

fchollet commented Aug 2, 2023

I looked at this more closely. What I can propose is this:

  1. Unseeded random ops use a global SeedGenerator
  2. The state of the global seed generator (KerasVariable of size (2,)) is accessible via keras.random.global_rng_state()
  3. You can update the state any time you want via assign or assign_add, just like any other Keras variable.

So you can manage your unseeded random ops calls like this:

@jax.jit
def jitted_random_numbers(seed):
    rng_state = keras.random.global_rng_state()
    rng_state.assign(seed)
    x = keras.random.normal(...)
    y = keras.random.uniform(...)
    return x, y, rng_state.value

You could even have something a bit more intuitive like this:

@jax.jit
def jitted_random_numbers(seed):
    keras.random.set_global_rng_state(seed)
    x = keras.random.normal(...)
    y = keras.random.uniform(...)
    return x, y, keras.random.global_rng_state().value  # the .value won't be necessary in a future JAX version

The default behavior would be unchanged from now (unseeded random op calls are deterministic per traced function execution).

Does that work?

+1 to proper RNG management. This is extremely important to JAX users

We have that already, via the SeedGenerator class. But it does require that you seed your random op calls with a SeedGenerator. The general pattern is this:

class RandomLayer(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1234)

    def call(self, x):
         return x + keras.random.normal(..., seed=self.seed_generator)

layer = RandomLayer()  # layer.non_trainable_variables includes the seed
non_trainable_vars = layer.non_trainable_variables
outputs, non_trainable_vars = layer.stateless_call(non_trainable_vars, x)  # This returns the updated seed value

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 2, 2023

The default behavior would be unchanged from now

If this results in Dropout layer silently not working (same dropout mask at each invocation), then this is not a good solution. You suggested catching this with an error in the compiled case. That sounds right. How doable is it?

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 2, 2023

Wouldn't it be better to handle the rng at the model level, as part of the model state?

The standard training loop with dropout in JAX usually looks like this (omitting non-trainable variables for simplicity but not RNGs:

# rng_key is an instance of jax.random.PRNGKey(seed)
# trainable_vars must be the first param as jax.grad differentiates agains the first param only
def stateless_loss(trainable_vars, rng_key, x,y): 
  y_pred = model.apply(trainable_vars, rng_key, x) # rng_key necessary in .apply or error
  loss_val = loss_fn(y, y_pred)
  return loss_val

stateless_grads = jax.grad(stateless_loss)

state.rng_key = jax.random.PRNGKey(0)
for x, y in dataset:
  # training step, could be a jitted function
  loss_val, grads = stateless_grads(state.trainable_vars, state.rng_key, x,y)
  updates, state.optimizer_state = optimizer.update(grads, state.optimizer_state)
  state.trainable_vars = optax.apply_updates(state.trainable_vars, updates)

  # new rng key so that dropout gives a different value on next iteration
  state.rng_key, _ = jax.random.split(state.rng_key)

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 2, 2023

Could Keras have something very similar ? Like this:

model.build(data, seed) # internally creates RNG keys for layers that
                        # need them (mostly dropout), stores them in model.rng_keys
state = (model.trainable_variables, model.non_trainable_variables, model.rng_keys, optimizer.variables)

# ignoring model.non_trainable_variables again for simplicity

def stateless_loss(trainable_vars, rng_keys, x,y): 
  y_pred = model.stateless_call(trainable_vars, rng_keys, x) # rng_keys necessary in call or error
  loss_val = loss_fn(y, y_pred)
  return loss_val

stateless_grads = jax.grad(stateless_loss)

state.rng_key = jax.random.PRNGKey(0)
for x, y in dataset:
  # training step, could be a jitted function
  loss_val, grads = stateless_grads(state.trainable_vars, state.rng_keys, x,y)
  updates, state.optimizer_state = optimizer.update(grads, state.optimizer_state)
  state.trainable_vars = optimizer.stateless_apply(state.trainable_vars, updates)

  # new rng keys so that dropouts gives a different value on next iteration
  state.rng_key[0], _ = jax.random.split(state.rng_key[0])
  state.rng_key[1], _ = jax.random.split(state.rng_key[1])
  # etc, there could be convenience function for this: model.advance_all_rng_keys()

@martin-gorner
Copy link
Contributor

Compared to your proposal above, I'd like to build my dropout-like layers with SeedGenerator() without params and still be able to assign a seed to all these layers at once though the Model abstraction.

Same thing for weight initializers, btw.

@fchollet
Copy link
Member

fchollet commented Aug 2, 2023

That's roughly how it works, except it's actually much simpler and more intuitive.

  • You don't need to work about new types like PRNGKey. RNG state is a regular variable and it's part of non_trainable_variables like any other non-trainable element of state of the model.
  • You don't need to manually update your RNG state like this, it gets updated inside the model automatically. You just need to do outputs, non_trainable_variables = model.stateless_call(trainable_variables, non_trainable_variables, inputs) in a loop.

@AakashKumarNain
Copy link
Contributor

You don't need to manually update your RNG state like this, it gets updated inside the model automatically. You just need to do outputs, non_trainable_variables = model.stateless_call(trainable_variables, non_trainable_variables, inputs) in a loop.

This sounds like a clean solution to me. The only thing that I would expect in this scenario is to a way to access the random state at any step to ensure that the rng is being handled properly. That way I can validate that everything on the model side is working as expected

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 3, 2023

OK, that sounds good. Three questions:

  1. Looking at the current implementation, is this how you should set seeds for deterministic training ?
  • for Dropout: keras.layers.Dropout(seed=123). SeedGenerator is built in so that each invocation gives a different dropout mask.
  • for a layer with weights is keras.layers.Dense(kernel_initializer=keras.initializers.RandomNormal(seed=123)). SeedGenerator not built in but that seems to be what users wants: the same random weights at each initialization.
  1. What happens for people using just keras.layers.Dense()? Will this result in a different initialization at each instantiation? Is there a way to control the seed of all the initializers in a model at once without using kernel_initializer= and instance_initializer= in every layer explicitly? Maybe through keras.random.set_global_rng_state(seed)?

  2. Looking at the implementation of SeedGenerator.next(), I see this line:
    self.state.assign((seed_state + 1) * 5387 % 933199)
    Shouldn't this be something involving the platform-specific random split APIs like tf.random.split or jax.random.split ?
    According to JAX docstrings, the theory behind the "split" mechanism is in this article. I have not read it yet.

@AakashKumarNain
Copy link
Contributor

Also, jax.random.split is purely deterministic, I am not sure about tf.random.split

@fchollet
Copy link
Member

fchollet commented Aug 3, 2023

for Dropout: keras.layers.Dropout(seed=123). SeedGenerator is built in so that each invocation gives a different dropout mask.

Yes

for a layer with weights is keras.layers.Dense(kernel_initializer=keras.initializers.RandomNormal(seed=123)). SeedGenerator not built in but that seems to be what users wants: the same random weights at each initialization.

Initializers are only meant to be called once, and integer-seeded initializers always return the same value, just like integer-seeded ops.

What happens for people using just keras.layers.Dense()? Will this result in a different initialization at each instantiation?

Yes

Is there a way to control the seed of all the initializers in a model at once without using kernel_initializer= and instance_initializer= in every layer explicitly?

Call keras.utils.set_random_seed(1337) at the start of your program. This provided full determinism, minus backend op level (certain GPU kernels) indeterminism which is handled differently by each framework (e.g. TF tf.config.experimental.enable_op_determinism()).

Shouldn't this be something involving the platform-specific random split APIs

No need. But if you do want to manage your random seed sequence yourself via whatever algorithm of your choice, you have that option (you can just subclass SeedGenerator and do your own thing, then pass your custom class instance around to your layers).

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 4, 2023

OK, thank you for your answers. For droput-style random layers as well as weight initializers, this looks good. I especially like that the the same setup (i.e. seed=123) is the correct one in both case.

As for RNG splitting? I think the problem to solve is rng determinism, even in a distributed setting where execution order is not fully deterministic. The theory of why it is needed seems complicated (ref) and math-intensive. I don't have the full background so I will accept your conclusion that this mechanism is "not needed" at face value.

I fear however that most users will prefer relying on the standard "split" mechanism implemented in TF and JAX rather than investing the time to analyze and be convinced by the assertion that the mechanism is useless. If you have a proof of this assertion, please put it forward, but consider the difficulty of then communicating it to all users and convincing them.

@AakashKumarNain
Copy link
Contributor

AakashKumarNain commented Aug 5, 2023

@fchollet I looked at the SeedGenerator again, and I have a few more doubts now. Apologies for so many questions but I think this is a very critical aspect to discuss.

  • JAX PRNG is based on Threefry hash function. Does SeedGenerator implementation follow the same?
  • If the answer to the above question is No, then I guess we need a series of tests like:
    • Can we reproduce the same splits every time when using the same seed in a multi-device setting?
    • Does the current implementation allow JAX to escape sequential execution order constraint? This is one of the most important aspects for parallel execution without compromising reproducibility

I fear however that most users will prefer relying on the standard "split" mechanism implemented in TF and JAX rather than investing the time to analyze and be convinced by the assertion that the mechanism is useless.

I totally agree with @martin-gorner here. It's very hard to convince the users to validate another PRNG implementation in their daily workflow.

A simple thing to do would be to leverage JAX PRNG implementation in the SeedGenerator. Depending on the backend type, we should cast the rng accordingly, and use it everywhere. This reduces the burden of testing another pseudo random generator. What do you think?

@fchollet
Copy link
Member

fchollet commented Aug 5, 2023

Sure, I'm open to having a split_seed backend op of some kind. Then we can use it in SeedGenerator.next().

@AakashKumarNain
Copy link
Contributor

AakashKumarNain commented Aug 6, 2023

Thank you. I guess we can refactor our seed generator like this:

class SeedGenerator:
    def __init__(self, seed=None, **kwargs):
        if seed is None:
            seed = jax.random.PRNGKey(make_default_seed())
        self._initial_seed = seed

        def seed_initializer(*args, **kwargs):
            return self.backend.convert_to_tensor(np.asarray(seed), dtype="uint32")

        self.state = self.backend.Variable(
            seed_initializer,
            shape=(2,),
            dtype="uint32",
            trainable=False,
            name="seed_generator_state",
        )

    def split_seed(self, seed_state):
        return jax.random.split(seed_state)

    def next(self):
        seed_state = jnp.array(backend.convert_to_numpy(self.state), dtype="uint32")
        seed_state, seed_sub_state = self.split_seed(seed_state)
        self.state.assign(seed_sub_state)
        return seed_sub_state


def draw_seed(seed):
    from keras_core.backend import convert_to_tensor

    if isinstance(seed, SeedGenerator):
        return seed.next()
    elif isinstance(seed, int):
        seed = jax.random.PRNGKey(seed)
        return SeedGenerator(seed=seed, dtype="uint32")
    elif seed is None:
        return global_seed_generator().next()
    raise ValueError(
        "Argument `seed` must be either an integer "
        "or an instance of `SeedGenerator`. "
        f"Received: seed={seed} (of type {type(seed)})"
    )

@AakashKumarNain
Copy link
Contributor

Thoughts @fchollet @martin-gorner @GallagherCommaJack ?

@AakashKumarNain
Copy link
Contributor

AakashKumarNain commented Aug 9, 2023

Good news! I just tested the suggestions I made above for the SeedGenerator class, and I am sure this works perfectly! 🕺
Let me know your thoughts. I will make a PR accordingly

Here are the results from a model I had:

First run:

Screenshot 2023-08-09 at 4 48 39 PM

Second run:

Screenshot 2023-08-09 at 4 45 29 PM

Btw these are the results with JAX backend, model trained on 2 GPUs. The only thing that I am worried about is the lack of support of uint32 in TF. I am not sure if casting back to int32 can create a problem somehow

@fchollet
Copy link
Member

fchollet commented Aug 9, 2023

Thank you. I guess we can refactor our seed generator like this:

That sounds good at a high level, but the split_seed method should be a backend function instead, with a different implementation in each backend. We should also not have any reference to PRNGKey outside of the JAX backend.

@AakashKumarNain
Copy link
Contributor

@fchollet got it! I will make the changes accordingly, and will make a PR. Thanks for the pointers

@AakashKumarNain
Copy link
Contributor

@fchollet I refactored the SeedGenerator class and made changes in the backend. For TF, and JAX we have a very uniform implementation but torch is a bit problematic. Why? Because torch changes the state of the generator implicitly as it consumes the bits of the torch.Generator(..) instance. Here is the simplified version now:

class SeedGenerator:
    def __init__(self, seed=None, **kwargs):
        custom_backend = kwargs.pop("backend", None)
        if kwargs:
            raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
        if custom_backend is not None:
            self.backend = custom_backend
        else:
            self.backend = backend

        if seed is None:
            seed = backend.random.make_default_seed()
        else:
            seed = backend.random.make_initial_seed(seed)
        
        if backend.backend() == "tensorflow":
            seed_dtype = "int32"
        else:
            seed_dtype = "uint32"
        
        self._initial_seed = seed
        self.state = self.backend.Variable(
            backend.convert_to_tensor(seed),
            shape=tuple(seed.shape),
            dtype=seed_dtype,
            trainable=False,
            name="seed_generator_state",
        )

    def next(self):
        seed_state = backend.convert_to_tensor(self.state)
        seed_state, seed_sub_state = backend.random.get_next_state(seed_state)
        self.state.assign(seed_sub_state)
        return seed_sub_state


def global_seed_generator():
    gen = global_state.get_global_attribute("global_seed_generator")
    if gen is None:
        gen = SeedGenerator()
        global_state.set_global_attribute("global_seed_generator", gen)
    return gen


def global_rng_state():
    return global_seed_generator().state


def draw_seed(seed):
    from keras_core.backend import convert_to_tensor
        

    if isinstance(seed, SeedGenerator):
        return seed.next()
    elif isinstance(seed, int):
        return SeedGenerator(seed=seed).next()
    elif seed is None:
        return global_seed_generator().next()
    raise ValueError(
        "Argument `seed` must be either an integer "
        "or an instance of `SeedGenerator`. "
        f"Received: seed={seed} (of type {type(seed)})"
    )

One way to handle the differences in torch is to check the backend type in SeedGenerator class and modify the returned values accordingly. For example, in TF and JAX, we get seed_state, seed_sub_state when we call next but in torch it would return the torch generator object, along with the current state.

Please let me know what you think. I can make a PR and we can discuss the modification within the PR itself. Would be much easier for you to review and comment

@AakashKumarNain
Copy link
Contributor

I have figured out a way to make everything work seamlessly. Will make a PR tomorrow

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@aaarrti
Copy link
Contributor

aaarrti commented Oct 19, 2023

Hi @GallagherCommaJack @fchollet,
do you plan to resume work on this issue? What is the recommended way of handling RNGs?
I believe it would be much appreciated, if you could add a documentation (or tutorial) covering this topic 🙃.

@sachinprasadhs sachinprasadhs added the type:feature The user is asking for a new feature. label Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants