Skip to content

Commit

Permalink
Add SimpleCell.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Feb 14, 2024
1 parent 6181239 commit dc9dbdb
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/api_reference/flax.linen/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ Recurrent
:module: flax.linen
:class: OptimizedLSTMCell

.. flax_module::
:module: flax.linen
:class: SimpleCell

.. flax_module::
:module: flax.linen
:class: GRUCell
Expand Down Expand Up @@ -168,6 +172,7 @@ BatchApply
RNNCellBase
LSTMCell
OptimizedLSTMCell
SimpleCell
GRUCell
RNN
Bidirectional
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from .recurrent import (
Bidirectional as Bidirectional,
ConvLSTMCell as ConvLSTMCell,
SimpleCell as SimpleCell,
GRUCell as GRUCell,
MGUCell as MGUCell,
LSTMCell as LSTMCell,
Expand Down
107 changes: 107 additions & 0 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,113 @@ def num_feature_axes(self) -> int:
return 1


class SimpleCell(RNNCellBase):
r"""Simple cell.
The mathematical definition of the cell is as follows
.. math::
\begin{array}{ll}
h' = \tanh(W_i x + b + W_h h)
\end{array}
where x is the input and h, is the output of the previous time step.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.SimpleCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
Attributes:
features: number of output features.
activation_fn: activation function used for output and memory update
(default: tanh).
kernel_init: initializer function for the kernels that transform
the input (default: lecun_normal).
recurrent_kernel_init: initializer function for the kernels that transform
the hidden state (default: initializers.orthogonal()).
bias_init: initializer for the bias parameters (default: initializers.zeros_init())
dtype: the dtype of the computation (default: None).
param_dtype: the dtype passed to parameter initializers (default: float32).
residual: pre-activation residual connection (https://arxiv.org/abs/1801.06105).
"""

features: int
activation_fn: Callable[..., Any] = tanh
kernel_init: Initializer = default_kernel_init
recurrent_kernel_init: Initializer = initializers.orthogonal()
bias_init: Initializer = initializers.zeros_init()
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
carry_init: Initializer = initializers.zeros_init()
residual: bool = False

@compact
def __call__(self, carry, inputs):
"""Simple cell.
Args:
carry: the hidden state of the Simple cell,
initialized using ``SimpleCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.
Returns:
A tuple with the new carry and the output.
"""
hidden_features = carry.shape[-1]
# input and recurrent layers are summed so only one needs a bias.
dense_h = partial(
Dense,
features=hidden_features,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.recurrent_kernel_init,
bias_init=self.bias_init,
)
dense_i = partial(
Dense,
features=hidden_features,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)
new_carry = dense_i(name='i')(inputs) + dense_h(name='h')(carry)
if self.residual:
new_carry += carry
new_carry = self.activation_fn(new_carry)
return new_carry, new_carry

@nowrap
def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
"""Initialize the RNN cell carry.
Args:
rng: random number generator passed to the init_fn.
input_shape: a tuple providing the shape of the input to the cell.
Returns:
An initialized carry for the given RNN cell.
"""
batch_dims = input_shape[:-1]
mem_shape = batch_dims + (self.features,)
return self.carry_init(rng, mem_shape, self.param_dtype)

@property
def num_feature_axes(self) -> int:
return 1


class GRUCell(RNNCellBase):
r"""GRU cell.
Expand Down
11 changes: 10 additions & 1 deletion tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,13 @@ def test_lstm(self):
)

@parameterized.parameters(
{
'module_cls': nn.SimpleCell,
'expected_param_shapes': {
'i': {'kernel': (3, 4), 'bias': (4,)},
'h': {'kernel': (4, 4)},
},
},
{
'module_cls': nn.GRUCell,
'expected_param_shapes': {
Expand Down Expand Up @@ -1068,7 +1075,9 @@ def test_gated_units(self, module_cls, expected_param_shapes):
)

@parameterized.parameters(
{'module_cls': nn.GRUCell}, {'module_cls': nn.MGUCell}
{'module_cls': nn.SimpleCell},
{'module_cls': nn.GRUCell},
{'module_cls': nn.MGUCell},
)
def test_complex_input_gated_units(self, module_cls):
module_instance = module_cls(features=4)
Expand Down

0 comments on commit dc9dbdb

Please sign in to comment.