-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Seed Generator #704
Conversation
Thanks for the PR!
This is not intended behavior -- if you do |
@fchollet Sorry for the confusion. It does exactly what you expect. On every call for _ in range(5):
samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=1)
print(samples)
# output
[4]
[4]
[4]
[4]
[4]
#########
for _ in range(5):
samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=SeedGenerator(1))
print(samples)
# output
[4]
[4]
[4]
[4]
[4] You will get different values only if are making subsequent calls to the same function. In this case, instead of a global generator, we have a local rng for this function and we are consuming it's state when the same rng(not the seed value) is used over subsequent calls. e.g. s = SeedGenerator(1)
for _ in range(5):
samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=s)
print(samples)
# output
[4]
[0]
[1]
[9]
[9] |
This is what I mean. The behavior of your code should not change if you put it in a traced function. In this case, |
Okay, I will refactor the code. But clearing it again:
Re: The remaining case where we are passing the update rng state to a seeded function like this s = SeedGenerator(1)
for _ in range(3):
samples = ops.random.randint(shape=(1,), minval=0, maxval=10, seed=s)
print(samples)
# output
[4]
[0]
[1] This behavior is consistent with the below workflow gen = torch.Generator().manual_seed(1)
for _ in range(3):
samples = torch.randint(low=0, high=10, size=(1,), generator=gen)
print(samples)
# output
tensor([5])
tensor([9])
tensor([4])
###################
gen = tf.random.Generator.from_seed(1)
seed = tf.cast(gen.make_seeds(2)[0], tf.int32)
for _ in range(3):
seed, subseed = tf.random.split(seed, 2)
samples = tf.random.stateless_uniform(minval=0, maxval=10, shape=(1,), dtype=tf.int32, seed=seed)
print(samples)
# output
tf.Tensor([0], shape=(1,), dtype=int32)
tf.Tensor([5], shape=(1,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)
########################
key = jax.random.PRNGKey(1)
for _ in range(3):
key, subkey = jax.random.split(key)
samples = jax.random.randint(key, minval=0, maxval=10, shape=(1,))
print(samples)
# output
[5]
[9]
[1] |
@fchollet I have made the changes. Random functions are now stateless while layers depending on the rng state are stateful and uses
The main change is only for the |
keras_core/backend/jax/random.py
Outdated
return None, jax.random.PRNGKey(seed=seed) | ||
|
||
|
||
def get_next_state(seed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should have "seed" somewhere in the name since it's part of the same group of functions, e.g. get_next_seed_state()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will rename it.
keras_core/backend/jax/random.py
Outdated
else: | ||
return draw_seed(seed) | ||
|
||
|
||
def make_default_seed(): | ||
return None, jax.random.PRNGKey(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default seed should not be fixed, it should be drawn from Python random
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know. I kept if just for the sake of demonstration. Will replace it with a random selection
def seed_initializer(*args, **kwargs): | ||
dtype = kwargs.get("dtype", None) | ||
return self.backend.convert_to_tensor([seed, 0], dtype=dtype) | ||
if backend.backend() == "tensorflow": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Can't we use uint32
everywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF random generator use int32/int64, and with XLA only int32 is allowed. Casting it to uint32 can cause some unintended consequences, and will be hard to debug
) | ||
self.state.assign(seed_state + increment) | ||
def next(self): | ||
if backend.backend() == "torch": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the torch specific code to get_next_seed_state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me elaborate why is it coded this way.
- In TF/JAX, when we get the next state by splitting the seed, we get two states. We assign one of these states to the state of the
SeedGenerator
, and we also pass the same state to the random functions - In torch, we don't have a way to control the random state explicitly. We keep passing the same gen object to random functions, and they consume and update the state of the gen implicitly. If we move this to torch backend, even then we would be required to have this condition because for toech we will be getting
state, gen
compared tostate, sub_state
in TF/JAX
Even though I have made changes to the seed generator class, this won't work for jit compiled TF because compiled TF function supports only singleton variable creation, and here we are returning a tensor from the generator Similarly there is a weird issue on JAX side where it keeps complaining about the dtype of PRNGKey Our only options are:
I will refactor it one more time to see if we can bypass these issues |
This is too general -- I think specifically we should detect when the global seed generator is being used in a tracing context in JAX, and throw an error at that time. It is fine to use
If we are disabling tracing + |
There are two core issues with this approach:
Listing a few of them:
def random_letter():
return random.choice(["a", "b", "c"])
def random_number():
return random.randint(0, 10, shape=())
res = random_letter() + str(random_number()) Both the random functions in the above example are independent of each other, meaning that we should be able to run them in any order. We don't know which function will be executed by the compiler first. If a global state is used, the result will depend on which function was executed first making it completely irreproducible. Even if we use a stateful seed here, it introduces another sequential dependency. Both of these things don't align with the XLA semantics used by JAX.
One of the ways to use and update state in layers like class SeedGenerator:
def __init__(self, seed):
self.seed = seed
@jax.jit
def get_next_state(self):
state, sub_state = jax.random.split(self.seed)
self.seed = state
return self
def _tree_flatten(self):
children = (self.seed,)
aux_data = {}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children)
# Need to register this as a pytree
tree_util.register_pytree_node(SeedGenerator,
SeedGenerator._tree_flatten,
SeedGenerator._tree_unflatten)
def _get_concrete_noise_shape(inputs, noise_shape):
if noise_shape is None:
return inputs.shape
concrete_inputs_shape = inputs.shape
concrete_noise_shape = []
for i, value in enumerate(noise_shape):
concrete_noise_shape.append(
concrete_inputs_shape[i] if value is None else value
)
return concrete_noise_shape
def dropout(seed, inputs, rate, noise_shape=None):
state = seed.get_next_state()
keep_prob = 1.0 - rate
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
mask = jax.random.bernoulli(state.seed, p=keep_prob, shape=noise_shape)
mask = jax.numpy.broadcast_to(mask, inputs.shape)
out = jax.lax.select(mask, inputs / keep_prob, jax.numpy.zeros_like(inputs))
return out, state
class DropoutLayer:
def __init__(self, drop_rate=0.5, noise_shape=None):
self.drop_rate = drop_rate
self.noise_shape = noise_shape
@partial(jax.jit, static_argnums=(0,))
def __call__(self, inputs, seed):
outputs, state = dropout(seed, inputs, self.drop_rate, self.noise_shape)
return outputs, state
inp = jnp.ones(shape=(1, 3), dtype=jnp.float32)
seed = SeedGenerator(jax.random.PRNGKey(42))
layer = DropoutLayer()
for i in range(5):
out, seed = layer(inp, seed)
print(out)
##output:
[[0. 0. 0.]]
[[2. 2. 2.]]
[[0. 2. 2.]]
[[2. 0. 0.]]
[[2. 2. 0.]]
cc: @mattjj who can provide some more ideas on the same |
This PR is an attempt to fix keras-team/keras#18426 A few notable things:
42
as the default seed so that everyone can test it easily on their side with the same setup. Will replace it with a random number once PR is approvedtorch
is a pain to work with when it comes to random generation. Why? Because it consumes and updates the state of the generator implicitly. Had it followed the same algo as TF and JAX, we could have avoided a couple ofif-else
conditionsHere is an overview of how the seed generation works now:
Random ops
Scenario 1: No seed value is provided
draw_seed(...)
function.next(...)
method to get the next state of the global generator and draw samples from this updated rng stateSeedGenerator()
withseed=None
, which will initialize the seed state with42
, and will set the global seed generator to this instanceScenario 2: seed value is provided
1. Once this random op is called, it will call thedraw_seed(...)
function.2. If the seed is anint
, it will create aSeedGenerator
instance by calling themake_initial_seed(...)
function.3. Or you can instantiateSeedGenerator
and pass to it.4. The function then will use the seed value produced by calling thenext(..)
method indraw_seed(...)
functiondraw_seed(...)
function.draw_seed()
function will call themake_initial_seed(seed)
function that expects an integer argument. Only integer arguments are supported to make the behavior consistent across backends.stateless
call, meaning if you call the function with the same seed, you will get the same results everytime.Layers
Scenario 1: No seed value is provided
next(...)
42
as the initial state. And it will then draw seed from this global seed generator on subsequent calls.Scenario 2: Seed value is provided
Will instantiate a
SeedGenerator
, and subsequent calls will draw seed from this generator.Note: The
SeedGenerator
class now depends on the backend specific state updates of the rng. In case ofTF
andJAX
we will split the state at every call while in torch the state is consumed implicitly. Thenext()
method in theSeedGenerator
is now fully deterministic and depends on the algos used by the backend for managing state of the rngThough I don't think I have missed anything in testing this on my end, it will be good if someone can confirm the same on their end. cc: @fchollet @martin-gorner @GallagherCommaJack