Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers #9098

Merged
merged 3 commits into from
Dec 14, 2020

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Dec 14, 2020

What does this PR do?

In Transformers, the cache should always have the same structure. This becomes especially important for composite models like RAG and EncoderDecoder that expect all models to have the same cache.

Bart and T5 had different caches with Bart being most different from the standard cache of the library.
This PR aligns the past_key_values cache of Bart/Rag with all other models in the library. In general, the philosophy should be:

the past_key_value should have exactly one level for each layer, no matter whether the model is a decoder-only a.k.a. GPT2 or BART. This was not correctly refactored in BART (it should have been implemented 1-to-1 as in T5). No breaking changes here though.

  • past_key_value tuple for each layer should always be a tuple of tensors, not a tuple of a tuple
  • for decodre-only models (GPT2), the tuple for each layer contains 2 tensors: key and value states
  • for seq2seq (BART/T5), the tuple for each layer contains 4 tensors: key and value states of uni-directional self-attention, saved key and value states for cross-attention

This doesn't break any backward compatibility and should fix some RAG problems (@ratthachat). All RAG, Bart slow tests are passing and changes correspond just to the tuple structure.

PR is blocking me for TFBart refactor -> will merge already.

cc @LysandreJik, @sgugger, @patil-suraj for info.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@@ -535,7 +535,6 @@ def config_and_inputs(self):
n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length,
use_cache=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cache was not tested because there was a discrepancy previously between Bart and T5 -> should work now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now use_cache is correctly tested in RAG

@@ -758,18 +756,18 @@ def test_rag_sequence_generate_beam(self):
generator_tokenizer=rag_decoder_tokenizer,
)

rag_token = self.sequence_model
rag_token.set_retriever(rag_retriever)
rag_sequence = self.sequence_model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a sequence not a token test => so change results here slightly

@@ -407,7 +407,7 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the past_key_value should have exactly one level for each layer, no matter whether the model is a decoder-only a.k.a. GPT2 or BART. This was not correctly refactored in BART (it should have been implemented 1-to-1 as in T5). No breaking changes here though.

For GPT2, the tuple for each layer contains 2 tensors: key and value states
For BART/T5, the tuple for each layer contains 4 tensors: key and value states of uni-directional self-attention, saved key and value states for cross-attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @patil-suraj for information

@@ -1284,12 +1285,9 @@ def _force_token_id_to_be_generated(scores, token_id) -> None:

@staticmethod
def _reorder_cache(past, beam_idx):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes re-order easier

@@ -1057,23 +1061,17 @@ def question_encoder(self):
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""

def _reorder_stacked(hidden_states):
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
def _reorder_stacked(hidden_states, new_order):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor RAG according to Bart

@patrickvonplaten patrickvonplaten changed the title [RAG] Fix RAG cache [RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers Dec 14, 2020
@patrickvonplaten patrickvonplaten merged commit fa1ddce into huggingface:master Dec 14, 2020
@ratthachat ratthachat mentioned this pull request Dec 25, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant