Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QDQBert model and quantization examples of SQUAD task #14066

Merged
merged 42 commits into from
Nov 19, 2021

Conversation

shangz-ai
Copy link
Contributor

@shangz-ai shangz-ai commented Oct 19, 2021

What does this PR do?

This PR includes:

  1. Add support of Q/DQ BERT model based on HF BERT model.
    (src/transformers/models/qdqbert/)

QDQBERT model add fake quantization operations (pair of QuantizeLinear/DequantizeLinear ops) to:

  • linear layer inputs and weights
  • matmul inputs
  • residual add inputs

in BERT model.

QDQBERT model will be able to load from any checkpoint of HF BERT model, and perform Quantization Aware Training/Post Training Quantization with the support from PyTorch-Quantization toolkit.

  1. Add an example of SQUAD tasks finetuned by the QDQBERT model and inferenced by TensorRT
    (transformers/examples/research_projects/quantization-qdqbert/)

In the example, we use qdqbert model to do Quantization Aware Training from pretrained HF BERT model on SQUAD task. Then TensorRT can run the inference of the generated ONNX model for optimal INT8 performance out-of-the-box.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.

A related discussion on this topic Issue 10639

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@shangz-ai shangz-ai marked this pull request as ready for review October 19, 2021 17:23
@shangz-ai
Copy link
Contributor Author

@LysandreJik @sgugger Thanks!

@shangz-ai
Copy link
Contributor Author

Some CIs failed since QDQBERT model needs the dependency of Pytorch Quantization Toolkit (https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization). This dependency is good to go with simple one-line installation as:
pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com

I'm thinking of either adding the one line installation change to CI or adding quantization toolkit installation to transformers installation (or any other suggestions which are smooth and neat for the HF community) if we want to upstream the model. @LysandreJik @sgugger

Thanks!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the long time without review, I had somehow missed the notification for this PR.

Thanks a lot for all your work! The new example you add is very cool! There is a little bit of work to polish the PR before we can merge it but this is already in great shape! I have left a few comments, the two main things are:

  • don't define a new tokenizer class since we can re-use the existing BERT tokenizers (I know some old models have subclasses but we don't do this anymore)
  • use "Copied from" statements in your modeling file to keep copies from the BERT modeling files up to date (you can see examples in the RoBERTa modeling file for instance)

One last point is the new dependency added, as you mentioned. Since one cannot import the model without it, you should add a new function that checks whether the necessary modules are installed (see for instance is_scatter_available used for the TAPAS model) and you should only conditionally import in the main init, like it is done for TAPAS. Running make fix-copies will then create the appropriate file with dummy classes so we can still import something called QDQBertModel in the init when the dependency is not there.

And for proper testing, and doc, the line with the pip install you mention should probably be added in the config.yml file that triggers circleCI.

Let us know if you need help with any of those steps!

docs/source/model_doc/qdqbert.rst Outdated Show resolved Hide resolved
docs/source/model_doc/qdqbert.rst Outdated Show resolved Hide resolved
docs/source/model_doc/qdqbert.rst Outdated Show resolved Hide resolved
docs/source/model_doc/qdqbert.rst Outdated Show resolved Hide resolved
docs/source/model_doc/qdqbert.rst Outdated Show resolved Hide resolved
src/transformers/models/qdqbert/modeling_qdqbert.py Outdated Show resolved Hide resolved
@@ -0,0 +1,54 @@
# coding=utf-8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to define a new tokenizer if you just re-use the BERT one. In the tokenization auto module, just set a line in the auto mapping that maps "qdqbert" to ("BertTokenizer, "BertTokenizerFast") (there is already an example with ibert using the RoBERTa tokenizer for instance).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in commit: 873d352

@shangz-ai
Copy link
Contributor Author

shangz-ai commented Nov 5, 2021

As for the CI failure of check_code_quality, import pycuda.autoinit is needed, even if not used, so to initialize CUDA environment. Any suggestions to resolve this?

For the other two check failures, I'm not super sure about what is the root cause. Glad to get insights about how to fix that.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in really good shape, thanks a lot for all the effort! Before merging, I'd like to understand if we can't make the code examples in the docstrings a bit clearer by leveraging the QuantDescriptor as you use in your integration test. Doing so would significantly reduce friction for users trying to use the model.

Thank you!

src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/qdqbert/configuration_qdqbert.py Outdated Show resolved Hide resolved
src/transformers/models/qdqbert/configuration_qdqbert.py Outdated Show resolved Hide resolved
src/transformers/models/qdqbert/configuration_qdqbert.py Outdated Show resolved Hide resolved
src/transformers/models/qdqbert/configuration_qdqbert.py Outdated Show resolved Hide resolved
return hidden_states


class QDQBertAttention(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good, thank you @shangz-ai!

Comment on lines 519 to 535
# Override
def test_feed_forward_chunking(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mention in the comment why it is overridden?

Copy link
Contributor Author

@shangz-ai shangz-ai Nov 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be due to the fact that when using feed_forward_chunking, the tensors change their shapes from chunking. Then quantizing those tesnors per channel/tensor will change the scaling factors during calibration. Thus one cannot obtain identical results before and after chunking.

As a result, I will also remove the chunk_size_feed_forward feature in qdqbert, like what iBert did. Does that make sense? addressed in f5188b7

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes sense, thank you!

Comment on lines 530 to 549
import pytorch_quantization.nn as quant_nn
from pytorch_quantization.tensor_quant import QuantDescriptor

model = QDQBertForMaskedLM.from_pretrained("bert-base-uncased")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])

input_desc = QuantDescriptor(num_bits=8, calib_method="max")
weight_desc = QuantDescriptor(num_bits=8, axis=((0,)))
quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)

output = model(input_ids)[0]

expected_shape = torch.Size((1, 11, 768))
self.assertEqual(output.shape, expected_shape)

expected_slice = torch.tensor(
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this approach is the canonical approach to using the model, I would favor updating the model docstring examples to reflect that, alongside with inline comments explaining in a few words what is happening. I think this would go a long way to improve usability of the model. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in 69da5d6

@LysandreJik
Copy link
Member

Thanks for working on this! It seems the code quality is not yet passing, could you run the quality scripts? You can do so with the following, from the root of your clone:

pip install -e ".[quality]"
make fixup

@shangz-ai
Copy link
Contributor Author

shangz-ai commented Nov 11, 2021

Thanks for working on this! It seems the code quality is not yet passing, could you run the quality scripts? You can do so with the following, from the root of your clone:

pip install -e ".[quality]"
make fixup

Thanks for the comments! This is actually somewhere I want to check.

The code quality failure is from the TensorRT inference script import pycuda.autoinit. This line of code is needed, but not used, to initialize CUDA environment. Is there a way that I can keep this line of code in the script and pass the code quality test? @patrickvonplaten @LysandreJik @sgugger

@shangz-ai shangz-ai changed the title Add QDQBert model and QAT example of SQUAD task Add QDQBert model and quantization examples of SQUAD task Nov 12, 2021
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR! Thanks for adding the model

@patrickvonplaten
Copy link
Contributor

Think there is just one small clean-up left to do:

examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py:29:1: F401 'pycuda.autoinit' imported but unused

@patrickvonplaten
Copy link
Contributor

Think there is another line to clean up :-) examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py:29:1: F401 'pycuda.autoinit' imported but unused

@shangz-ai
Copy link
Contributor Author

Think there is another line to clean up :-) examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py:29:1: F401 'pycuda.autoinit' imported but unused

@patrickvonplaten Is there a workaround for it? The pycuda.autoinit is imported for cuda environment setup so it is needed in the script. Thanks!

@sgugger
Copy link
Collaborator

sgugger commented Nov 18, 2021

You can add a a comment at the end of the import line # noqa: F401 to have it be ignored by our styler. To check locally if the test will pass or not, just run make quality.

Note that with the merge of #14431, you will need to rebase your PR on master and replace the lines

        self.init_weights()

by

        # Initialize weights and apply final processing
        self.post_init()

Let us know if you need any help!

@shangz-ai
Copy link
Contributor Author

Rebase the PR but not sure why there is the model templates runner CI failure now.

@@ -571,8 +597,6 @@
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
_import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]

# PyTorch models structure
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The templates test is failing because this line was deleted!

@sgugger
Copy link
Collaborator

sgugger commented Nov 19, 2021

Thanks again for all your work on this!

@sgugger sgugger merged commit a59e7c1 into huggingface:master Nov 19, 2021
@shangz-ai shangz-ai deleted the add-qdqbert-model branch November 19, 2021 18:58
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
…e#14066)

* clean up branch for add-qdqbert-model

* README update for QAT example; update docstrings in modeling_qdqbert.py

* Update qdqbert.rst

* Update README.md

* Update README.md

* calibration data using traning set; QAT example runs in fp32

* re-use BERTtokenizer for qdqbert

* Update docs/source/model_doc/qdqbert.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docs/source/model_doc/qdqbert.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docs/source/model_doc/qdqbert.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove qdqbert tokenizer

* Update qdqbert.rst

* update evaluate-hf-trt-qa.py

* update configuration_qdqbert.py

* update modeling_qdqbert.py: add copied statement; replace assert with ValueError

* update copied from statement

* add is_quantization_available; run make fix-copies

* unittest add require_quantization

* add backend dependency to qdqbert model

* update README; update evaluate script; make style

* lint

* docs qdqbert update

* circleci build_doc add pytorch-quantization for qdqbert

* update README

* update example readme with instructions to upgrade TensorRT to 8.2

* Update src/transformers/models/qdqbert/configuration_qdqbert.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/qdqbert/configuration_qdqbert.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/qdqbert/configuration_qdqbert.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/qdqbert/configuration_qdqbert.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* change quantization to pytorch_quantization for backend requirement

* feed_forward_chunking not supported in QDQBert

* make style

* update model docstrings and comments in testing scripts

* rename example to quantization-qdqbert; rename example scripts from qat to quant

* Update src/transformers/models/qdqbert/modeling_qdqbert.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* rm experimental functions in quant_trainer

* qa cleanup

* make fix-copies for docs index.rst

* fix doctree; use post_init() for qdqbert

* fix early device assignment for qdqbert

* fix CI:Model templates runner

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants