Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Longformer] Output both local attentions and global attentions when output_attentions=True -> Good Second Issue #7514

Closed
patrickvonplaten opened this issue Oct 1, 2020 · 8 comments · Fixed by #7562
Assignees
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 1, 2020

🚀 Feature request

Good Second Issue - A more advanced issue for contributors who want to dive more into Longformer's attention mechanism.

Longformer currently only outputs global attentions, which is suboptimal because users might be interested in the local attentions as well. I propose to change the "output_attention" logic as follows in longformer:

attentions should correspond to the "local" attentions and then we'll add a new output type global_attention that contains the global_attentions. This is consistent with the naming of attention_mask and global_attention_mask IMO and the cleanest way to implement the feature.

Implementing this feature would mean to that Longformer will require its own ModelOutput class =>
BaseModelOutput, => LongformerBaseModelOutput or BaseModelOutputWithGlobalAttention (prefer the first name though)
BaseModelOutputWithPooling, => ...

Also some tests will have to be adapted.

This is a slightly more difficult issue, so I'm happy to help on it. One should understand the difference between local and global attention and how Longformer's attention is different to e.g. Bert's attention in general.

For more detail check out discussion here: #5646

@patrickvonplaten patrickvonplaten added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label Oct 1, 2020
@patrickvonplaten patrickvonplaten self-assigned this Oct 1, 2020
@patrickvonplaten patrickvonplaten changed the title [Longformer] Output both local attentions and global attentions when output_attentions=True [Longformer] Output both local attentions and global attentions when output_attentions=True -> Second Good Issue Oct 1, 2020
@patrickvonplaten patrickvonplaten changed the title [Longformer] Output both local attentions and global attentions when output_attentions=True -> Second Good Issue [Longformer] Output both local attentions and global attentions when output_attentions=True -> Good Second Issue Oct 1, 2020
@gui11aume
Copy link
Contributor

I am working on a pull request to address this. I don't see any major challenge so far, but this made me realize how much attentions in Bert-like models and in Longformers are different. Why not replace attentions in the Longformer by local_attentions?

This means that the interface of Longformers would become incompatible with every other Transformer, but maybe it should be? I don't think that there is a way to plug Longformer attentions into a code that expects Bert-like attentions and get meaningful results, so users always have to write a special case for Longformers if they use them. As is, the risk is that they get bogus output and won't realize it until they carefully read the doc (that is not yet written).

What are your thoughts on this @patrickvonplaten?

@gui11aume
Copy link
Contributor

I have made the pull request.

I checked that the Longformer tests passed with my changes, and I added one more test to check the output of attention probabilities.

Quite stupidly I made the pull request to the master branch, I am sorry about this. I left it as is to avoid duplicating pull requests for now. You can reject it and I will make a cleaner pull request to a separate branch.

@patrickvonplaten patrickvonplaten added this to On Hold - Wait for Main Contributor in Medium Contribution Proposals - Advanced Oct 27, 2020
@patrickvonplaten
Copy link
Contributor Author

sorry to have been so super inactive on this issue :-/ I will find time to solve it in ~1 week :-) . This issue is related as well: https://github.com/huggingface/transformers/pull/8007/files#r514633097.

@gui11aume
Copy link
Contributor

No worries, there is no hurry on my side. Anyway, the issue is a little trickier than it looks because you guys have to decide how to encode attention probabilities when they are too large to be represented by a dense matrix. Let me know if there is anything I can do to help.

@patrickvonplaten patrickvonplaten moved this from On Hold - Wait for Main Contributor to Done in Medium Contribution Proposals - Advanced Nov 11, 2020
@gui11aume
Copy link
Contributor

Hi @patrickvonplaten. I did not use the 🤗 Transformers since our discussion in November 2020. Today I came back to it (transformers version: 4.4.2) and I realized that this issue is still not completely solved. I could open a new issue, but I believe that the fix is really simple so I hope we can address it here: In some models, the global attentions are computed, stored in outputs, but at the very last stage they are not returned.

If I am not mistaken, the issue is in modeling_longformer.py. At lines 1784-1789 the code is

return LongformerMaskedLMOutput(
    loss=masked_lm_loss,
    logits=prediction_scores,
    hidden_states=outputs.hidden_states,
    attentions=outputs.attentions,
)

but I think it should be

return LongformerMaskedLMOutput(
    loss=masked_lm_loss,
    logits=prediction_scores,
    hidden_states=outputs.hidden_states,
    attentions=outputs.attentions,
    global_attentions=outputs.global_attentions,  # <=====
)

The same goes for lines 1876 and 2124 (but it is fine for lines 2029 and 2235).

@patrickvonplaten
Copy link
Contributor Author

This sounds correct to me! Would you mind opening a new PR?

@gui11aume
Copy link
Contributor

I will do it, no problem.

gui11aume added a commit to gui11aume/transformers that referenced this issue Mar 25, 2021
@gui11aume
Copy link
Contributor

I made a minimal pull request #10906.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Development

Successfully merging a pull request may close this issue.

2 participants