Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF GPT2: clearer model variable naming with @unpack_inputs #16311

Merged
merged 7 commits into from Mar 29, 2022
Merged

TF GPT2: clearer model variable naming with @unpack_inputs #16311

merged 7 commits into from Mar 29, 2022

Conversation

cakiki
Copy link
Contributor

@cakiki cakiki commented Mar 21, 2022

What does this PR do?

Addresses #16051

Before submitting

Who can review?

@gante @Rocketknight1

@cakiki cakiki marked this pull request as ready for review March 21, 2022 20:41
@cakiki
Copy link
Contributor Author

cakiki commented Mar 21, 2022

Some tests failed locally: 12 failed, 37 passed, 1 skipped, 1 warning in 205.75s (0:03:25)

One such test: ValueError: The following keyword arguments are not supported by this model: ['past_key_values']. even though GPT2 calls it past as opposed to past_key_values. Shouldn't the test be rewritten?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 21, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member

gante commented Mar 22, 2022

Hey @cakiki 👋 The problem you're seeing is related to another change that is happening at the same time -- we are refactoring the auto-regressive generation function, where the past input variable in some old models got updated to past_key_values, to be consistent throughout models/frameworks.

(see next comment)

@gante
Copy link
Member

gante commented Mar 22, 2022

@sgugger @patrickvonplaten calling for your opinion here.

TL;DR:

  • In a generate() refactor past PR, I made prepare_inputs_for_generate() uniform across famerworks. In the process, one of the output keys in TF GPT2 was updated from past to past_key_values -- removing the past was one of the TODO goals flagged by @patrickvonplaten;
  • However, the current version of the model expects past as an input, if passed as a keyword argument [which raises the error @cakiki is seeing];
  • In PT/FLAX, this input is called past_key_values.

To fix it we have two options:

  1. Revert the output of prepare_inputs_for_generate() from past_key_values to past;
  2. We update the GPT2 input from past to past_key_values. It would be an API change, but this variable is mostly used in generate(), right?

I'm pro option 2, but WDYT?

@gante gante self-requested a review March 22, 2022 11:08
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the issue under discussion, looks solid 🔥

src/transformers/models/gpt2/modeling_tf_gpt2.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

I'm in favor of 1). I haven't done a good job at reviewing the refactor PR I think - see: #15944 (comment).

Backwards compatibility for models such as GPT2 is extremely important and people do use past outside of generate. Think the easy fix here is to just revert the line above.

The other possibility is to deprecate past in general for all models in TF, but this should be done over a deprecation cycle so that users are aware of the change.

@patrickvonplaten
Copy link
Contributor

Were there other models where we renamed past to past_key_values without changing the keyword argument name in the forward function?

@gante
Copy link
Member

gante commented Mar 22, 2022

Were there other models where we renamed past to past_key_values without changing the keyword argument name in the forward function?

Perhaps, going to check and open a PR to fix it (including this one)

@cakiki I will fix the issue in a separate PR, and will ping you to rebase with main when it is sorted

@gante
Copy link
Member

gante commented Mar 23, 2022

@cakiki the fix is merged -- rebasing with main should fix the problems you're seeing :)

@gante
Copy link
Member

gante commented Mar 23, 2022

(please rerun the tests locally before merging, and confirm here that they pass)

@cakiki
Copy link
Contributor Author

cakiki commented Mar 23, 2022

Rebasing with main did indeed solve most of the failing tests. The following 3 are still failing, but they're unrelated to the previous issue.

======================================================================================================= short test summary info =======================================================================================================
FAILED tests/gpt2/test_modeling_tf_gpt2.py::TFGPT2ModelTest::test_gpt2_xla_generate - TypeError: function() got an unexpected keyword argument 'jit_compile'
FAILED tests/gpt2/test_modeling_tf_gpt2.py::TFGPT2ModelTest::test_onnx_runtime_optimize - ModuleNotFoundError: No module named 'onnxruntime'
FAILED tests/gpt2/test_modeling_tf_gpt2.py::TFGPT2ModelLanguageGenerationTest::test_lm_generate_gpt2_xla - TypeError: function() got an unexpected keyword argument 'jit_compile'

@gante
Copy link
Member

gante commented Mar 23, 2022

@cakiki

  • jit_compile as a flag of tf.function() was added in TF2.5, can you confirm that you have TF >= 2.5? If not, can you try reruning after updating TF to a version equal or higher than 2.5? [added a personal TODO to throw an error if TF<2.5]
  • for the other error, can you try reruning the tests after reinstalling transformers with pip install -e ".[dev,onnx]"? It should be because of the onnx special dependencies :)

@cakiki
Copy link
Contributor Author

cakiki commented Mar 23, 2022

@gante I uninstalled tensorflow and explicitly pinned it >=2.5 and that worked. I noticed that pip install .[dev] was installing a bunch of tensorflow versions then finally settling on version 2.3 (fresh virtual env). (setup.py sets it to >=2.3)

image

test_onnx_runtime_optimize is the only test still failing:

E           onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. In Node, ("tfgp_t2for_sequence_classification_22/GatherV2", GatherV2, "", -1) : ("tfgp_t2for_sequence_classification_22/score/Tensordot:0": tensor(float),"tfgp_t2for_sequence_classification_22/sub:0": tensor(int32),"tfgp_t2for_sequence_classification_22/sub/y:0": tensor(int32),) -> ("logits": tensor(float),) , Error No Op registered for GatherV2 with domain_version of 10

../venv/lib/python3.6/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:370: InvalidGraph

ERROR    tf2onnx.tfonnx:tfonnx.py:303 Failed to convert node 'tfgp_t2for_sequence_classification_22/GatherV2' (fct=<bound method GatherV2.version_1 of <class 'tf2onnx.onnx_opset.tensor.GatherV2'>>)
'OP=GatherV2\nName=tfgp_t2for_sequence_classification_22/GatherV2\nInputs:\n\ttfgp_t2for_sequence_classification_22/score/Tensordot:0=Reshape, [-1, -1, 2], 1\n\ttfgp_t2for_sequence_classification_22/sub:0=Sub, [-1], 6\n\ttfgp_t2for_sequence_classification_22/GatherV2/axis:0=Const, [], 6\nOutpus:\n\ttfgp_t2for_sequence_classification_22/GatherV2:0=[-1, 2], 1'
Traceback (most recent call last):
  File "/media/ssd/BIGSCIENCE/venv/lib/python3.6/site-packages/tf2onnx/tfonnx.py", line 292, in tensorflow_onnx_mapping
    func(g, node, **kwargs, initialized_tables=initialized_tables, dequantize=dequantize)
  File "/media/ssd/BIGSCIENCE/venv/lib/python3.6/site-packages/tf2onnx/onnx_opset/tensor.py", line 444, in version_1
    utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
  File "/media/ssd/BIGSCIENCE/venv/lib/python3.6/site-packages/tf2onnx/utils.py", line 260, in make_sure
    raise ValueError("make_sure failure: " + error_msg % args)
ValueError: make_sure failure: Opset 12 required for batch_dims attribute of GatherV2

@cakiki
Copy link
Contributor Author

cakiki commented Mar 23, 2022

If it helps:

- `transformers` version: 4.18.0.dev0
- Platform: Linux-4.15.0-171-generic-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.6.9
- Huggingface_hub version: 0.4.0
- PyTorch version (GPU?): 1.10.1+cu102 (True)
- Tensorflow version (GPU?): 2.6.2 (False)
- Flax version (CPU?/GPU?/TPU?): 0.3.5 (cpu)
- Jax version: 0.2.17
- JaxLib version: 0.1.69

@gante
Copy link
Member

gante commented Mar 24, 2022

@cakiki It is settling on TF 2.3 because of python 3.6-related limitations on several packages :( We are actually having an internal discussion about potentially dropping support to python 3.6, since it is causing issues with both TF and PT (e.g. TF 2.8 requires python >= 3.7).

As for the onnx test, let's not worry about it, since it is failing on master as well.

@gante
Copy link
Member

gante commented Mar 24, 2022

I see that there are further errors in the tests, I will take a look (I think I know how to fix it). I will ping here again when the related fix is merged.

Hah, it seems like you got the best model to apply @unpack_inputs on :D

@gante
Copy link
Member

gante commented Mar 29, 2022

@cakiki I've been working on issues related to TF GPT-2, and it seems like I've also solved the errors here. I've tested locally with these changes on top of main, and all tests pass (except the onnx one, which is also failing on main).

We've recently renamed our branch from master to main, so CI won't turn green until we rebase -- which I've just pushed :)

@cakiki
Copy link
Contributor Author

cakiki commented Mar 29, 2022

@gante Thank you for the update!

@gante
Copy link
Member

gante commented Mar 29, 2022

The failing test is being tracked internally -- merging

@cakiki thank you for the contribution! (and for the patience)

@gante gante merged commit ee18d4d into huggingface:main Mar 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants