Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.layers.fnet_encoder import FNetEncoder
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.layers.transformer_encoder import TransformerEncoder
173 changes: 173 additions & 0 deletions keras_nlp/layers/fnet_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FNet encoder block implementation based on `keras.layers.Layer`."""

import tensorflow as tf
from tensorflow import keras


class FNetEncoder(keras.layers.Layer):
"""FNet encoder.

This class follows the architecture of FNet encoder layer in paper
"FNet: Mixing Tokens with Fourier Transforms"
(https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances
of this class to stack up the encoder.

Note on padding: In the official FNet code, padding tokens are added to the
the input. However, the padding masks are deleted, i.e., mixing of
all tokens is done. This is because certain frequencies will be zeroed
out if we apply padding masks in every encoder layer. Hence, we don't
take padding mask as input in the call() function.

Args:
intermediate_dim: int. The hidden size of feedforward network.
dropout: float, defaults to 0. The dropout value, applied in the
feedforward network.
activation: string or `tf.keras.activations`, defaults to "relu". The
activation function of feedforward network.
layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer
normalization components.
kernel_initializer: "string" or `tf.keras.initializers` initializer,
defaults to "glorot_uniform". The kernel initializer for the dense
layers.
bias_initializer: "string" or `tf.keras.initializers` initializer,
defaults to "zeros". The bias initializer for the dense layers.
name: string, defaults to None. The name of the layer.
**kwargs: other keyword arguments.

Examples:

```python
# Create a single FNet encoder layer.
encoder = keras_nlp.layers.FNetEncoder(
intermediate_dim=64)

# Create a simple model containing the encoder.
input = tf.keras.Input(shape=[4, 6])
output = encoder(input)
model = tf.keras.Model(inputs=input, outputs=output)

# Call encoder on the inputs.
input_data = tf.random.uniform(shape=[1, 10, 64])
output = model(input_data)
```

References:
[Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824)
"""

def __init__(
self,
intermediate_dim,
dropout=0,
activation="relu",
layer_norm_epsilon=1e-5,
kernel_initializer="glorot_uniform",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this choice of initializer the best practice for this layer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FNet repo: https://github.com/google-research/google-research/blob/6fd7a1a0872e9d27c1f1764836bbf9048a7903e7/f_net/models.py#L39-L42

I believe it is tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2) in Keras. Will change it to this.

Intermediate Dense Layer

kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2)
bias_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2)

Output Dense Layer:

kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2)
bias_initializer = "zeros" # They don't specify the bias initializer in the output dense layer. The default in Flax is "zeros".

https://github.com/google-research/google-research/blob/master/f_net/layers.py#L73-L81

Should I set it as tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2) for all, or should I separate the bias initializer for the output dense layer and set that to "zeros"?

bias_initializer="zeros",
name=None,
**kwargs
):
super().__init__(name=name, **kwargs)
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.activation = keras.activations.get(activation)
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)

def build(self, input_shape):
# Create layers based on input shape.
feature_size = input_shape[-1]

# Layer Norm layers.
self._mixing_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
)
self._output_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
)

# Feedforward layers.
self._intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._output_dropout = keras.layers.Dropout(rate=self.dropout)

def call(self, inputs):
"""Forward pass of the FNetEncoder.

Args:
inputs: a Tensor. The input data to TransformerEncoder, should be
of shape [batch_size, sequence_length, feature_dim].

Returns:
A Tensor of the same shape as the `inputs`.
"""

def fourier_transform(input):
# Apply FFT on the input and take the real part.
# Before we apply fourier transform, let's convert the dtype of the
# input tensor to complex64.
input = tf.cast(input, tf.complex64)
mixing_output = tf.math.real(tf.signal.fft2d(input))
return mixing_output

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)

def feed_forward(input):
x = self._intermediate_dense(input)
x = self._output_dense(x)
return self._output_dropout(x)

mixing_output = fourier_transform(inputs)

mixing_output = add_and_norm(
inputs, mixing_output, self._mixing_layer_norm
)

feed_forward_output = feed_forward(mixing_output)

x = add_and_norm(
mixing_output, feed_forward_output, self._output_layer_norm
)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"activation": keras.activations.serialize(self.activation),
"layer_norm_epsilon": self.layer_norm_epsilon,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
}
)
return config
136 changes: 136 additions & 0 deletions keras_nlp/layers/fnet_encoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for FNet Encoder."""

import os

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers import fnet_encoder


class FNetEncoderTest(tf.test.TestCase):
def test_valid_call(self):
encoder = fnet_encoder.FNetEncoder(intermediate_dim=4)
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
encoder,
]
)
input = tf.random.uniform(shape=[2, 4, 6])
model(input)

def test_get_config_and_from_config(self):
encoder = fnet_encoder.FNetEncoder(
intermediate_dim=4,
kernel_initializer="HeNormal",
bias_initializer="Zeros",
)
config = encoder.get_config()
expected_config_subset = {
"intermediate_dim": 4,
"dropout": 0,
"activation": "relu",
"layer_norm_epsilon": 1e-5,
"kernel_initializer": keras.initializers.serialize(
keras.initializers.HeNormal()
),
"bias_initializer": keras.initializers.serialize(
keras.initializers.Zeros()
),
}
self.assertEqual(config, {**config, **expected_config_subset})

restored_encoder = fnet_encoder.FNetEncoder.from_config(
config,
)
self.assertEqual(
restored_encoder.get_config(), {**config, **expected_config_subset}
)

def test_value_error_when_invalid_kernel_initializer(self):
with self.assertRaises(ValueError):
fnet_encoder.FNetEncoder(
intermediate_dim=4,
dropout=0.5,
kernel_initializer="Invalid",
)

def test_one_training_step_of_fnet_encoder(self):
encoder = fnet_encoder.FNetEncoder(intermediate_dim=4)
inputs = keras.Input(shape=(4, 6))
x = encoder(inputs)
x = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=x)

data = tf.random.uniform(shape=[2, 4, 6])
label = tf.cast(data[:, :, 0] >= 0.5, dtype=tf.int32)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam()
with tf.GradientTape() as tape:
pred = model(data)
loss = loss_fn(label, pred)
grad = tape.gradient(loss, model.trainable_variables)
self.assertGreater(len(grad), 1)
optimizer.apply_gradients(zip(grad, model.trainable_variables))

def test_checkpointing_fnet_encoder(self):
encoder1 = fnet_encoder.FNetEncoder(
intermediate_dim=4,
)

encoder2 = fnet_encoder.FNetEncoder(
intermediate_dim=4,
)
data = tf.random.uniform(shape=[2, 4, 6])
encoder1(data)
encoder2(data)
# The weights of encoder1 and encoder2 are different.
self.assertFalse(
all(
encoder1._output_dense.trainable_variables[0][0]
== encoder2._output_dense.trainable_variables[0][0]
)
)
checkpoint = tf.train.Checkpoint(encoder1)
checkpoint2 = tf.train.Checkpoint(encoder2)
temp_dir = self.get_temp_dir()
save_path = checkpoint.save(temp_dir)
checkpoint2.restore(save_path)

encoder1_output = encoder1(data)
encoder2_output = encoder2(data)
self.assertAllClose(encoder1_output, encoder2_output)

def test_save_model(self):
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
fnet_encoder.FNetEncoder(
intermediate_dim=4,
),
]
)
data = tf.random.uniform(shape=[2, 4, 6])
model(data)
path = os.path.join(self.get_temp_dir(), "model")
model.save(path)
loaded_model = keras.models.load_model(path)

model_output = model(data)
loaded_model_output = loaded_model(data)
self.assertAllClose(model_output, loaded_model_output)