diff --git a/keras_nlp/layers/__init__.py b/keras_nlp/layers/__init__.py index 0397ce6be2..d2f5b57895 100644 --- a/keras_nlp/layers/__init__.py +++ b/keras_nlp/layers/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from keras_nlp.layers.fnet_encoder import FNetEncoder +from keras_nlp.layers.sine_position_encoding import SinePositionEncoding from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.layers.transformer_encoder import TransformerEncoder diff --git a/keras_nlp/layers/sine_position_encoding.py b/keras_nlp/layers/sine_position_encoding.py new file mode 100644 index 0000000000..4112ca8a18 --- /dev/null +++ b/keras_nlp/layers/sine_position_encoding.py @@ -0,0 +1,94 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sinusoidal position embedding layer.""" + +import tensorflow as tf +from tensorflow import keras + + +class SinePositionEncoding(keras.layers.Layer): + """Sinusoidal positional encoding layer. + + This layer calculates the position encoding as a mix of sine and cosine + functions with geometrically increasing wavelengths. Defined and formulized + in [Attention is All You Need](https://arxiv.org/abs/1706.03762). + + Takes as input an embedded token tensor. The input must have shape + [batch_size, sequence_length, feature_size]. This layer will return a + positional encoding the same size as the embedded token tensor, which + can be added directly to the embedded token tensor. + + Args: + max_wavelength: The maximum angular wavelength of the sine/cosine + curves, as described in Attention is All You Need. Defaults to + 10000. + + Example: + ```python + # create a simple embedding layer with sinusoidal positional encoding + seq_len = 100 + vocab_size = 1000 + embedding_dim = 32 + inputs = keras.Input((seq_len,), dtype=tf.float32) + embedding = keras.layers.Embedding( + input_dim=vocab_size, output_dim=embedding_dim + )(inputs) + positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding) + outputs = embedding + positional_encoding + ``` + + References: + [Attention is All You Need](https://arxiv.org/abs/1706.03762) + """ + + def __init__( + self, + max_wavelength=10000, + **kwargs, + ): + super().__init__(**kwargs) + self.max_wavelength = max_wavelength + + def call(self, inputs): + input_shape = tf.shape(inputs) + # length of sequence is the second last dimension of the inputs + seq_length = input_shape[-2] + hidden_size = input_shape[-1] + position = tf.cast(tf.range(seq_length), self.compute_dtype) + min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype) + timescales = tf.pow( + min_freq, + tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype) + / tf.cast(hidden_size, self.compute_dtype), + ) + angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0) + # even indices are sine, odd are cosine + cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype) + sin_mask = 1 - cos_mask + # embedding shape is [seq_length, hidden_size] + positional_encodings = ( + tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask + ) + + return tf.broadcast_to(positional_encodings, input_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "max_wavelength": self.max_wavelength, + } + ) + return config diff --git a/keras_nlp/layers/sine_position_encoding_test.py b/keras_nlp/layers/sine_position_encoding_test.py new file mode 100644 index 0000000000..5a6d28d47a --- /dev/null +++ b/keras_nlp/layers/sine_position_encoding_test.py @@ -0,0 +1,122 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Sinusoidal Positional encoding.""" + + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.layers import sine_position_encoding + + +class SinePositionEncodingTest(tf.test.TestCase): + def test_valid_call(self): + pos_encoding = sine_position_encoding.SinePositionEncoding() + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + pos_encoding, + ] + ) + input = tf.random.uniform(shape=[2, 4, 6]) + model(input) + + def test_static_layer_output_shape(self): + pos_encoding = sine_position_encoding.SinePositionEncoding() + seq_length = 100 + hidden_size = 32 + inputs = keras.Input(shape=(seq_length, hidden_size)) + outputs = pos_encoding(inputs) + + # When using static positional encoding shapes, the output is expected + # to be the same as the input shape in all dimensions. + expected_output_shape = [None, seq_length, hidden_size] + self.assertEqual(expected_output_shape, outputs.shape.as_list()) + + def test_dynamic_layer_output_shape(self): + pos_encoding = sine_position_encoding.SinePositionEncoding() + hidden_size = 32 + inputs = keras.Input(shape=(None, hidden_size)) + outputs = pos_encoding(inputs) + + # When using dynamic positional encoding shapes, the output is expected + # to be the same as the input shape in all dimensions but may be None. + expected_output_shape = [None, None, hidden_size] + self.assertEqual(expected_output_shape, outputs.shape.as_list()) + + # do multi dimension before sequence length + def test_multi_dimension_layer_output_shape(self): + pos_encoding = sine_position_encoding.SinePositionEncoding() + seq_length = 100 + hidden_size = 32 + inputs = keras.Input(shape=(None, seq_length, hidden_size)) + outputs = pos_encoding(inputs) + + # When using muliple dimensions before sequence length, the output is + # expected to be the same as the input shape in all dimensions. + expected_output_shape = [None, None, seq_length, hidden_size] + self.assertEqual(expected_output_shape, outputs.shape.as_list()) + + def test_output_correct_values(self): + pos_encoding = sine_position_encoding.SinePositionEncoding() + model = keras.Sequential( + [ + keras.Input(shape=(4, 6)), + pos_encoding, + ] + ) + input = tf.random.uniform(shape=[1, 4, 6]) + output = model(input) + + # comapre position encoding values for position 0 and 3 + expected_encoding_position_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0] + expected_encoding_position_3 = [ + 0.14112, + -0.9899925, + 0.1387981, + 0.9903207, + 0.00646326, + 0.99997914, + ] + self.assertAllClose(output[0, 0, :], expected_encoding_position_0) + self.assertAllClose(output[0, 3, :], expected_encoding_position_3) + + def test_get_config_and_from_config(self): + pos_encoding = sine_position_encoding.SinePositionEncoding( + max_wavelength=1000, + ) + config = pos_encoding.get_config() + expected_config_subset = { + "max_wavelength": 1000, + } + self.assertEqual(config, {**config, **expected_config_subset}) + restored_pos_encoding = ( + sine_position_encoding.SinePositionEncoding.from_config(config) + ) + self.assertEqual( + restored_pos_encoding.get_config(), + {**config, **expected_config_subset}, + ) + + def test_float16_dtype(self): + pos_encoding = sine_position_encoding.SinePositionEncoding( + dtype="float16" + ) + seq_length = 100 + hidden_size = 32 + inputs = keras.Input(shape=(seq_length, hidden_size)) + outputs = pos_encoding(inputs) + + # output dtype for this layer should be tf.float16. + self.assertEqual(outputs.dtype, tf.float16)