Skip to content

Conversation

jessechancy
Copy link
Contributor

Made encoder sequence an optional parameter, added testing for this change.

@google-cla
Copy link

google-cla bot commented Jun 1, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Dropped some comments.

)

if encoder_sequence is not None:
# Encoder-decoder attention.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this comment over self._encoder_decoder_attention_layer

self._feedforward_layernorm,
)
else:
# Skip Encoder-Decoder attention, Feedforward.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing - the comma "," could suggest Feedforward is skipped as well.

Maybe just say "# Skip Encoder-Decoder attention if no encoder_sequence is provided."?

output = decoder(decoder_input)
model = keras.Model(
inputs=decoder_input,
outputs = output,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the space surrounding "=" => outputs=output

use_causal_mask=True,
)

def test_valid_call_without_encoder_with_mask(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can delete this test case because it is covered by test_valid_call_with_mask

@jessechancy
Copy link
Contributor Author

Right now there are two things that indicate decoder only but can conflict. The first one is the decoder_only attribute that is passed in an initialized. The second is implicit in the optional parameter of encoder_sequence. These are the behaviors right now for these two things:

  1. decoder_only = True and encoder_seq is not None: ignore building encoder layers and don't run encoder_seq
  2. decoder_only = True and encoder_seq is None: ignore building encoder layers and don't run encoder_seq
  3. decoder_only = False and encoder_seq is not None: build encoder layers and run encoder_seq through model
  4. decoder_only = False and encoder_seq is None: build encoder layers but don't use layers and don't run encoder_seq through model

I added this comment to the docstring
"""
If decoder_only is set to True, the encoder layer would not be built,
the encoder output would not be used in TransformerDecoder and ignored.
If decoder_only is set to False, but no encoder sequence is provided,
TransformerDecoder would run as decoder only.
"""

Let me know if this is enough to explain and whether it is intuitive, or if I should make any changes, thanks!

@chenmoneygithub
Copy link
Contributor

@jessechancy Yea, we should throw an explicit error message to our users if the two places contradict:

  1. decoder_only = True and encoder_seq is not None: raise ValueError("encoder_seq should be None for decoder-only models").
  2. decoder_only = True and encoder_seq is None: allowed case.
  3. decoder_only = False and encoder_seq is not None: allowed case.
  4. decoder_only = False and encoder_seq is None: raise ValueError("encoder_seq is empty ...")

@mattdangerw Does this look good to you?

@jessechancy
Copy link
Contributor Author

jessechancy commented Jun 7, 2022

self._feedforward_layernorm,
)

if self._encoder_decoder_attention_layer is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor comment. It might be nice if you rename the self_attended variable to attention_output, and do something like this.

attention_output = self._add_and_norm(...)

if encoder_sequence is not None:
    ... cross attention ...
    attention_output = self._add_and_norm(...)

feed_forward_output = self._feed_forward(attention_output)
return self._add_and_norm(...)

So basically bring this back to the single return statement. As a reader, that would make it much clearer how the computation is flowing overall with and without encoder_sequence.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great. Left a few comments, mostly minor

[guide](https://keras.io/guides/understanding_masking_and_padding/)
for more details.
If decoder_only is set to True, the encoder layer would not be built,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have removed this argument. We should remove docs too.

We should update the class level docs with a few things

  • In the second paragraph about masking. Add as a first sentence, "This layer will always apply a causal mask to the decoder attention layer."
  • Add a new paragraph. Some suggested text below:
This layer can be called with with either one or two inputs as follows:

 - `layer(decoder_sequence)`: no cross-attention will be built into the decoder
   block. This is useful when building a "decoder-only" transformer such as GPT-2.
 - `layer(decoder_sequence, encoder_sequence)`: cross-attention will be built into
   the encoder block. This is useful when building an "encoder-decoder" transformer,
   such as the original transformer model described in Attention is All You Need.

defaults to "zeros". The bias initializer for
the dense and multiheaded attention layers.
name: string, defaults to None. The name of the layer.
decoder_only: bool, defaults to False. If True, only the decoder layers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

self.supports_masking = True

def _build(self, input_shape):
def _build(self, input_shape, cross_attention):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe include_cross_attention, so it is more obvious this is a boolean value?

raise ValueError(
f"The number of call arguments to "
f"`keras_nlp.layers.TransformerDecoder` should not change."
f"\nUse `layer(decoder_sequence, encoder_sequence)` to "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove all the \n in both error messages

encoder_sequence, encoder_padding_mask, encoder_attention_mask
)
# Encoder-decoder attention.
encoder_decoder_attended = self._encoder_decoder_attention_layer(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's clean up some variable names

_encoder_decoder_attention_layer -> _cross_attention_layer
_enc_dec_attentiondropout -> _cross_attention_dropout
_enc_dec_attention_layernorm -> _cross_attention_layernorm
encoder_decoder_attended -> cross_attended

output = decoder(encoder_input, decoder_input)
output = decoder(decoder_input, encoder_input)
# should raise ValueError if encoder_input is not provided
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the try block. You can add a separate test for these using self.assertRaises(ValueError). There's other examples in this test file.

use_causal_mask=True,
output = decoder(decoder_input)
# should raise ValueError if encoder_input is provided
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, remove the try, catch

self.assertGreater(len(grad), 1)
optimizer.apply_gradients(zip(grad, model.trainable_variables))

def test_one_training_step_of_transformer_without_encoder(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without_encoder -> without_cross_attention

here and elsewhere

model_output = model(decoder_sequence)
loaded_model_output = loaded_model(decoder_sequence)
self.assertAllClose(model_output, loaded_model_output)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra newlines

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks.

One nit, and I think there are some format issues still.

decoder block. This is useful when building a "decoder-only"
transformer such as GPT-2.
`layer(decoder_sequence, encoder_sequence)`: cross-attention will be
built into the encoder block. This is useful when building an
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder block -> decoder block

@chenmoneygithub chenmoneygithub merged commit f9f52ca into keras-team:master Jun 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants