Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5] Fix Cross Attention position bias #4499

Merged
merged 2 commits into from May 26, 2020
Merged

Conversation

ZhuBaohe
Copy link
Contributor

@ZhuBaohe ZhuBaohe commented May 21, 2020

This PR fixes the Cross Attention position bias assignment in Class T5Stack.

@codecov-commenter
Copy link

codecov-commenter commented May 21, 2020

Codecov Report

Merging #4499 into master will decrease coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4499      +/-   ##
==========================================
- Coverage   77.83%   77.82%   -0.02%     
==========================================
  Files         123      123              
  Lines       20514    20514              
==========================================
- Hits        15968    15964       -4     
- Misses       4546     4550       +4     
Impacted Files Coverage Δ
src/transformers/modeling_t5.py 83.53% <100.00%> (ø)
src/transformers/modeling_tf_t5.py 95.16% <100.00%> (ø)
src/transformers/hf_api.py 93.06% <0.00%> (-4.96%) ⬇️
src/transformers/file_utils.py 73.85% <0.00%> (+0.41%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a086527...e9775b2. Read the comment docs.

@patrickvonplaten
Copy link
Contributor

Hi @ZhuBaohe,

Thansk for your PR! Can you explain a bit more in-detail what the fix is doing here? :-)

@ZhuBaohe
Copy link
Contributor Author

ZhuBaohe commented May 25, 2020

@patrickvonplaten

I fixes a bug that the variable encoder_decoder_position_bias was incorrectly assigned by cross-attention weights, not by cross-attention position bias.

See Line 745 of the file modeling_t5.py as follow:

# layer_outputs = hidden-states,                   -> 0
                  key-value-states,                -> 1
                  (self-attention weights),        -> 2                               
                  (self-attention position bias),  -> 3  
                  (cross-attention weights),       -> 4 
                  (cross-attention position bias)  -> 5 

encoder_decoder_position_bias should be assigned by layer_outputs[5] instead of layer_outputs[4] .

@patrickvonplaten patrickvonplaten changed the title fix T5 [T5] Fix Cross Attention position bias May 26, 2020
@patrickvonplaten
Copy link
Contributor

Great, I agree with you. Previously the attention weights of the cross attention layer were taken instead of the bias.

@LysandreJik @thomwolf I am quite surprised that we did not see an error earlier. I checked the slow tests and the summarization / translation results are equivalent as before.

So good to merge for me!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks @ZhuBaohe

@LysandreJik LysandreJik merged commit a163c9c into huggingface:master May 26, 2020
@ZhuBaohe ZhuBaohe deleted the t5 branch May 26, 2020 12:59
@thomwolf
Copy link
Member

Surprising indeed @patrickvonplaten , I did fix a similar bug when implementing T5.

We should switch to NamedTuples one day 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants