Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 3 additions & 67 deletions examples/bert/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import tensorflow as tf
from tensorflow import keras

import keras_nlp


def make_attention_mask(inputs, mask):
"""Make a 3D attention mask from a 2D input mask.
Expand Down Expand Up @@ -331,72 +333,6 @@ def call(self, query, key_value=None, attention_mask=None):
return self._output_layer_norm(layer_output + attention_output)


class PositionEmbedding(keras.layers.Layer):
"""Creates a positional embedding.

Example:
```python
position_embedding = PositionEmbedding(max_length=100)
inputs = keras.Input((100, 32), dtype=tf.float32)
outputs = position_embedding(inputs)
```


Args:
max_length: The maximum size of the dynamic sequence.
initializer: The initializer to use for the embedding weights. Defaults
to "glorot_uniform".
seq_axis: The axis of the input tensor where we add the embeddings.

Reference: This layer creates a positional embedding as described in
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805).
"""

def __init__(
self, max_length, initializer="glorot_uniform", seq_axis=1, **kwargs
):

super().__init__(**kwargs)
if max_length is None:
raise ValueError("`max_length` must be an Integer, not `None`.")
self._max_length = max_length
self._initializer = keras.initializers.get(initializer)
self._seq_axis = seq_axis

def get_config(self):
config = {
"max_length": self._max_length,
"initializer": keras.initializers.serialize(self._initializer),
"seq_axis": self._seq_axis,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

def build(self, input_shape):
dimension_list = input_shape.as_list()
width = dimension_list[-1]
weight_sequence_length = self._max_length

self._position_embeddings = self.add_weight(
"embeddings",
shape=[weight_sequence_length, width],
initializer=self._initializer,
)

super().build(input_shape)

def call(self, inputs):
input_shape = tf.shape(inputs)
actual_seq_len = input_shape[self._seq_axis]
position_embeddings = self._position_embeddings[:actual_seq_len, :]
new_shape = [1 for _ in inputs.get_shape().as_list()]
new_shape[self._seq_axis] = actual_seq_len
new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
position_embeddings = tf.reshape(position_embeddings, new_shape)
return tf.broadcast_to(position_embeddings, input_shape)


# TODO(mattdangerw): This class is needed for TPU friendly embeddings, we should
# remove it entirely and fix tf.keras.layers.Embedding as needed.
class OnDeviceEmbedding(keras.layers.Layer):
Expand Down Expand Up @@ -546,7 +482,7 @@ def __init__(
name="word_embeddings",
)

self._position_embedding_layer = PositionEmbedding(
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name="position_embedding",
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.layers.fnet_encoder import FNetEncoder
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.sine_position_encoding import SinePositionEncoding
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.layers.transformer_encoder import TransformerEncoder
113 changes: 113 additions & 0 deletions keras_nlp/layers/position_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Position embedding implementation based on `keras.layers.Layer`."""

import tensorflow as tf
from tensorflow import keras

SEQUENCE_AXIS = -2


class PositionEmbedding(keras.layers.Layer):
"""Creates a layer which learns a position embedding for inputs sequences.

This class assumes that in the input tensor, the last dimension corresponds
to the features, and the dimension before the last corresponds to the
sequence.

This class accepts `RaggedTensor`s as inputs to process batches of sequences
of different lengths. The one ragged dimension must be the dimension that
corresponds to the sequence, that is, the penultimate dimension.

Args:
max_length: The maximum length of the dynamic sequence.
initializer: The initializer to use for the embedding weights. Defaults
to "glorot_uniform".
seq_axis: The axis of the input tensor where we add the embeddings.

Example:
```python
token_embeddings = layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
position_embeddings = keras_nlp.layers.PositionEmbedding(
max_length=max_length
)

embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(embedded_tokens)
outputs = embedded_tokens + embedded_positions
```

Reference:
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805).
"""

def __init__(
self,
max_length,
initializer="glorot_uniform",
**kwargs,
):
super().__init__(**kwargs)
if max_length is None:
raise ValueError("`max_length` must be an Integer, not `None`.")
self.max_length = int(max_length)
self.initializer = keras.initializers.get(initializer)

def get_config(self):
config = super().get_config()
config.update(
{
"max_length": self.max_length,
"initializer": keras.initializers.serialize(self.initializer),
}
)
return config

def build(self, input_shape):
feature_size = input_shape[-1]
self.position_embeddings = self.add_weight(
"embeddings",
shape=[self.max_length, feature_size],
initializer=self.initializer,
trainable=True,
)

super().build(input_shape)

def call(self, inputs):
if isinstance(inputs, tf.RaggedTensor):
bounding_shape = inputs.bounding_shape()
position_embeddings = self._trim_and_broadcast_position_embeddings(
bounding_shape,
)
# then apply row lengths to recreate the same ragged shape as inputs
return tf.RaggedTensor.from_tensor(
position_embeddings,
inputs.nested_row_lengths(),
)
else:
return self._trim_and_broadcast_position_embeddings(
tf.shape(inputs),
)

def _trim_and_broadcast_position_embeddings(self, shape):
sequence_length = shape[SEQUENCE_AXIS]
# trim to match the length of the sequence
position_embeddings = self.position_embeddings[:sequence_length, :]
# then broadcast to add the missing dimensions to match "shape"
return tf.broadcast_to(position_embeddings, shape)
Loading