diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index ce654b8d05..0397ce6be2 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.layers.fnet_encoder import FNetEncoder from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.layers.transformer_encoder import TransformerEncoder diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py new file mode 100644 index 0000000000..4793739236 --- /dev/null +++ b/keras_nlp/layers/fnet_encoder.py @@ -0,0 +1,173 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FNet encoder block implementation based on `keras.layers.Layer`.""" + +import tensorflow as tf +from tensorflow import keras + + +class FNetEncoder(keras.layers.Layer): + """FNet encoder. + + This class follows the architecture of FNet encoder layer in paper + "FNet: Mixing Tokens with Fourier Transforms" + (https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances + of this class to stack up the encoder. + + Note on padding: In the official FNet code, padding tokens are added to the + the input. However, the padding masks are deleted, i.e., mixing of + all tokens is done. This is because certain frequencies will be zeroed + out if we apply padding masks in every encoder layer. Hence, we don't + take padding mask as input in the call() function. + + Args: + intermediate_dim: int. The hidden size of feedforward network. + dropout: float, defaults to 0. The dropout value, applied in the + feedforward network. + activation: string or `tf.keras.activations`, defaults to "relu". The + activation function of feedforward network. + layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer + normalization components. + kernel_initializer: "string" or `tf.keras.initializers` initializer, + defaults to "glorot_uniform". The kernel initializer for the dense + layers. + bias_initializer: "string" or `tf.keras.initializers` initializer, + defaults to "zeros". The bias initializer for the dense layers. + name: string, defaults to None. The name of the layer. + **kwargs: other keyword arguments. + + Examples: + + ```python + # Create a single FNet encoder layer. + encoder = keras_nlp.layers.FNetEncoder( + intermediate_dim=64) + + # Create a simple model containing the encoder. + input = tf.keras.Input(shape=[4, 6]) + output = encoder(input) + model = tf.keras.Model(inputs=input, outputs=output) + + # Call encoder on the inputs. + input_data = tf.random.uniform(shape=[1, 10, 64]) + output = model(input_data) + ``` + + References: + [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824) + """ + + def __init__( + self, + intermediate_dim, + dropout=0, + activation="relu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + name=None, + **kwargs + ): + super().__init__(name=name, **kwargs) + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + def build(self, input_shape): + # Create layers based on input shape. + feature_size = input_shape[-1] + + # Layer Norm layers. + self._mixing_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self._output_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + + # Feedforward layers. + self._intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self._output_dense = keras.layers.Dense( + feature_size, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self._output_dropout = keras.layers.Dropout(rate=self.dropout) + + def call(self, inputs): + """Forward pass of the FNetEncoder. + + Args: + inputs: a Tensor. The input data to TransformerEncoder, should be + of shape [batch_size, sequence_length, feature_dim]. + + Returns: + A Tensor of the same shape as the `inputs`. + """ + + def fourier_transform(input): + # Apply FFT on the input and take the real part. + # Before we apply fourier transform, let's convert the dtype of the + # input tensor to complex64. + input = tf.cast(input, tf.complex64) + mixing_output = tf.math.real(tf.signal.fft2d(input)) + return mixing_output + + def add_and_norm(input1, input2, norm_layer): + return norm_layer(input1 + input2) + + def feed_forward(input): + x = self._intermediate_dense(input) + x = self._output_dense(x) + return self._output_dropout(x) + + mixing_output = fourier_transform(inputs) + + mixing_output = add_and_norm( + inputs, mixing_output, self._mixing_layer_norm + ) + + feed_forward_output = feed_forward(mixing_output) + + x = add_and_norm( + mixing_output, feed_forward_output, self._output_layer_norm + ) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config diff --git a/keras_nlp/layers/fnet_encoder_test.py b/keras_nlp/layers/fnet_encoder_test.py new file mode 100644 index 0000000000..f1025d9a6b --- /dev/null +++ b/keras_nlp/layers/fnet_encoder_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for FNet Encoder.""" + +import os + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.layers import fnet_encoder + + +class FNetEncoderTest(tf.test.TestCase): + def test_valid_call(self): + encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + encoder, + ] + ) + input = tf.random.uniform(shape=[2, 4, 6]) + model(input) + + def test_get_config_and_from_config(self): + encoder = fnet_encoder.FNetEncoder( + intermediate_dim=4, + kernel_initializer="HeNormal", + bias_initializer="Zeros", + ) + config = encoder.get_config() + expected_config_subset = { + "intermediate_dim": 4, + "dropout": 0, + "activation": "relu", + "layer_norm_epsilon": 1e-5, + "kernel_initializer": keras.initializers.serialize( + keras.initializers.HeNormal() + ), + "bias_initializer": keras.initializers.serialize( + keras.initializers.Zeros() + ), + } + self.assertEqual(config, {**config, **expected_config_subset}) + + restored_encoder = fnet_encoder.FNetEncoder.from_config( + config, + ) + self.assertEqual( + restored_encoder.get_config(), {**config, **expected_config_subset} + ) + + def test_value_error_when_invalid_kernel_initializer(self): + with self.assertRaises(ValueError): + fnet_encoder.FNetEncoder( + intermediate_dim=4, + dropout=0.5, + kernel_initializer="Invalid", + ) + + def test_one_training_step_of_fnet_encoder(self): + encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) + inputs = keras.Input(shape=(4, 6)) + x = encoder(inputs) + x = keras.layers.Dense(1, activation="sigmoid")(x) + model = keras.Model(inputs=inputs, outputs=x) + + data = tf.random.uniform(shape=[2, 4, 6]) + label = tf.cast(data[:, :, 0] >= 0.5, dtype=tf.int32) + + loss_fn = keras.losses.BinaryCrossentropy(from_logits=False) + optimizer = keras.optimizers.Adam() + with tf.GradientTape() as tape: + pred = model(data) + loss = loss_fn(label, pred) + grad = tape.gradient(loss, model.trainable_variables) + self.assertGreater(len(grad), 1) + optimizer.apply_gradients(zip(grad, model.trainable_variables)) + + def test_checkpointing_fnet_encoder(self): + encoder1 = fnet_encoder.FNetEncoder( + intermediate_dim=4, + ) + + encoder2 = fnet_encoder.FNetEncoder( + intermediate_dim=4, + ) + data = tf.random.uniform(shape=[2, 4, 6]) + encoder1(data) + encoder2(data) + # The weights of encoder1 and encoder2 are different. + self.assertFalse( + all( + encoder1._output_dense.trainable_variables[0][0] + == encoder2._output_dense.trainable_variables[0][0] + ) + ) + checkpoint = tf.train.Checkpoint(encoder1) + checkpoint2 = tf.train.Checkpoint(encoder2) + temp_dir = self.get_temp_dir() + save_path = checkpoint.save(temp_dir) + checkpoint2.restore(save_path) + + encoder1_output = encoder1(data) + encoder2_output = encoder2(data) + self.assertAllClose(encoder1_output, encoder2_output) + + def test_save_model(self): + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + fnet_encoder.FNetEncoder( + intermediate_dim=4, + ), + ] + ) + data = tf.random.uniform(shape=[2, 4, 6]) + model(data) + path = os.path.join(self.get_temp_dir(), "model") + model.save(path) + loaded_model = keras.models.load_model(path) + + model_output = model(data) + loaded_model_output = loaded_model(data) + self.assertAllClose(model_output, loaded_model_output)