Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 95 additions & 52 deletions keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,23 @@ class TransformerDecoder(keras.layers.Layer):
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
can instantiate multiple instances of this class to stack up a decoder.

This layer will always apply a causal mask to the decoder attention layer.
This layer will correctly compute an attention mask from an implicit
Keras padding mask (for example, by passing `mask_zero=True` to a
`keras.layers.Embedding` layer). See the Masking and Padding
[guide](https://keras.io/guides/understanding_masking_and_padding/)
for more details.

This layer can be called with either one or two inputs. The number of inputs
must be consistent across all calls. The options are as follows:
`layer(decoder_sequence)`: no cross-attention will be built into the
decoder block. This is useful when building a "decoder-only"
transformer such as GPT-2.
`layer(decoder_sequence, encoder_sequence)`: cross-attention will be
built into the decoder block. This is useful when building an
"encoder-decoder" transformer, such as the original transformer
model described in Attention is All You Need.

Args:
intermediate_dim: int, the hidden size of feedforward network.
num_heads: int, the number of heads in MultiHeadAttention.
Expand Down Expand Up @@ -102,7 +113,7 @@ def __init__(
self._built = False
self.supports_masking = True

def _build(self, input_shape):
def _build(self, input_shape, include_cross_attention):
# Create layers based on input shape.
self._built = True
feature_size = input_shape[-1]
Expand All @@ -115,29 +126,37 @@ def _build(self, input_shape):
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
self._encoder_decoder_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self._attention_head_size,
value_dim=feature_size,
dropout=self.dropout,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)

self._decoder_attention_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self._enc_dec_attention_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)

self._cross_attention_layer = None
if include_cross_attention:
# Create layers for cross attention.
self._cross_attention_layer = keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self._attention_head_size,
value_dim=feature_size,
dropout=self.dropout,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)

self._cross_attention_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)

self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
)

self._feedforward_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)

self._self_attention_dropout = keras.layers.Dropout(rate=self.dropout)
self._enc_dec_attentiondropout = keras.layers.Dropout(
rate=self.dropout,
)

# First dense layer in the feedforward network, which maps input
# feauture size to dimension `self.intermediate_dim`.
self._intermediate_dense = keras.layers.Dense(
Expand Down Expand Up @@ -166,18 +185,20 @@ def _feed_forward(self, input):
def call(
self,
decoder_sequence,
encoder_sequence,
encoder_sequence=None,
decoder_padding_mask=None,
decoder_attention_mask=None,
encoder_padding_mask=None,
encoder_attention_mask=None,
use_causal_mask=False,
):
"""Forward pass of the TransformerDecoder.

Args:
decoder_sequence: a Tensor. The decoder input sequence.
encoder_sequence: a Tensor. The decoder input sequence.
encoder_sequence: a Tensor. The encoder input sequence. For decoder
only models (like GPT2), this should be left None. Once the
model is called once without an encoder_sequence, you cannot
call it again with encoder_sequence.
decoder_padding_mask: a boolean Tensor, the padding mask of decoder
sequence, must of shape [batch_size, decoder_sequence_length].
decoder_attention_mask: a boolean Tensor. Customized decoder
Expand All @@ -188,29 +209,45 @@ def call(
encoder_attention_mask: a boolean Tensor. Customized encoder
sequence mask, must of shape
[batch_size, encoder_sequence_length, encoder_sequence_length].
use_causal_mask: bool, defaults to False. If true, causal mask
(masking out future input) is applied on the decoder sequence.

Returns:
A Tensor of the same shape as the `decoder_sequence`.
"""
has_encoder_sequence = encoder_sequence is not None
if not self._built:
self._build(decoder_sequence.shape)
encoder_mask = merge_padding_and_attention_mask(
encoder_sequence, encoder_padding_mask, encoder_attention_mask
)
self._build(decoder_sequence.shape, has_encoder_sequence)

is_cross_attention = self._cross_attention_layer is not None
if not is_cross_attention and has_encoder_sequence:
raise ValueError(
"The number of call arguments to "
"`keras_nlp.layers.TransformerDecoder` should not change. "
"Use `layer(decoder_sequence, encoder_sequence)` to "
"build a layer with cross attention, or "
"`layer(decoder_sequence)` to build a layer without. "
"This layer has been built without cross attention, but "
"you are trying to call it with encoder_sequence."
)
elif is_cross_attention and not has_encoder_sequence:
raise ValueError(
"The number of call arguments to "
"`keras_nlp.layers.TransformerDecoder` should not change. "
"Use `layer(decoder_sequence, encoder_sequence)` to "
"build a layer with cross attention, or "
"`layer(decoder_sequence)` to build a layer without. "
"This layer has been built with cross attention, but "
"you did not provide encoder_sequence."
)
decoder_mask = merge_padding_and_attention_mask(
decoder_sequence, decoder_padding_mask, decoder_attention_mask
)
if use_causal_mask:
causal_mask = tf.cast(
compute_causal_mask(decoder_sequence),
dtype=tf.int32,
)
if decoder_mask is None:
decoder_mask = causal_mask
else:
decoder_mask = tf.minimum(decoder_mask, causal_mask)
causal_mask = tf.cast(
compute_causal_mask(decoder_sequence),
dtype=tf.int32,
)
if decoder_mask is None:
decoder_mask = causal_mask
else:
decoder_mask = tf.minimum(decoder_mask, causal_mask)

# Decoder input self-attention.
self_attended = self._self_attention_layer(
Expand All @@ -220,28 +257,34 @@ def call(
attention_mask=decoder_mask,
)
self_attended = self._self_attention_dropout(self_attended)
self_attended = self._add_and_norm(
attention_output = self._add_and_norm(
self_attended, decoder_sequence, self._decoder_attention_layernorm
)
# Encoder-decoder attention.
encoder_decoder_attended = self._encoder_decoder_attention_layer(
query=self_attended,
value=encoder_sequence,
key=encoder_sequence,
attention_mask=encoder_mask,
)
encoder_decoder_attended = self._enc_dec_attentiondropout(
encoder_decoder_attended,
)
encoder_decoder_attended = self._add_and_norm(
encoder_decoder_attended,
self_attended,
self._enc_dec_attention_layernorm,
)

if self._cross_attention_layer is not None:
encoder_mask = merge_padding_and_attention_mask(
encoder_sequence, encoder_padding_mask, encoder_attention_mask
)
# Cross attention.
cross_attended = self._cross_attention_layer(
query=attention_output,
value=encoder_sequence,
key=encoder_sequence,
attention_mask=encoder_mask,
)
cross_attended = self._cross_attention_dropout(
cross_attended,
)
attention_output = self._add_and_norm(
cross_attended,
attention_output,
self._cross_attention_layernorm,
)

# Feedforward.
feed_forward_output = self._feed_forward(encoder_decoder_attended)
feed_forward_output = self._feed_forward(attention_output)
return self._add_and_norm(
encoder_decoder_attended,
attention_output,
feed_forward_output,
self._feedforward_layernorm,
)
Expand Down
Loading