Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimpleCell. #3697

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api_reference/flax.linen/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ Recurrent
:module: flax.linen
:class: OptimizedLSTMCell

.. flax_module::
:module: flax.linen
:class: SimpleCell

.. flax_module::
:module: flax.linen
:class: GRUCell
Expand Down Expand Up @@ -168,6 +172,7 @@ BatchApply
RNNCellBase
LSTMCell
OptimizedLSTMCell
SimpleCell
GRUCell
RNN
Bidirectional
Expand Down
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from .recurrent import (
Bidirectional as Bidirectional,
ConvLSTMCell as ConvLSTMCell,
SimpleCell as SimpleCell,
GRUCell as GRUCell,
MGUCell as MGUCell,
LSTMCell as LSTMCell,
Expand Down
118 changes: 116 additions & 2 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,120 @@ def num_feature_axes(self) -> int:
return 1


class SimpleCell(RNNCellBase):
r"""Simple cell.

The mathematical definition of the cell is as follows

.. math::

\begin{array}{ll}
h' = \tanh(W_i x + b_i + W_h h)
\end{array}

where x is the input and h is the output of the previous time step.

If `residual` is `True`,

.. math::

\begin{array}{ll}
h' = \tanh(W_i x + b_i + W_h h + h)
\end{array}

Example usage::

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.SimpleCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)

Attributes:
features: number of output features.
activation_fn: activation function used for output and memory update
(default: tanh).
kernel_init: initializer function for the kernels that transform
the input (default: lecun_normal).
recurrent_kernel_init: initializer function for the kernels that transform
the hidden state (default: initializers.orthogonal()).
bias_init: initializer for the bias parameters (default: initializers.zeros_init())
dtype: the dtype of the computation (default: None).
param_dtype: the dtype passed to parameter initializers (default: float32).
residual: pre-activation residual connection (https://arxiv.org/abs/1801.06105).
"""

features: int
activation_fn: Callable[..., Any] = tanh
kernel_init: Initializer = default_kernel_init
recurrent_kernel_init: Initializer = initializers.orthogonal()
bias_init: Initializer = initializers.zeros_init()
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
carry_init: Initializer = initializers.zeros_init()
residual: bool = False

@compact
def __call__(self, carry, inputs):
"""Simple cell.

Args:
carry: the hidden state of the Simple cell,
initialized using ``SimpleCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.

Returns:
A tuple with the new carry and the output.
"""
hidden_features = carry.shape[-1]
# input and recurrent layers are summed so only one needs a bias.
dense_h = partial(
Dense,
features=hidden_features,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.recurrent_kernel_init,
)
dense_i = partial(
Dense,
features=hidden_features,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)
new_carry = dense_i(name='i')(inputs) + dense_h(name='h')(carry)
if self.residual:
new_carry += carry
new_carry = self.activation_fn(new_carry)
return new_carry, new_carry

@nowrap
def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
"""Initialize the RNN cell carry.

Args:
rng: random number generator passed to the init_fn.
input_shape: a tuple providing the shape of the input to the cell.

Returns:
An initialized carry for the given RNN cell.
"""
batch_dims = input_shape[:-1]
mem_shape = batch_dims + (self.features,)
return self.carry_init(rng, mem_shape, self.param_dtype)

@property
def num_feature_axes(self) -> int:
return 1


class GRUCell(RNNCellBase):
r"""GRU cell.

Expand All @@ -406,7 +520,7 @@ class GRUCell(RNNCellBase):
h' = (1 - z) * n + z * h \\
\end{array}

where x is the input and h, is the output of the previous time step.
where x is the input and h is the output of the previous time step.

Example usage::

Expand Down Expand Up @@ -519,7 +633,7 @@ class MGUCell(RNNCellBase):
h' = (1 - f) * n + f * h \\
\end{array}

where x is the input and h, is the output of the previous time step.
where x is the input and h is the output of the previous time step.

Example usage::

Expand Down
11 changes: 10 additions & 1 deletion tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,13 @@ def test_lstm(self):
)

@parameterized.parameters(
{
'module_cls': nn.SimpleCell,
'expected_param_shapes': {
'i': {'kernel': (3, 4), 'bias': (4,)},
'h': {'kernel': (4, 4)},
},
},
{
'module_cls': nn.GRUCell,
'expected_param_shapes': {
Expand Down Expand Up @@ -1068,7 +1075,9 @@ def test_gated_units(self, module_cls, expected_param_shapes):
)

@parameterized.parameters(
{'module_cls': nn.GRUCell}, {'module_cls': nn.MGUCell}
{'module_cls': nn.SimpleCell},
{'module_cls': nn.GRUCell},
{'module_cls': nn.MGUCell},
)
def test_complex_input_gated_units(self, module_cls):
module_instance = module_cls(features=4)
Expand Down