Skip to content

Commit

Permalink
Add propagation of param_dtype to carry initializer.
Browse files Browse the repository at this point in the history
This fixes sudden dtype changes when using jax_enable_x64 with RNNs.

PiperOrigin-RevId: 545219222
  • Loading branch information
Flax Team committed Jul 3, 2023
1 parent b05c673 commit 8efa997
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def initialize_carry(
batch_dims = input_shape[:-1]
key1, key2 = random.split(rng)
mem_shape = batch_dims + (self.features,)
c = self.carry_init(key1, mem_shape)
h = self.carry_init(key2, mem_shape)
c = self.carry_init(key1, mem_shape, self.param_dtype)
h = self.carry_init(key2, mem_shape, self.param_dtype)
return (c, h)

@property
Expand Down Expand Up @@ -455,7 +455,7 @@ def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
"""
batch_dims = input_shape[:-1]
mem_shape = batch_dims + (self.features,)
return self.carry_init(rng, mem_shape)
return self.carry_init(rng, mem_shape, self.param_dtype)

@property
def num_feature_axes(self) -> int:
Expand Down Expand Up @@ -569,8 +569,8 @@ def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
batch_dims = input_shape[:-self.num_feature_axes]
key1, key2 = random.split(rng)
mem_shape = batch_dims + signal_dims + (self.features,)
c = self.carry_init(key1, mem_shape)
h = self.carry_init(key2, mem_shape)
c = self.carry_init(key1, mem_shape, self.param_dtype)
h = self.carry_init(key2, mem_shape, self.param_dtype)
return c, h

@property
Expand Down

0 comments on commit 8efa997

Please sign in to comment.