New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TF LED/Longformer attentions computation #10007
Conversation
attn_probs = tf.where( | ||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand this change here. The correct shape is given by attn_probs
-> I don't understand why we cannot just use shape_list(attn_probs)
? IMO, something like:
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype
)
should work, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it's better to make the dtype dependent on the type of attn_probs
I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need here, the default dtype of tf.zeros
is always float (float16 if AMP is activated, or float32 if not).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the test pass! Could you give some more background on why
a) tf.tile should be used instead of tf.broadcast_to
&
b) why we cannot simply use the shape of attn_probs
since we apply the mask on attn_probs itself? So we know that shape_list(masked_index) == shape_list(attn_probs)
There are two reasons for this, the first one is because
This part is a bit tricky to explain. The issue here was that I don't know if it is clear enough or not. Don't hesitate to tell me if there is something you don't get. |
Thanks for the explanation - just tried it out and cool to see that your change fixes the test! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the slow tests pass, LGTM
The entire list of slow tests are ok! |
@sgugger Feel free to merge if it looks ok for you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
What does this PR do?
This PR fixes the test
test_saved_model_with_attentions_output
for TF Longformer and LED that was failing due to an issue in computing some shapes in the attentions.All the slow tests are now passing 馃帀