Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Seed Generator #704

Closed
wants to merge 11 commits into from

Conversation

AakashKumarNain
Copy link
Collaborator

@AakashKumarNain AakashKumarNain commented Aug 12, 2023

This PR is an attempt to fix keras-team/keras#18426 A few notable things:

  1. Right now, I have kept 42 as the default seed so that everyone can test it easily on their side with the same setup. Will replace it with a random number once PR is approved
  2. torch is a pain to work with when it comes to random generation. Why? Because it consumes and updates the state of the generator implicitly. Had it followed the same algo as TF and JAX, we could have avoided a couple of if-else conditions
  3. I have tested the workflow on my end with all the backends with random ops, and layers. It works flawlessly! I will update the tests later on.

Here is an overview of how the seed generation works now:

Random ops

Scenario 1: No seed value is provided

samples = ops.random.randint(shape=(1,), minval=0, maxval=10)
  1. Once this random op is called, it will call the draw_seed(...) function.
  2. It will check if we already have a global seed generator. If yes, then we will use call the next(...) method to get the next state of the global generator and draw samples from this updated rng state
  3. If there is no global seed generator yet, It will create an instance of SeedGenerator() with seed=None, which will initialize the seed state with 42, and will set the global seed generator to this instance
  4. Subsequent calls of unseeded random ops will use the next state of this global seed generator

Scenario 2: seed value is provided

samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1)

1. Once this random op is called, it will call the draw_seed(...) function.
2. If the seed is an int, it will create a SeedGenerator instance by calling the make_initial_seed(...) function.
3. Or you can instantiate SeedGenerator and pass to it.
4. The function then will use the seed value produced by calling the next(..) method in draw_seed(...) function

  1. Once this random op is called, it will call the draw_seed(...) function.
  2. The draw_seed() function will call the make_initial_seed(seed) function that expects an integer argument. Only integer arguments are supported to make the behavior consistent across backends.
  3. This is a stateless call, meaning if you call the function with the same seed, you will get the same results everytime.

Layers

Scenario 1: No seed value is provided

  1. It will check if a global seed generator is there. If yes, will draw seed by calling next(...)
  2. If no global seed generator is present, it will create one with 42 as the initial state. And it will then draw seed from this global seed generator on subsequent calls.

Scenario 2: Seed value is provided

Will instantiate a SeedGenerator, and subsequent calls will draw seed from this generator.

Note: The SeedGenerator class now depends on the backend specific state updates of the rng. In case of TF and JAX we will split the state at every call while in torch the state is consumed implicitly. The next() method in the SeedGenerator is now fully deterministic and depends on the algos used by the backend for managing state of the rng

Though I don't think I have missed anything in testing this on my end, it will be good if someone can confirm the same on their end. cc: @fchollet @martin-gorner @GallagherCommaJack

@fchollet
Copy link
Member

Thanks for the PR!

Scenario 2: Seed value is provided

This is not intended behavior -- if you do ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1) you should get the same value at every call. It should not be stateful. It would be very counter-intuitive if it were stateful and the same seed gave you different results at every call.

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 12, 2023

Thanks for the PR!

Scenario 2: Seed value is provided

This is not intended behavior -- if you do ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1) you should get the same value at every call. It should not be stateful. It would be very counter-intuitive if it were stateful and the same seed gave you different results at every call.

@fchollet Sorry for the confusion. It does exactly what you expect. On every call ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1) will give you the same result. Here is an example to show the same:

for _ in range(5):
    samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1)
    print(samples)

# output
[4]
[4]
[4]
[4]
[4]

#########

for _ in range(5):
    samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=SeedGenerator(1))
    print(samples)

# output
[4]
[4]
[4]
[4]
[4]

You will get different values only if are making subsequent calls to the same function. In this case, instead of a global generator, we have a local rng for this function and we are consuming it's state when the same rng(not the seed value) is used over subsequent calls. e.g.

s = SeedGenerator(1)
for _ in range(5):
    samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=s)
    print(samples)
    
# output
[4]
[0]
[1]
[9]
[9]

@fchollet
Copy link
Member

You will get different values only if are making subsequent calls to the same function

This is what I mean. The behavior of your code should not change if you put it in a traced function. In this case, ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1) should return the same value regardless of the context, including in the case where it's a traced function that is called multiple times. Basically, we should keep the current RNG seed behavior.

@AakashKumarNain
Copy link
Collaborator Author

Okay, I will refactor the code. But clearing it again:

ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1)
The above will return the same result now matter how many times you call it

ops.random.randint(shape=(1,), minval=0, maxval=10, seed=SeedGenerator(1))
The above will return the same result now matter how many times you call it

Re: The remaining case where we are passing the update rng state to a seeded function like this

s = SeedGenerator(1)
for _ in range(3):
    samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=s)
    print(samples)
    
# output
[4]
[0]
[1]

This behavior is consistent with the below workflow

gen = torch.Generator().manual_seed(1)
for _ in range(3):
    samples = torch.randint(low=0, high=10, size=(1,), generator=gen)
    print(samples)

# output
tensor([5])
tensor([9])
tensor([4])

###################

gen = tf.random.Generator.from_seed(1)
seed = tf.cast(gen.make_seeds(2)[0], tf.int32)
for _ in range(3):
    seed, subseed = tf.random.split(seed, 2)
    samples = tf.random.stateless_uniform(minval=0, maxval=10, shape=(1,), dtype=tf.int32, seed=seed)
    print(samples)

# output
tf.Tensor([0], shape=(1,), dtype=int32)
tf.Tensor([5], shape=(1,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)

########################

key = jax.random.PRNGKey(1)
for _ in range(3):
    key, subkey = jax.random.split(key)
    samples = jax.random.randint(key, minval=0, maxval=10, shape=(1,))
    print(samples)

# output
[5]
[9]
[1]

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 13, 2023

@fchollet I have made the changes. Random functions are now stateless while layers depending on the rng state are stateful and uses SeedGenerator. I have edited the note at the top with the detailed workflow. Please let me know if you want to change anything else. At this point

ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1)
This will always return the same result regardless of the context, and

ops.random.randint(shape=(1,), minval=0, maxval=10, seed=None)
The results here will depend on the state of the global random generator

The main change is only for the SeedGenerator class which now uses the backend specific implementation for obtaining the next state, and is fully deterministic

return None, jax.random.PRNGKey(seed=seed)


def get_next_state(seed):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should have "seed" somewhere in the name since it's part of the same group of functions, e.g. get_next_seed_state()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rename it.

else:
return draw_seed(seed)


def make_default_seed():
return None, jax.random.PRNGKey(42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default seed should not be fixed, it should be drawn from Python random

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. I kept if just for the sake of demonstration. Will replace it with a random selection

keras_core/backend/tensorflow/random.py Outdated Show resolved Hide resolved
def seed_initializer(*args, **kwargs):
dtype = kwargs.get("dtype", None)
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)
if backend.backend() == "tensorflow":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Can't we use uint32 everywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF random generator use int32/int64, and with XLA only int32 is allowed. Casting it to uint32 can cause some unintended consequences, and will be hard to debug

)
self.state.assign(seed_state + increment)
def next(self):
if backend.backend() == "torch":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the torch specific code to get_next_seed_state?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me elaborate why is it coded this way.

  1. In TF/JAX, when we get the next state by splitting the seed, we get two states. We assign one of these states to the state of the SeedGenerator, and we also pass the same state to the random functions
  2. In torch, we don't have a way to control the random state explicitly. We keep passing the same gen object to random functions, and they consume and update the state of the gen implicitly. If we move this to torch backend, even then we would be required to have this condition because for toech we will be getting state, gen compared to state, sub_state in TF/JAX

keras_core/random/seed_generator.py Show resolved Hide resolved
@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 14, 2023

Even though I have made changes to the seed generator class, this won't work for jit compiled TF because compiled TF function supports only singleton variable creation, and here we are returning a tensor from the generator

Similarly there is a weird issue on JAX side where it keeps complaining about the dtype of PRNGKey TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=uint32. Also, having a global generator with JAX doesn't make sense because it will create a side effect when jitted

Our only options are:

  1. Either keep the current SeedGenerator class as it is. But for the JAX users, the current design doesn't guarantee the reproducibility that comes with JAX PRNG design (unless we have some proof for it).
  2. The other option is to have a SeedGenerator class for each backend. That we, we can over come a lot of these issues. Also, for JAX backend, we need to raise an issue whenever seed=None is passed instead of using a global state

I will refactor it one more time to see if we can bypass these issues

@fchollet
Copy link
Member

fchollet commented Aug 14, 2023

Also, for JAX backend, we need to raise an issue whenever seed=None is passed instead of using a global state

This is too general -- I think specifically we should detect when the global seed generator is being used in a tracing context in JAX, and throw an error at that time. It is fine to use seed=None in JAX when running eagerly.

The other option is to have a SeedGenerator class for each backend. That we, we can over come a lot of these issues.

If we are disabling tracing + seed=None in JAX, what are the remaining benefits of having backend specific generators? Just the seed-splitting algo? I am fundamentally not worried about it. The current setup is reproducible and works well. If you are in a multi worker setting, you know that the seed on the worker deterministically depends on the Python RNG seed state, so you can chose to seed random with the same seed on each worker if you want the same random numbers getting generator, or seed each worker with e.g. worker_id to get different random behavior on each worker.

@AakashKumarNain
Copy link
Collaborator Author

AakashKumarNain commented Aug 15, 2023

It is fine to use seed=None in JAX when running eagerly.

There are two core issues with this approach:

  1. It diverges from the behvior of the backend. If no seed is passed, we should raise the same error as raised byJAX irrespective of the execution mode.
  2. It changes the mental model for the backend. The idea of passing a seed in random functions in JAX, and that too not a global seed generator, is about controlling the entropy explicitly. If we don't mimic the same behavior, we then become a "specialized" implementation of JAX backend.

what are the remaining benefits of having backend specific generators?

Listing a few of them:

  • The PRNG design in JAX makes it easy to escape the sequential execution constraint, making it parallelizable and reproducible at the same time. To illustrate, let's take a simple example:
def random_letter():
    return random.choice(["a", "b", "c"])

def random_number():
   return random.randint(0, 10, shape=())

res = random_letter() + str(random_number())

Both the random functions in the above example are independent of each other, meaning that we should be able to run them in any order. We don't know which function will be executed by the compiler first. If a global state is used, the result will depend on which function was executed first making it completely irreproducible. Even if we use a stateful seed here, it introduces another sequential dependency. Both of these things don't align with the XLA semantics used by JAX.

  • Though you can use any counter based PRNG algo, but the only way to ensure that there is no sequential threading, and that we are aligned with the functional programming is to have a splittable PRNG instead of a linear one. Here is the paper for the same.

One of the ways to use and update state in layers like Dropout is this:

class SeedGenerator:
    def __init__(self, seed):
        self.seed = seed
    
    @jax.jit
    def get_next_state(self):
        state, sub_state = jax.random.split(self.seed)
        self.seed = state
        return self

    def _tree_flatten(self):
        children = (self.seed,)
        aux_data = {}
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children)

# Need to register this as a pytree
tree_util.register_pytree_node(SeedGenerator,
                           SeedGenerator._tree_flatten,
                           SeedGenerator._tree_unflatten)

def _get_concrete_noise_shape(inputs, noise_shape):
    if noise_shape is None:
        return inputs.shape

    concrete_inputs_shape = inputs.shape
    concrete_noise_shape = []
    for i, value in enumerate(noise_shape):
        concrete_noise_shape.append(
            concrete_inputs_shape[i] if value is None else value
        )
    return concrete_noise_shape


def dropout(seed, inputs, rate, noise_shape=None):
    state = seed.get_next_state()
    keep_prob = 1.0 - rate
    noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
    mask = jax.random.bernoulli(state.seed, p=keep_prob, shape=noise_shape)
    mask = jax.numpy.broadcast_to(mask, inputs.shape)
    out = jax.lax.select(mask, inputs / keep_prob, jax.numpy.zeros_like(inputs))
    return out, state

class DropoutLayer:
    def __init__(self, drop_rate=0.5, noise_shape=None):
        self.drop_rate = drop_rate
        self.noise_shape = noise_shape
    
    @partial(jax.jit, static_argnums=(0,))
    def __call__(self, inputs, seed):
        outputs, state = dropout(seed, inputs, self.drop_rate, self.noise_shape)
        return outputs, state

inp = jnp.ones(shape=(1, 3), dtype=jnp.float32)
seed = SeedGenerator(jax.random.PRNGKey(42))
layer = DropoutLayer()

for i in range(5):
    out, seed = layer(inp, seed)
    print(out)

##output:
[[0. 0. 0.]]
[[2. 2. 2.]]
[[0. 2. 2.]]
[[2. 0. 0.]]
[[2. 2. 0.]]
​

cc: @mattjj who can provide some more ideas on the same

@fchollet fchollet closed this Sep 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

current rng setup is full of footguns in jax
2 participants