Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# limitations under the License.

from keras_nlp.layers.fnet_encoder import FNetEncoder
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.layers.transformer_encoder import TransformerEncoder
94 changes: 94 additions & 0 deletions keras_nlp/layers/sine_position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sinusoidal position embedding layer."""

import tensorflow as tf
from tensorflow import keras


class SinePositionEncoding(keras.layers.Layer):
"""Sinusoidal positional encoding layer.

This layer calculates the position encoding as a mix of sine and cosine
functions with geometrically increasing wavelengths. Defined and formulized
in [Attention is All You Need](https://arxiv.org/abs/1706.03762).

Takes as input an embedded token tensor. The input must have shape
[batch_size, sequence_length, feature_size]. This layer will return a
positional encoding the same size as the embedded token tensor, which
can be added directly to the embedded token tensor.

Args:
max_wavelength: The maximum angular wavelength of the sine/cosine
curves, as described in Attention is All You Need. Defaults to
10000.

Example:
```python
# create a simple embedding layer with sinusoidal positional encoding
seq_len = 100
vocab_size = 1000
embedding_dim = 32
inputs = keras.Input((seq_len,), dtype=tf.float32)
embedding = keras.layers.Embedding(
input_dim=vocab_size, output_dim=embedding_dim
)(inputs)
positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding)
outputs = embedding + positional_encoding
```

References:
[Attention is All You Need](https://arxiv.org/abs/1706.03762)
"""

def __init__(
self,
max_wavelength=10000,
**kwargs,
):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength

def call(self, inputs):
input_shape = tf.shape(inputs)
# length of sequence is the second last dimension of the inputs
seq_length = input_shape[-2]
hidden_size = input_shape[-1]
position = tf.cast(tf.range(seq_length), self.compute_dtype)
min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
timescales = tf.pow(
min_freq,
tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
/ tf.cast(hidden_size, self.compute_dtype),
)
angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)
# even indices are sine, odd are cosine
cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)
sin_mask = 1 - cos_mask
# embedding shape is [seq_length, hidden_size]
positional_encodings = (
tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
)

return tf.broadcast_to(positional_encodings, input_shape)

def get_config(self):
config = super().get_config()
config.update(
{
"max_wavelength": self.max_wavelength,
}
)
return config
122 changes: 122 additions & 0 deletions keras_nlp/layers/sine_position_encoding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Sinusoidal Positional encoding."""


import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers import sine_position_encoding


class SinePositionEncodingTest(tf.test.TestCase):
def test_valid_call(self):
pos_encoding = sine_position_encoding.SinePositionEncoding()
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
pos_encoding,
]
)
input = tf.random.uniform(shape=[2, 4, 6])
model(input)

def test_static_layer_output_shape(self):
pos_encoding = sine_position_encoding.SinePositionEncoding()
seq_length = 100
hidden_size = 32
inputs = keras.Input(shape=(seq_length, hidden_size))
outputs = pos_encoding(inputs)

# When using static positional encoding shapes, the output is expected
# to be the same as the input shape in all dimensions.
expected_output_shape = [None, seq_length, hidden_size]
self.assertEqual(expected_output_shape, outputs.shape.as_list())

def test_dynamic_layer_output_shape(self):
pos_encoding = sine_position_encoding.SinePositionEncoding()
hidden_size = 32
inputs = keras.Input(shape=(None, hidden_size))
outputs = pos_encoding(inputs)

# When using dynamic positional encoding shapes, the output is expected
# to be the same as the input shape in all dimensions but may be None.
expected_output_shape = [None, None, hidden_size]
self.assertEqual(expected_output_shape, outputs.shape.as_list())

# do multi dimension before sequence length
def test_multi_dimension_layer_output_shape(self):
pos_encoding = sine_position_encoding.SinePositionEncoding()
seq_length = 100
hidden_size = 32
inputs = keras.Input(shape=(None, seq_length, hidden_size))
outputs = pos_encoding(inputs)

# When using muliple dimensions before sequence length, the output is
# expected to be the same as the input shape in all dimensions.
expected_output_shape = [None, None, seq_length, hidden_size]
self.assertEqual(expected_output_shape, outputs.shape.as_list())

def test_output_correct_values(self):
pos_encoding = sine_position_encoding.SinePositionEncoding()
model = keras.Sequential(
[
keras.Input(shape=(4, 6)),
pos_encoding,
]
)
input = tf.random.uniform(shape=[1, 4, 6])
output = model(input)

# comapre position encoding values for position 0 and 3
expected_encoding_position_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
expected_encoding_position_3 = [
0.14112,
-0.9899925,
0.1387981,
0.9903207,
0.00646326,
0.99997914,
]
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)

def test_get_config_and_from_config(self):
pos_encoding = sine_position_encoding.SinePositionEncoding(
max_wavelength=1000,
)
config = pos_encoding.get_config()
expected_config_subset = {
"max_wavelength": 1000,
}
self.assertEqual(config, {**config, **expected_config_subset})
restored_pos_encoding = (
sine_position_encoding.SinePositionEncoding.from_config(config)
)
self.assertEqual(
restored_pos_encoding.get_config(),
{**config, **expected_config_subset},
)

def test_float16_dtype(self):
pos_encoding = sine_position_encoding.SinePositionEncoding(
dtype="float16"
)
seq_length = 100
hidden_size = 32
inputs = keras.Input(shape=(seq_length, hidden_size))
outputs = pos_encoding(inputs)

# output dtype for this layer should be tf.float16.
self.assertEqual(outputs.dtype, tf.float16)