Skip to content

Conversation

abodinier
Copy link
Contributor

@abodinier abodinier commented Feb 27, 2023

Function merge_padding_and_attention_mask does not return an output with the desired shape when both padding and attention masks are given.

See the linked issue here: #783 (comment)

Currently, the TransformerEncoder layers breaks if we call it with both padding and attention masks.

Changes:

  • Remove duplicated tf.newaxis
  • Create a check function to control the shapes of the padding and attention masks
  • Update unit tests to make sure that the error is raised if we give merge_padding_and_attention_mask masks with bad shapes.

@google-cla
Copy link

google-cla bot commented Feb 27, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@abodinier abodinier changed the title Fix merge padding and attention mask Function merge_padding_and_attention_mask does not return an output with the desired shape when both padding and attention masks are given Feb 27, 2023
@atharvapurdue
Copy link
Contributor

atharvapurdue commented Mar 1, 2023

Hi @abodinier , great job!
I think there is small change required. merged_mask variable is assigned but never used. The check should pass after that

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Some minor comments, this also still has a few minor formatting errors (./shell/format.sh), and needs to be merged with master.

Thanks!

@mattdangerw mattdangerw merged commit 4c94a0d into keras-team:master Mar 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants