Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Quantization Aware Fine-tuning in all models (pytorch) #10639

Closed
sai-prasanna opened this issue Mar 11, 2021 · 7 comments
Closed

Support Quantization Aware Fine-tuning in all models (pytorch) #10639

sai-prasanna opened this issue Mar 11, 2021 · 7 comments

Comments

@sai-prasanna
Copy link

🚀 Feature request

Pytorch supports mimicking quantization errors while training the models.
Here is the tutorial on this. For our NLP transformers, it requires a "fake quantization" operation to be done on the embeddings. I found this repository converting BERT to support this.

Motivation

I think quantization aware fine-tuning (if it works) will help a lot of use-cases where dynamic quantization alone doesn't suffice in maintaining the performance of the quantized model. Supporting it out of the box will remove the duplication of model code in end use cases.

Your contribution

I can work on this ASAP. Would appreciate initial thoughts on what a the MVP for it would be, any thoughts on the API (should we take in a "qat" boolean in config?), any pitfalls that I should be aware of, etc.

@LysandreJik
Copy link
Member

Hello! Would I-BERT, available on master and contributed by @kssteven418 be of interest?

@sai-prasanna
Copy link
Author

@LysandreJik, Thanks for the useful reference. I guess the i-BERT model has manually implemented the architectural components (kernels, int8 layer norm etc) to make quantization work for BERT. If I am not wrong, their objective is to train BERT as much as possible in int8. The qat in torch takes the approach of training model in floating point fully but incorporating noise in gradients that mimic noise due to quantization. So it's basically throwing the "optimizing for quantization error" part to gradient descent, foregoing any need for altering architectures or fp32/16 training regime.

This approach would be broader and apply for all the architectures without re-implementation. Maybe we can have a "qat" flag in config, that can be used to perform fake quantization and dequantization (which introduces quantization noise to parts of the gradients).

@LysandreJik
Copy link
Member

Do you have an idea of the changes required for that? Could you do PoC and show us so that we can discuss over it?

@sai-prasanna
Copy link
Author

@LysandreJik Can you take a look at this implementation. It's a functioning qat aware BERT fine-tuning implementation. The process is described in this paper, Q8BERT: Quantized 8Bit BERT.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@pie3636
Copy link
Contributor

pie3636 commented May 25, 2021

This is a feature I'd like to see as well, as dynamic quantization leads to a huge accuracy drop in my use case. My understanding is that a possible implementation of QAT could also easily be expanded to support static quantization.

@rohanshingade
Copy link

@sai-prasanna is it possible to load Bert-base (FP32 model) weights into Q8Bert ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants