From 59cc53cbea715426263f95714a5688d78f4d47ca Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Tue, 15 Mar 2022 18:09:02 +0530 Subject: [PATCH 01/13] Add rough code for FNet Encoder --- keras_nlp/layers/fnet_encoder.py | 137 +++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 keras_nlp/layers/fnet_encoder.py diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py new file mode 100644 index 0000000000..84b1edeaed --- /dev/null +++ b/keras_nlp/layers/fnet_encoder.py @@ -0,0 +1,137 @@ +"""Transformer encoder block implementation based on `keras.layers.Layer`.""" + +import tensorflow as tf + +from tensorflow import keras + + +class FNetEncoder(keras.layers.Layer): + """FNet encoder. + + This class follows the architecture of FNet encoder layer in paper + "FNet: Mixing Tokens with Fourier Transforms" + (https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances + of this class to stack up the encoder. + + Args: + intermediate_dim: int, defaults to 3072. The hidden size of feedforward + network. + dropout: float, defaults to 0.1. the dropout value, applied in the + feedforward network. + activation: string or `tf.keras.activations`, defaults to "gelu". The + activation function of feedforward network. + layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer + normalization components. + name: string, defaults to None. The name of the layer. + **kwargs: other keyword arguments. + + Examples: + + ```python + # Create a single FNet encoder layer. + encoder = keras_nlp.layer.FNetEncoder( + intermediate_dim=64) + + # Create a simple model containing the encoder. + input = tf.keras.Input(shape=[4, 6]) + output = encoder(input) + model = tf.keras.Model(inputs=input, outputs=output) + + # Call encoder on the inputs. + input_data = tf.random.uniform(shape=[1, 10, 64]) + output = model(input_data) + + ``` + + References: + [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824) + """ + + def __init__( + self, + intermediate_dim=3072, + dropout=0.1, + activation="gelu", + layer_norm_epsilon=1e-12, + name=None, + **kwargs + ): + super().__init__(name=name, **kwargs) + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.activation = activation + self.layer_norm_epsilon = layer_norm_epsilon + self._built = False + + def _build(self, input_shape): + # Create layers based on input shape. + self._built = True + feature_size = input_shape[-1] + + # Layer Norm layers. + self._mixing_layer_norm = keras.layers.LayerNormalization(epsilon=self.layer_norm_epsilon) + self._output_layer_norm = keras.layers.LayerNormalization(epsilon=self.layer_norm_epsilon) + + # Feedforward layer. + self._intermediate_dense = keras.layers.Dense( + self.intermediate_dim, activation=self.activation + ) + self._output_dense = keras.layers.Dense(feature_size) + self._output_dropout = keras.layers.Dropout(rate=self.dropout) + + def _fourier_transform(self, input): + # Apply FFT on the input and take the real part. Before we apply fourier + # transform, let's convert the dtype of the input tensor to complex64. + input = tf.cast(input, tf.complex64) + mixing_output = tf.math.real(tf.signal.fft2d(input)) + return mixing_output + + def _add_and_norm(self, input1, input2, norm_layer): + return norm_layer(input1 + input2) + + def _feed_forward(self, input): + x = self._intermediate_dense(input) + x = self._output_dense(x) + return self._output_dropout(x) + + def call(self, inputs): + """Forward pass of the FNetEncoder. + + Args: + inputs: a Tensor. The input data to TransformerEncoder, should be + of shape [batch_size, sequence_length, feature_dim]. + + Returns: + A Tensor of the same shape as the `inputs`. + """ + + if not self._built: + self._build(inputs.shape) + + # Apply fourier transform on the input. Note: We don't have padding + # tokens in the official FNet code. + # https://github.com/google-research/google-research/blob/master/f_net/layers.py#L137 + mixing_output = self._fourier_transform(inputs) + + # LayerNorm layer. + mixing_output = self._add_and_norm(inputs, mixing_output, + self._mixing_layer_norm) + + # Feedforward layer. + feed_forward_output = self._feed_forward(mixing_output) + + # LayerNorm layer. + x = self._add_and_norm(mixing_output, feed_forward_output, self._output_layer_norm) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "activation": self.activation, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config From e520e4d792940d02960caa845182e957a1863467 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Tue, 15 Mar 2022 18:32:52 +0530 Subject: [PATCH 02/13] Format code --- keras_nlp/layers/fnet_encoder.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 84b1edeaed..6db93e3b8d 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -1,7 +1,20 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Transformer encoder block implementation based on `keras.layers.Layer`.""" import tensorflow as tf - from tensorflow import keras @@ -69,8 +82,12 @@ def _build(self, input_shape): feature_size = input_shape[-1] # Layer Norm layers. - self._mixing_layer_norm = keras.layers.LayerNormalization(epsilon=self.layer_norm_epsilon) - self._output_layer_norm = keras.layers.LayerNormalization(epsilon=self.layer_norm_epsilon) + self._mixing_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) + self._output_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon + ) # Feedforward layer. self._intermediate_dense = keras.layers.Dense( @@ -114,14 +131,17 @@ def call(self, inputs): mixing_output = self._fourier_transform(inputs) # LayerNorm layer. - mixing_output = self._add_and_norm(inputs, mixing_output, - self._mixing_layer_norm) + mixing_output = self._add_and_norm( + inputs, mixing_output, self._mixing_layer_norm + ) # Feedforward layer. feed_forward_output = self._feed_forward(mixing_output) # LayerNorm layer. - x = self._add_and_norm(mixing_output, feed_forward_output, self._output_layer_norm) + x = self._add_and_norm( + mixing_output, feed_forward_output, self._output_layer_norm + ) return x def get_config(self): From a005ae2296a23a94224f5e7775371e45ba53a24b Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Tue, 15 Mar 2022 18:49:23 +0530 Subject: [PATCH 03/13] Minor doc-string changes --- keras_nlp/layers/__init__.py | 1 + keras_nlp/layers/fnet_encoder.py | 4 ++-- keras_nlp/layers/transformer_decoder.py | 2 +- keras_nlp/layers/transformer_encoder.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index ce654b8d05..74c8ef8f0b 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -14,3 +14,4 @@ from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.layers.transformer_encoder import TransformerEncoder +from keras_nlp.layers.fnet_encoder import FNetEncoder diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 6db93e3b8d..aa0b86f00c 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Transformer encoder block implementation based on `keras.layers.Layer`.""" +"""FNet encoder block implementation based on `keras.layers.Layer`.""" import tensorflow as tf from tensorflow import keras @@ -42,7 +42,7 @@ class FNetEncoder(keras.layers.Layer): ```python # Create a single FNet encoder layer. - encoder = keras_nlp.layer.FNetEncoder( + encoder = keras_nlp.layers.FNetEncoder( intermediate_dim=64) # Create a simple model containing the encoder. diff --git a/keras_nlp/layers/transformer_decoder.py b/keras_nlp/layers/transformer_decoder.py index 5bea0697a0..1d1cc11a77 100644 --- a/keras_nlp/layers/transformer_decoder.py +++ b/keras_nlp/layers/transformer_decoder.py @@ -45,7 +45,7 @@ class TransformerDecoder(keras.layers.Layer): Examples: ```python # Create a single transformer decoder layer. - decoder = keras_nlp.layer.TransformerDecoder( + decoder = keras_nlp.layers.TransformerDecoder( intermediate_dim=64, num_heads=8) # Create a simple model containing the decoder. diff --git a/keras_nlp/layers/transformer_encoder.py b/keras_nlp/layers/transformer_encoder.py index 4cc04debab..48a017af67 100644 --- a/keras_nlp/layers/transformer_encoder.py +++ b/keras_nlp/layers/transformer_encoder.py @@ -44,7 +44,7 @@ class TransformerEncoder(keras.layers.Layer): ```python # Create a single transformer decoder layer. - encoder = keras_nlp.layer.TransformerEncoder( + encoder = keras_nlp.layers.TransformerEncoder( intermediate_dim=64, num_heads=8) # Create a simple model containing the decoder. From a5575ffa3e4af98991b2bb0bacc8b88a2f7b2796 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Tue, 15 Mar 2022 19:10:30 +0530 Subject: [PATCH 04/13] Format __init__.py --- keras_nlp/layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index 74c8ef8f0b..0397ce6be2 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.layers.fnet_encoder import FNetEncoder from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.layers.transformer_encoder import TransformerEncoder -from keras_nlp.layers.fnet_encoder import FNetEncoder From c75e09f72c437639c3a569b08cdb7d83993c1f6d Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Sun, 20 Mar 2022 15:45:56 +0530 Subject: [PATCH 05/13] Address review comments - 1 --- keras_nlp/layers/fnet_encoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index aa0b86f00c..4e685d544a 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -29,7 +29,7 @@ class FNetEncoder(keras.layers.Layer): Args: intermediate_dim: int, defaults to 3072. The hidden size of feedforward network. - dropout: float, defaults to 0.1. the dropout value, applied in the + dropout: float, defaults to 0.1. The dropout value, applied in the feedforward network. activation: string or `tf.keras.activations`, defaults to "gelu". The activation function of feedforward network. @@ -97,8 +97,9 @@ def _build(self, input_shape): self._output_dropout = keras.layers.Dropout(rate=self.dropout) def _fourier_transform(self, input): - # Apply FFT on the input and take the real part. Before we apply fourier - # transform, let's convert the dtype of the input tensor to complex64. + # Apply FFT on the input and take the real part. + # Before we apply fourier transform, let's convert the dtype of the + # input tensor to complex64. input = tf.cast(input, tf.complex64) mixing_output = tf.math.real(tf.signal.fft2d(input)) return mixing_output From 6d850739d1d996360a13bf1e2a986f1506b644cb Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Tue, 22 Mar 2022 08:40:40 +0530 Subject: [PATCH 06/13] Add detailed comment about padding masks --- keras_nlp/layers/fnet_encoder.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 4e685d544a..b0e5471488 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -53,7 +53,6 @@ class FNetEncoder(keras.layers.Layer): # Call encoder on the inputs. input_data = tf.random.uniform(shape=[1, 10, 64]) output = model(input_data) - ``` References: @@ -126,8 +125,13 @@ def call(self, inputs): if not self._built: self._build(inputs.shape) - # Apply fourier transform on the input. Note: We don't have padding - # tokens in the official FNet code. + # Apply fourier transform on the input. + # Note: In the official FNet code, padding tokens are added to the + # the input. However, the padding masks are deleted, i.e., mixing of + # all tokens is done. This is because certain frequencies will be zeroed + # out if we apply padding masks in every encoder layer. + # Code references: + # https://github.com/google-research/google-research/blob/master/f_net/input_pipeline.py#L107-L109 # https://github.com/google-research/google-research/blob/master/f_net/layers.py#L137 mixing_output = self._fourier_transform(inputs) From 63b7f22c406acd7adbdc303ff60c79305a622b89 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 23 Mar 2022 01:33:39 +0530 Subject: [PATCH 07/13] Add kernel and bias initialisers --- keras_nlp/layers/fnet_encoder.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index b0e5471488..c0871c1ffb 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -35,6 +35,10 @@ class FNetEncoder(keras.layers.Layer): activation function of feedforward network. layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer normalization components. + kernel_initializer: tf.keras.initializers initializer, defaults to + "glorot_uniform". The kernel initializer for the dense layers. + bias_initializer: tf.keras.initializers initializer, defaults to + "zeros". The bias initializer for the dense layers. name: string, defaults to None. The name of the layer. **kwargs: other keyword arguments. @@ -65,14 +69,18 @@ def __init__( dropout=0.1, activation="gelu", layer_norm_epsilon=1e-12, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", name=None, **kwargs ): super().__init__(name=name, **kwargs) self.intermediate_dim = intermediate_dim self.dropout = dropout - self.activation = activation + self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) self._built = False def _build(self, input_shape): @@ -90,9 +98,16 @@ def _build(self, input_shape): # Feedforward layer. self._intermediate_dense = keras.layers.Dense( - self.intermediate_dim, activation=self.activation + self.intermediate_dim, + activation=self.activation, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self._output_dense = keras.layers.Dense( + feature_size, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, ) - self._output_dense = keras.layers.Dense(feature_size) self._output_dropout = keras.layers.Dropout(rate=self.dropout) def _fourier_transform(self, input): @@ -155,8 +170,14 @@ def get_config(self): { "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, - "activation": self.activation, + "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), } ) return config From debafb7315dff5632ba87e9d835e343fa26f8974 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 23 Mar 2022 01:57:59 +0530 Subject: [PATCH 08/13] Add unit tests for the layer --- keras_nlp/layers/fnet_encoder_test.py | 136 ++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 keras_nlp/layers/fnet_encoder_test.py diff --git a/keras_nlp/layers/fnet_encoder_test.py b/keras_nlp/layers/fnet_encoder_test.py new file mode 100644 index 0000000000..fd43dd31a9 --- /dev/null +++ b/keras_nlp/layers/fnet_encoder_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for FNet Encoder.""" + +import os + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.layers import fnet_encoder + + +class FNetEncoderTest(tf.test.TestCase): + def test_valid_call(self): + encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + encoder, + ] + ) + input = tf.random.uniform(shape=[2, 4, 6]) + model(input) + + def test_get_config_and_from_config(self): + encoder = fnet_encoder.FNetEncoder( + intermediate_dim=4, + kernel_initializer="HeNormal", + bias_initializer="Zeros", + ) + config = encoder.get_config() + expected_config_subset = { + "intermediate_dim": 4, + "dropout": 0.1, + "activation": "gelu", + "layer_norm_epsilon": 1e-12, + "kernel_initializer": keras.initializers.serialize( + keras.initializers.HeNormal() + ), + "bias_initializer": keras.initializers.serialize( + keras.initializers.Zeros() + ), + } + self.assertEqual(config, {**config, **expected_config_subset}) + + restored_encoder = fnet_encoder.FNetEncoder.from_config( + config, + ) + self.assertEqual( + restored_encoder.get_config(), {**config, **expected_config_subset} + ) + + def test_value_error_when_invalid_kernel_initializer(self): + with self.assertRaises(ValueError): + fnet_encoder.FNetEncoder( + intermediate_dim=4, + dropout=0.5, + kernel_initializer="Invalid", + ) + + def test_one_training_step_of_fnet_encoder(self): + encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) + inputs = keras.Input(shape=(4, 6)) + x = encoder(inputs) + x = keras.layers.Dense(1, activation="sigmoid")(x) + model = keras.Model(inputs=inputs, outputs=x) + + data = tf.random.uniform(shape=[2, 4, 6]) + label = tf.cast(data[:, :, 0] >= 0.5, dtype=tf.int32) + + loss_fn = keras.losses.BinaryCrossentropy(from_logits=False) + optimizer = keras.optimizers.Adam() + with tf.GradientTape() as tape: + pred = model(data) + loss = loss_fn(label, pred) + grad = tape.gradient(loss, model.trainable_variables) + self.assertGreater(len(grad), 1) + optimizer.apply_gradients(zip(grad, model.trainable_variables)) + + def test_checkpointing_fnet_encoder(self): + encoder1 = fnet_encoder.FNetEncoder( + intermediate_dim=4, + ) + + encoder2 = fnet_encoder.FNetEncoder( + intermediate_dim=4, + ) + data = tf.random.uniform(shape=[2, 4, 6]) + encoder1(data) + encoder2(data) + # The weights of encoder1 and encoder2 are different. + self.assertFalse( + all( + encoder1._output_dense.trainable_variables[0][0] + == encoder2._output_dense.trainable_variables[0][0] + ) + ) + checkpoint = tf.train.Checkpoint(encoder1) + checkpoint2 = tf.train.Checkpoint(encoder2) + temp_dir = self.get_temp_dir() + save_path = checkpoint.save(temp_dir) + checkpoint2.restore(save_path) + + encoder1_output = encoder1(data) + encoder2_output = encoder2(data) + self.assertAllClose(encoder1_output, encoder2_output) + + def test_save_model(self): + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + fnet_encoder.FNetEncoder( + intermediate_dim=4, + ), + ] + ) + data = tf.random.uniform(shape=[2, 4, 6]) + model(data) + path = os.path.join(self.get_temp_dir(), "model") + model.save(path) + loaded_model = keras.models.load_model(path) + + model_output = model(data) + loaded_model_output = loaded_model(data) + self.assertAllClose(model_output, loaded_model_output) From 669ca9c6143594603bce8fc258d3c911b52b8c94 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 23 Mar 2022 09:12:35 +0530 Subject: [PATCH 09/13] Address review comments - 2 --- keras_nlp/layers/fnet_encoder.py | 49 ++++++++++++++------------------ 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index c0871c1ffb..0933e31182 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -69,8 +69,8 @@ def __init__( dropout=0.1, activation="gelu", layer_norm_epsilon=1e-12, - kernel_initializer="glorot_uniform", - bias_initializer="zeros", + kernel_initializer=tf.keras.initializers.RandomNormal(stddev=2e-2), + bias_initializer=tf.keras.initializers.RandomNormal(stddev=2e-2), name=None, **kwargs ): @@ -81,11 +81,9 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) - self._built = False - def _build(self, input_shape): + def build(self, input_shape): # Create layers based on input shape. - self._built = True feature_size = input_shape[-1] # Layer Norm layers. @@ -110,22 +108,6 @@ def _build(self, input_shape): ) self._output_dropout = keras.layers.Dropout(rate=self.dropout) - def _fourier_transform(self, input): - # Apply FFT on the input and take the real part. - # Before we apply fourier transform, let's convert the dtype of the - # input tensor to complex64. - input = tf.cast(input, tf.complex64) - mixing_output = tf.math.real(tf.signal.fft2d(input)) - return mixing_output - - def _add_and_norm(self, input1, input2, norm_layer): - return norm_layer(input1 + input2) - - def _feed_forward(self, input): - x = self._intermediate_dense(input) - x = self._output_dense(x) - return self._output_dropout(x) - def call(self, inputs): """Forward pass of the FNetEncoder. @@ -137,8 +119,21 @@ def call(self, inputs): A Tensor of the same shape as the `inputs`. """ - if not self._built: - self._build(inputs.shape) + def _fourier_transform(input): + # Apply FFT on the input and take the real part. + # Before we apply fourier transform, let's convert the dtype of the + # input tensor to complex64. + input = tf.cast(input, tf.complex64) + mixing_output = tf.math.real(tf.signal.fft2d(input)) + return mixing_output + + def _add_and_norm(input1, input2, norm_layer): + return norm_layer(input1 + input2) + + def _feed_forward(input): + x = self._intermediate_dense(input) + x = self._output_dense(x) + return self._output_dropout(x) # Apply fourier transform on the input. # Note: In the official FNet code, padding tokens are added to the @@ -148,18 +143,18 @@ def call(self, inputs): # Code references: # https://github.com/google-research/google-research/blob/master/f_net/input_pipeline.py#L107-L109 # https://github.com/google-research/google-research/blob/master/f_net/layers.py#L137 - mixing_output = self._fourier_transform(inputs) + mixing_output = _fourier_transform(inputs) # LayerNorm layer. - mixing_output = self._add_and_norm( + mixing_output = _add_and_norm( inputs, mixing_output, self._mixing_layer_norm ) # Feedforward layer. - feed_forward_output = self._feed_forward(mixing_output) + feed_forward_output = _feed_forward(mixing_output) # LayerNorm layer. - x = self._add_and_norm( + x = _add_and_norm( mixing_output, feed_forward_output, self._output_layer_norm ) return x From f2134729ffc2db21841c8264ddf20d4ed99a0b70 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Wed, 23 Mar 2022 09:45:41 +0530 Subject: [PATCH 10/13] Address review comments - 3 --- keras_nlp/layers/fnet_encoder.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 0933e31182..7dbc9366af 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -26,6 +26,12 @@ class FNetEncoder(keras.layers.Layer): (https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances of this class to stack up the encoder. + Note on padding: In the official FNet code, padding tokens are added to the + the input. However, the padding masks are deleted, i.e., mixing of + all tokens is done. This is because certain frequencies will be zeroed + out if we apply padding masks in every encoder layer. Hence, we don't + take padding mask as input in the call() function. + Args: intermediate_dim: int, defaults to 3072. The hidden size of feedforward network. @@ -65,7 +71,7 @@ class FNetEncoder(keras.layers.Layer): def __init__( self, - intermediate_dim=3072, + intermediate_dim, dropout=0.1, activation="gelu", layer_norm_epsilon=1e-12, @@ -136,13 +142,6 @@ def _feed_forward(input): return self._output_dropout(x) # Apply fourier transform on the input. - # Note: In the official FNet code, padding tokens are added to the - # the input. However, the padding masks are deleted, i.e., mixing of - # all tokens is done. This is because certain frequencies will be zeroed - # out if we apply padding masks in every encoder layer. - # Code references: - # https://github.com/google-research/google-research/blob/master/f_net/input_pipeline.py#L107-L109 - # https://github.com/google-research/google-research/blob/master/f_net/layers.py#L137 mixing_output = _fourier_transform(inputs) # LayerNorm layer. From 0a16b303b12ce672a48994732e8043d7cb003b2e Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Thu, 24 Mar 2022 08:18:18 +0530 Subject: [PATCH 11/13] Address review comments - 4 --- keras_nlp/layers/fnet_encoder.py | 28 ++++++++++++--------------- keras_nlp/layers/fnet_encoder_test.py | 6 +++--- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 7dbc9366af..158a8f33d9 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -72,11 +72,11 @@ class FNetEncoder(keras.layers.Layer): def __init__( self, intermediate_dim, - dropout=0.1, - activation="gelu", - layer_norm_epsilon=1e-12, - kernel_initializer=tf.keras.initializers.RandomNormal(stddev=2e-2), - bias_initializer=tf.keras.initializers.RandomNormal(stddev=2e-2), + dropout=0, + activation="relu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", name=None, **kwargs ): @@ -125,7 +125,7 @@ def call(self, inputs): A Tensor of the same shape as the `inputs`. """ - def _fourier_transform(input): + def fourier_transform(input): # Apply FFT on the input and take the real part. # Before we apply fourier transform, let's convert the dtype of the # input tensor to complex64. @@ -133,27 +133,23 @@ def _fourier_transform(input): mixing_output = tf.math.real(tf.signal.fft2d(input)) return mixing_output - def _add_and_norm(input1, input2, norm_layer): + def add_and_norm(input1, input2, norm_layer): return norm_layer(input1 + input2) - def _feed_forward(input): + def feed_forward(input): x = self._intermediate_dense(input) x = self._output_dense(x) return self._output_dropout(x) - # Apply fourier transform on the input. - mixing_output = _fourier_transform(inputs) + mixing_output = fourier_transform(inputs) - # LayerNorm layer. - mixing_output = _add_and_norm( + mixing_output = add_and_norm( inputs, mixing_output, self._mixing_layer_norm ) - # Feedforward layer. - feed_forward_output = _feed_forward(mixing_output) + feed_forward_output = feed_forward(mixing_output) - # LayerNorm layer. - x = _add_and_norm( + x = add_and_norm( mixing_output, feed_forward_output, self._output_layer_norm ) return x diff --git a/keras_nlp/layers/fnet_encoder_test.py b/keras_nlp/layers/fnet_encoder_test.py index fd43dd31a9..f1025d9a6b 100644 --- a/keras_nlp/layers/fnet_encoder_test.py +++ b/keras_nlp/layers/fnet_encoder_test.py @@ -42,9 +42,9 @@ def test_get_config_and_from_config(self): config = encoder.get_config() expected_config_subset = { "intermediate_dim": 4, - "dropout": 0.1, - "activation": "gelu", - "layer_norm_epsilon": 1e-12, + "dropout": 0, + "activation": "relu", + "layer_norm_epsilon": 1e-5, "kernel_initializer": keras.initializers.serialize( keras.initializers.HeNormal() ), From ada4eb168a80945ed0429ef8c282e9122e0630f5 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Thu, 24 Mar 2022 08:22:27 +0530 Subject: [PATCH 12/13] Minor change --- keras_nlp/layers/fnet_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 158a8f33d9..0dac9f5708 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -100,7 +100,7 @@ def build(self, input_shape): epsilon=self.layer_norm_epsilon ) - # Feedforward layer. + # Feedforward layers. self._intermediate_dense = keras.layers.Dense( self.intermediate_dim, activation=self.activation, From 8ce1d2ed9fcbe7d5d1192e33df768434ccd393b7 Mon Sep 17 00:00:00 2001 From: abheesht17 Date: Sat, 26 Mar 2022 07:38:57 +0530 Subject: [PATCH 13/13] Correct doc-string --- keras_nlp/layers/fnet_encoder.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/keras_nlp/layers/fnet_encoder.py b/keras_nlp/layers/fnet_encoder.py index 0dac9f5708..4793739236 100644 --- a/keras_nlp/layers/fnet_encoder.py +++ b/keras_nlp/layers/fnet_encoder.py @@ -33,18 +33,18 @@ class FNetEncoder(keras.layers.Layer): take padding mask as input in the call() function. Args: - intermediate_dim: int, defaults to 3072. The hidden size of feedforward - network. - dropout: float, defaults to 0.1. The dropout value, applied in the + intermediate_dim: int. The hidden size of feedforward network. + dropout: float, defaults to 0. The dropout value, applied in the feedforward network. - activation: string or `tf.keras.activations`, defaults to "gelu". The + activation: string or `tf.keras.activations`, defaults to "relu". The activation function of feedforward network. - layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer + layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer normalization components. - kernel_initializer: tf.keras.initializers initializer, defaults to - "glorot_uniform". The kernel initializer for the dense layers. - bias_initializer: tf.keras.initializers initializer, defaults to - "zeros". The bias initializer for the dense layers. + kernel_initializer: "string" or `tf.keras.initializers` initializer, + defaults to "glorot_uniform". The kernel initializer for the dense + layers. + bias_initializer: "string" or `tf.keras.initializers` initializer, + defaults to "zeros". The bias initializer for the dense layers. name: string, defaults to None. The name of the layer. **kwargs: other keyword arguments.