-
Notifications
You must be signed in to change notification settings - Fork 309
Add FNet Encoder Layer #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
59cc53c
Add rough code for FNet Encoder
abheesht17 e520e4d
Format code
abheesht17 a005ae2
Minor doc-string changes
abheesht17 7bd4a8a
Merge branch 'master' into fnet-encoder
abheesht17 a5575ff
Format __init__.py
abheesht17 c75e09f
Address review comments - 1
abheesht17 6d85073
Add detailed comment about padding masks
abheesht17 514cee1
Merge branch 'keras-team:master' into fnet-encoder
abheesht17 63b7f22
Add kernel and bias initialisers
abheesht17 debafb7
Add unit tests for the layer
abheesht17 669ca9c
Address review comments - 2
abheesht17 f213472
Address review comments - 3
abheesht17 0a16b30
Address review comments - 4
abheesht17 ada4eb1
Minor change
abheesht17 2def525
Merge branch 'keras-team:master' into fnet-encoder
abheesht17 8ce1d2e
Correct doc-string
abheesht17 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """FNet encoder block implementation based on `keras.layers.Layer`.""" | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
|
|
||
| class FNetEncoder(keras.layers.Layer): | ||
| """FNet encoder. | ||
|
|
||
| This class follows the architecture of FNet encoder layer in paper | ||
| "FNet: Mixing Tokens with Fourier Transforms" | ||
| (https://arxiv.org/abs/2105.03824). Users can instantiate multiple instances | ||
| of this class to stack up the encoder. | ||
|
|
||
| Note on padding: In the official FNet code, padding tokens are added to the | ||
| the input. However, the padding masks are deleted, i.e., mixing of | ||
| all tokens is done. This is because certain frequencies will be zeroed | ||
| out if we apply padding masks in every encoder layer. Hence, we don't | ||
| take padding mask as input in the call() function. | ||
|
|
||
| Args: | ||
| intermediate_dim: int. The hidden size of feedforward network. | ||
| dropout: float, defaults to 0. The dropout value, applied in the | ||
| feedforward network. | ||
| activation: string or `tf.keras.activations`, defaults to "relu". The | ||
| activation function of feedforward network. | ||
| layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer | ||
| normalization components. | ||
| kernel_initializer: "string" or `tf.keras.initializers` initializer, | ||
| defaults to "glorot_uniform". The kernel initializer for the dense | ||
| layers. | ||
| bias_initializer: "string" or `tf.keras.initializers` initializer, | ||
| defaults to "zeros". The bias initializer for the dense layers. | ||
| name: string, defaults to None. The name of the layer. | ||
| **kwargs: other keyword arguments. | ||
|
|
||
| Examples: | ||
|
|
||
| ```python | ||
| # Create a single FNet encoder layer. | ||
| encoder = keras_nlp.layers.FNetEncoder( | ||
| intermediate_dim=64) | ||
|
|
||
| # Create a simple model containing the encoder. | ||
| input = tf.keras.Input(shape=[4, 6]) | ||
| output = encoder(input) | ||
| model = tf.keras.Model(inputs=input, outputs=output) | ||
|
|
||
| # Call encoder on the inputs. | ||
| input_data = tf.random.uniform(shape=[1, 10, 64]) | ||
| output = model(input_data) | ||
| ``` | ||
|
|
||
| References: | ||
| [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| intermediate_dim, | ||
| dropout=0, | ||
| activation="relu", | ||
| layer_norm_epsilon=1e-5, | ||
| kernel_initializer="glorot_uniform", | ||
| bias_initializer="zeros", | ||
| name=None, | ||
| **kwargs | ||
| ): | ||
| super().__init__(name=name, **kwargs) | ||
| self.intermediate_dim = intermediate_dim | ||
| self.dropout = dropout | ||
| self.activation = keras.activations.get(activation) | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
| self.kernel_initializer = keras.initializers.get(kernel_initializer) | ||
| self.bias_initializer = keras.initializers.get(bias_initializer) | ||
|
|
||
| def build(self, input_shape): | ||
| # Create layers based on input shape. | ||
| feature_size = input_shape[-1] | ||
|
|
||
| # Layer Norm layers. | ||
| self._mixing_layer_norm = keras.layers.LayerNormalization( | ||
| epsilon=self.layer_norm_epsilon | ||
| ) | ||
| self._output_layer_norm = keras.layers.LayerNormalization( | ||
| epsilon=self.layer_norm_epsilon | ||
| ) | ||
|
|
||
| # Feedforward layers. | ||
| self._intermediate_dense = keras.layers.Dense( | ||
| self.intermediate_dim, | ||
| activation=self.activation, | ||
| kernel_initializer=self.kernel_initializer, | ||
| bias_initializer=self.bias_initializer, | ||
| ) | ||
| self._output_dense = keras.layers.Dense( | ||
| feature_size, | ||
| kernel_initializer=self.kernel_initializer, | ||
| bias_initializer=self.bias_initializer, | ||
| ) | ||
| self._output_dropout = keras.layers.Dropout(rate=self.dropout) | ||
|
|
||
| def call(self, inputs): | ||
| """Forward pass of the FNetEncoder. | ||
|
|
||
| Args: | ||
| inputs: a Tensor. The input data to TransformerEncoder, should be | ||
| of shape [batch_size, sequence_length, feature_dim]. | ||
|
|
||
| Returns: | ||
| A Tensor of the same shape as the `inputs`. | ||
| """ | ||
|
|
||
| def fourier_transform(input): | ||
| # Apply FFT on the input and take the real part. | ||
| # Before we apply fourier transform, let's convert the dtype of the | ||
| # input tensor to complex64. | ||
| input = tf.cast(input, tf.complex64) | ||
| mixing_output = tf.math.real(tf.signal.fft2d(input)) | ||
| return mixing_output | ||
|
|
||
| def add_and_norm(input1, input2, norm_layer): | ||
| return norm_layer(input1 + input2) | ||
|
|
||
| def feed_forward(input): | ||
| x = self._intermediate_dense(input) | ||
| x = self._output_dense(x) | ||
| return self._output_dropout(x) | ||
|
|
||
| mixing_output = fourier_transform(inputs) | ||
|
|
||
| mixing_output = add_and_norm( | ||
| inputs, mixing_output, self._mixing_layer_norm | ||
| ) | ||
|
|
||
| feed_forward_output = feed_forward(mixing_output) | ||
|
|
||
| x = add_and_norm( | ||
| mixing_output, feed_forward_output, self._output_layer_norm | ||
| ) | ||
| return x | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "dropout": self.dropout, | ||
| "activation": keras.activations.serialize(self.activation), | ||
| "layer_norm_epsilon": self.layer_norm_epsilon, | ||
| "kernel_initializer": keras.initializers.serialize( | ||
| self.kernel_initializer | ||
| ), | ||
| "bias_initializer": keras.initializers.serialize( | ||
| self.bias_initializer | ||
| ), | ||
| } | ||
| ) | ||
| return config | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # Copyright 2022 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Tests for FNet Encoder.""" | ||
|
|
||
| import os | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import keras | ||
|
|
||
| from keras_nlp.layers import fnet_encoder | ||
|
|
||
|
|
||
| class FNetEncoderTest(tf.test.TestCase): | ||
| def test_valid_call(self): | ||
| encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) | ||
| model = keras.Sequential( | ||
| [ | ||
| keras.Input(shape=(4, 6)), | ||
| encoder, | ||
| ] | ||
| ) | ||
| input = tf.random.uniform(shape=[2, 4, 6]) | ||
| model(input) | ||
|
|
||
| def test_get_config_and_from_config(self): | ||
| encoder = fnet_encoder.FNetEncoder( | ||
| intermediate_dim=4, | ||
| kernel_initializer="HeNormal", | ||
| bias_initializer="Zeros", | ||
| ) | ||
| config = encoder.get_config() | ||
| expected_config_subset = { | ||
| "intermediate_dim": 4, | ||
| "dropout": 0, | ||
| "activation": "relu", | ||
| "layer_norm_epsilon": 1e-5, | ||
| "kernel_initializer": keras.initializers.serialize( | ||
| keras.initializers.HeNormal() | ||
| ), | ||
| "bias_initializer": keras.initializers.serialize( | ||
| keras.initializers.Zeros() | ||
| ), | ||
| } | ||
| self.assertEqual(config, {**config, **expected_config_subset}) | ||
|
|
||
| restored_encoder = fnet_encoder.FNetEncoder.from_config( | ||
| config, | ||
| ) | ||
| self.assertEqual( | ||
| restored_encoder.get_config(), {**config, **expected_config_subset} | ||
| ) | ||
|
|
||
| def test_value_error_when_invalid_kernel_initializer(self): | ||
| with self.assertRaises(ValueError): | ||
| fnet_encoder.FNetEncoder( | ||
| intermediate_dim=4, | ||
| dropout=0.5, | ||
| kernel_initializer="Invalid", | ||
| ) | ||
|
|
||
| def test_one_training_step_of_fnet_encoder(self): | ||
| encoder = fnet_encoder.FNetEncoder(intermediate_dim=4) | ||
| inputs = keras.Input(shape=(4, 6)) | ||
| x = encoder(inputs) | ||
| x = keras.layers.Dense(1, activation="sigmoid")(x) | ||
| model = keras.Model(inputs=inputs, outputs=x) | ||
|
|
||
| data = tf.random.uniform(shape=[2, 4, 6]) | ||
| label = tf.cast(data[:, :, 0] >= 0.5, dtype=tf.int32) | ||
|
|
||
| loss_fn = keras.losses.BinaryCrossentropy(from_logits=False) | ||
| optimizer = keras.optimizers.Adam() | ||
| with tf.GradientTape() as tape: | ||
| pred = model(data) | ||
| loss = loss_fn(label, pred) | ||
| grad = tape.gradient(loss, model.trainable_variables) | ||
| self.assertGreater(len(grad), 1) | ||
| optimizer.apply_gradients(zip(grad, model.trainable_variables)) | ||
|
|
||
| def test_checkpointing_fnet_encoder(self): | ||
| encoder1 = fnet_encoder.FNetEncoder( | ||
| intermediate_dim=4, | ||
| ) | ||
|
|
||
| encoder2 = fnet_encoder.FNetEncoder( | ||
| intermediate_dim=4, | ||
| ) | ||
| data = tf.random.uniform(shape=[2, 4, 6]) | ||
| encoder1(data) | ||
| encoder2(data) | ||
| # The weights of encoder1 and encoder2 are different. | ||
| self.assertFalse( | ||
| all( | ||
| encoder1._output_dense.trainable_variables[0][0] | ||
| == encoder2._output_dense.trainable_variables[0][0] | ||
| ) | ||
| ) | ||
| checkpoint = tf.train.Checkpoint(encoder1) | ||
| checkpoint2 = tf.train.Checkpoint(encoder2) | ||
| temp_dir = self.get_temp_dir() | ||
| save_path = checkpoint.save(temp_dir) | ||
| checkpoint2.restore(save_path) | ||
|
|
||
| encoder1_output = encoder1(data) | ||
| encoder2_output = encoder2(data) | ||
| self.assertAllClose(encoder1_output, encoder2_output) | ||
|
|
||
| def test_save_model(self): | ||
| model = keras.Sequential( | ||
| [ | ||
| keras.Input(shape=(4, 6)), | ||
| fnet_encoder.FNetEncoder( | ||
| intermediate_dim=4, | ||
| ), | ||
| ] | ||
| ) | ||
| data = tf.random.uniform(shape=[2, 4, 6]) | ||
| model(data) | ||
| path = os.path.join(self.get_temp_dir(), "model") | ||
| model.save(path) | ||
| loaded_model = keras.models.load_model(path) | ||
|
|
||
| model_output = model(data) | ||
| loaded_model_output = loaded_model(data) | ||
| self.assertAllClose(model_output, loaded_model_output) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this choice of initializer the best practice for this layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FNet repo: https://github.com/google-research/google-research/blob/6fd7a1a0872e9d27c1f1764836bbf9048a7903e7/f_net/models.py#L39-L42
I believe it is
tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2)in Keras. Will change it to this.Intermediate Dense Layer
Output Dense Layer:
https://github.com/google-research/google-research/blob/master/f_net/layers.py#L73-L81
Should I set it as
tf.keras.initializers.RandomNormal(mean=0.0, stddev=2e-2)for all, or should I separate the bias initializer for the output dense layer and set that to"zeros"?