Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update past_key_values in GPT-2 #9596

Conversation

forest1988
Copy link
Contributor

@forest1988 forest1988 commented Jan 14, 2021

What does this PR do?

It seems GPT-2 and BartDecoder has a different style of past_key_values.
Advised by @patrickvonplaten,
I opened this PR to change GPT-2's cache format from a single tensor to a tuple of 2 tensors.
Once this problem is solved, it is expected that past_key_values in GPT-2 will be handled in the same way as in Bart.

Sorry there remain some errors. This PR is [WIP].
I would appreciate your advice on how to update generation_utils.py.
Can I modify _reorder_cache so that past is replaced from Tuple[torch.Tensor] to Tuple[Tuple[torch.Tensor]],
or should I consider other output variations, output.mem and outputs.past_buckets_states?

Fixes #9391

From patrickvonplaten:

This PR cleans the _reorder_cache logic. Now _reorcher_cache defaults to an erroneous NotImplementedError in generation_utils.py forcing the model to implement its corresponding _rerorder_cache it the modeling_...py file itself. This is cleaner as _reorder_cache strongly differs from model to model. In addition, this PR makes sure that gradient_checkpointing can only be used if the model is in training mode and makes sure that use_cache is disabled when training and gradient_checkpointing is enabled to prevent errors.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

GPT2: @LysandreJik, @patrickvonplaten

@forest1988
Copy link
Contributor Author

CircleCI error messages says as below.

In run_tests_torch:

-- Docs: https://docs.pytest.org/en/stable/warnings.html
=========================== short test summary info ============================
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_sample_generate
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_search_generate
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_search_generate_dict_outputs_use_cache
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_gpt2_gradient_checkpointing
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_group_beam_search_generate
==== 5 failed, 4202 passed, 1775 skipped, 744 warnings in 216.47s (0:03:36) ====

Exited with code exit status 1
CircleCI received exit code 1

In run_tests_flax:

FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_sample_generate
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_search_generate
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_beam_search_generate_dict_outputs_use_cache
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_gpt2_gradient_checkpointing
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_group_beam_search_generate
==== 5 failed, 4172 passed, 1805 skipped, 751 warnings in 282.27s (0:04:42) ====

Exited with code exit status 1
CircleCI received exit code 1

@forest1988
Copy link
Contributor Author

Is there a difference between past_key_value and layer_past? I understand that they both represent the contents of past_key_values, the past of each layer, but are they different?

I first thought it might be a difference between the Causal language model and the Seq2Seq language model, but it seems that both past_key_value and layer_past are used in modeling_bart.py.

And as for the contents of layer_past, should it be named past_state, as the following part of modeling_bart.py shows?

@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

@forest1988
Copy link
Contributor Author

I've updated generation_utils.py, and it seems mems in transfo_xl and xlnet causes a new error.

=========================== short test summary info ============================
FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_gpt2_gradient_checkpointing
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_beam_sample_generate
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_beam_sample_generate_dict_output
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_beam_search_generate
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_beam_search_generate_dict_output
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_group_beam_search_generate
FAILED tests/test_modeling_transfo_xl.py::TransfoXLModelTest::test_group_beam_search_generate_dict_output
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_beam_sample_generate
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_beam_sample_generate_dict_output
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_beam_search_generate
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_beam_search_generate_dict_output
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_group_beam_search_generate
FAILED tests/test_modeling_xlnet.py::XLNetModelTest::test_group_beam_search_generate_dict_output
=== 13 failed, 4194 passed, 1775 skipped, 743 warnings in 205.38s (0:03:25) ====

Exited with code exit status 1
CircleCI received exit code 1

@dataclass
class XLNetModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.XLNetModel`.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states. Can be used (see :obj:`mems` input) to speed up sequential decoding.
The token ids which have their past given to this model should not be passed as :obj:`input_ids` as they
have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

It seems mems is something similar to past_key_values.
Is there any difference between these two elements with different names?
Also, is it safe to change mems from List[torch.Tensor] to Tuple[Tuple[torch.Tensor]]?

@patrickvonplaten
Copy link
Contributor

Hey @forest1988,

You're PR looks very nice! Yes, it is expected that XLNet and TransfoXL fail actually since they also have been using the "default" _reorder_cache function of modeling_utils.py. Could you do the following changes to correct this:

  1. Copy that old _reorder_cache (the one before you did your changes) function that was in generation_utils.py to both modeling_xlnet.py and modeling_transfo_xl.py file so that those have the same function as before?
  2. Copy the current _reorder_cache function of generation_utils.py into modeling_gpt2.py?
  3. Add a default _reorder_cache function to generation_utils.py that looks as follows:
def _reorder_cache(self, past, beam_idx):
    raise NotImplementedError(...)

@forest1988
Copy link
Contributor Author

I've just updated torch.utils.checkpoint.checkpoint check in modeling_gpt2.py, referring to modeling_bart.py.

@patrickvonplaten
Copy link
Contributor

This way it's much cleaner and correct :-) The reason I'm proposing this change is that the _reorder_cache function is so different for each model that there should be no default function. A default function could confuse people that want to add a new model in a way that they think it works out of the box, but in most cases it just doesn't. A clear error message such as:

def _reorder_cache(self, past, beam_idx):
    raise NotImplementedError(f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}")

@patrickvonplaten
Copy link
Contributor

I think this should solve the problems, let me know if you need more help :-)

@forest1988
Copy link
Contributor Author

Thank you for your advice! I'll update _reorder_cache soon and commit it.

@forest1988 forest1988 force-pushed the forest1988-fix-gpt2-past_key_values branch from 89ee453 to d04b10c Compare January 14, 2021 17:00
@forest1988
Copy link
Contributor Author

forest1988 commented Jan 14, 2021

Hi @patrickvonplaten,

Thanks to your kind advice, I could solve the problem of _reorder_cache in GPT-2, XLNet, TransfoXL (, and CTRL).
Referring to modeling_bart.py, in which _reorder_cache is placed in ConditionalGeneration Model, I added _reoder_cache in LMHead Models in each Causal Language Models.

The last one remaining bug is:

FAILED tests/test_modeling_gpt2.py::GPT2ModelTest::test_gpt2_gradient_checkpointing

I think I should modify test_gpt2_gradient_checkpointing so that it has use_cache=False, or reconsider my previous update and re-modify the usage of checkpoint in modeling_gpt2.

I've just updated torch.utils.checkpoint.checkpoint check in modeling_gpt2.py, referring to modeling_bart.py.

@forest1988
Copy link
Contributor Author

All checks have passed!
I appreciate all your help.

However, in the documentation of _reorder_cache, there are references to both past_key_values and mems regardless of which object is used.
I think we can fix that and only mention the one we use, or we can leave the reference to both to show that the aim of the function is the same.
If there is a need to modify it, please let me know.

@forest1988 forest1988 changed the title [WIP] Update past_key_values in GPT-2 Update past_key_values in GPT-2 Jan 14, 2021
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.

For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove those lines and past_key_values above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned it as well.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks very nice - thanks so much for taking the time to tackle this @forest1988 . Let's wait a bit to see how to proceed with gradient_checkpointing in GPT2 as this question will come up more often. IMO, use_cache should always be False for training so either we update all use_cache in the models with a use_cache= not self.is_training and (use_cache if use_cache is not None else self.config.use_cache) or we force it somehow in the Trainer. Similarly gradient_checkpointing should never be set to True when the model is not training IMO (we could also automatically disable this using self.training). Let's see what @LysandreJik and @sgugger think.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a part of the library I'm very familiar with, so the changes look okay on my side, but I'm no expert.

src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look good to me! Thanks for taking care of it @forest1988.

tests/test_modeling_gpt2.py Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @forest1988,

I hope it's fine for you that I went into the PR to do some final fixes. Thanks a lot for cleaning this up :-)

@forest1988
Copy link
Contributor Author

Hi @patrickvonplaten,

I hope it's fine for you that I went into the PR to do some final fixes. Thanks a lot for cleaning this up :-)

Of course! Thank you for adding fixes to make this PR more valuable!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your commit looks good to me @patrickvonplaten! Thanks.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new changes look good to me, thanks!

@patrickvonplaten patrickvonplaten merged commit b020a73 into huggingface:master Jan 19, 2021
@patrickvonplaten
Copy link
Contributor

Awesome, merging - great job @forest1988 !

@forest1988
Copy link
Contributor Author

Thank you for your advice and encouraging comments!
It’s my pleasure to have opened this PR!

@@ -232,7 +232,7 @@ def forward(
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
present = (key.transpose(-2, -1), value) # transpose to have same shapes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the reason for the recent failure of the slow test:

RUN_SLOW=1 pytest tests/test_onnx.py::OnnxExportTestCase::test_export_pytorch

Can you fix the onnx part easily? @mfuntowicz @Narsil

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Similar usage of past_key_values in CausalLM and Seq2SeqLM
4 participants