Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF LED/Longformer attentions computation #10007

Merged
merged 9 commits into from Feb 10, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Feb 4, 2021

What does this PR do?

This PR fixes the test test_saved_model_with_attentions_output for TF Longformer and LED that was failing due to an issue in computing some shapes in the attentions.

All the slow tests are now passing 馃帀

attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this change here. The correct shape is given by attn_probs -> I don't understand why we cannot just use shape_list(attn_probs)? IMO, something like:

attn_probs = tf.where(
     masked_index,
     tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype
)

should work, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's better to make the dtype dependent on the type of attn_probs I think

Copy link
Contributor Author

@jplu jplu Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need here, the default dtype of tf.zeros is always float (float16 if AMP is activated, or float32 if not).

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the test pass! Could you give some more background on why
a) tf.tile should be used instead of tf.broadcast_to &
b) why we cannot simply use the shape of attn_probs since we apply the mask on attn_probs itself? So we know that shape_list(masked_index) == shape_list(attn_probs)

@jplu
Copy link
Contributor Author

jplu commented Feb 8, 2021

a) tf.tile should be used instead of tf.broadcast_to &

There are two reasons for this, the first one is because broadcast_to does reshape + tile, here we don't need to reshape, just tile is enough. The second reason is that broadcast_to is not compliant with ONNXRuntime.

b) why we cannot simply use the shape of attn_probs since we apply the mask on attn_probs itself? So we know that shape_list(masked_index) == shape_list(attn_probs)

This part is a bit tricky to explain. The issue here was that attn_probs was not always the same shape, if is_global_attn is True, then the shape of attn_probs is [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1], while if it equals False its shape is [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]. Now, because the shape is never potentially the same during the execution when run in graph mode, the pre-computed shape for attn_probs by the TF tracing was [batch_size, seq_len, self.num_heads, variable], where variable cannot be computed. The consequence of this was that attn_probs had never the proper shape at the end and creates a conflict in the tf.where. To solve this we had to also create a mask of a fixed shape that depends on is_global_attn.

I don't know if it is clear enough or not. Don't hesitate to tell me if there is something you don't get.

@patrickvonplaten
Copy link
Contributor

a) tf.tile should be used instead of tf.broadcast_to &

There are two reasons for this, the first one is because broadcast_to does reshape + tile, here we don't need to reshape, just tile is enough. The second reason is that broadcast_to is not compliant with ONNXRuntime.

b) why we cannot simply use the shape of attn_probs since we apply the mask on attn_probs itself? So we know that shape_list(masked_index) == shape_list(attn_probs)

This part is a bit tricky to explain. The issue here was that attn_probs was not always the same shape, if is_global_attn is True, then the shape of attn_probs is [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1], while if it equals False its shape is [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]. Now, because the shape is never potentially the same during the execution when run in graph mode, the pre-computed shape for attn_probs by the TF tracing was [batch_size, seq_len, self.num_heads, variable], where variable cannot be computed. The consequence of this was that attn_probs had never the proper shape at the end and creates a conflict in the tf.where. To solve this we had to also create a mask of a fixed shape that depends on is_global_attn.

I don't know if it is clear enough or not. Don't hesitate to tell me if there is something you don't get.

Thanks for the explanation - just tried it out and cool to see that your change fixes the test!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the slow tests pass, LGTM

@jplu
Copy link
Contributor Author

jplu commented Feb 10, 2021

The entire list of slow tests are ok!

@jplu
Copy link
Contributor Author

jplu commented Feb 10, 2021

@sgugger Feel free to merge if it looks ok for you!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@sgugger sgugger merged commit 22a32cf into huggingface:master Feb 10, 2021
@jplu jplu deleted the fix-tf-led-long branch February 10, 2021 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants