Skip to content

Commit

Permalink
Merge pull request #3354 from chiamp:dunder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567429834
  • Loading branch information
Flax Authors committed Sep 21, 2023
2 parents 79915d2 + dbc9254 commit d33a33e
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/guides/dropout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ the training step function. Refer to the
from flax.training import train_state

class TrainState(train_state.TrainState): #!
key: jax.random.KeyArray #!
key: jax.Array #!

state = TrainState.create( #!
apply_fn=my_model.apply,
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np

Array = jax.Array
PRNGKey = jax.random.KeyArray
PRNGKey = jax.Array
LSTMCarry = Tuple[Array, Array]


Expand Down
2 changes: 1 addition & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@

traceback_util.register_exclusion(__file__)

KeyArray = Union[jax.Array, jax.random.KeyArray] # pylint: disable=invalid-name
KeyArray = jax.Array
RNGSequences = Dict[str, KeyArray]
Array = Any # pylint: disable=invalid-name

Expand Down
8 changes: 4 additions & 4 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from typing_extensions import Protocol

A = TypeVar('A')
PRNGKey = jax.random.KeyArray
PRNGKey = jax.Array
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Array = jax.Array
Expand Down Expand Up @@ -750,7 +750,7 @@ def __call__(
inputs: jax.Array,
*,
initial_carry: Optional[Carry] = None,
init_key: Optional[random.KeyArray] = None,
init_key: Optional[PRNGKey] = None,
seq_lengths: Optional[Array] = None,
return_carry: Optional[bool] = None,
time_major: Optional[bool] = None,
Expand Down Expand Up @@ -976,7 +976,7 @@ def __call__(
inputs: jax.Array,
*,
initial_carry: Optional[Carry] = None,
init_key: Optional[random.KeyArray] = None,
init_key: Optional[PRNGKey] = None,
seq_lengths: Optional[Array] = None,
return_carry: Optional[bool] = None,
time_major: Optional[bool] = None,
Expand All @@ -1000,7 +1000,7 @@ def __call__(
inputs: jax.Array,
*,
initial_carry: Optional[Carry] = None,
init_key: Optional[random.KeyArray] = None,
init_key: Optional[PRNGKey] = None,
seq_lengths: Optional[Array] = None,
return_carry: Optional[bool] = None,
time_major: Optional[bool] = None,
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax import random
import jax.numpy as jnp

KeyArray = Union[jax.Array, jax.random.KeyArray]
KeyArray = jax.Array


class Dropout(Module):
Expand Down

0 comments on commit d33a33e

Please sign in to comment.