-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Quantization Aware Fine-tuning in all models (pytorch) #10639
Comments
Hello! Would I-BERT, available on |
@LysandreJik, Thanks for the useful reference. I guess the i-BERT model has manually implemented the architectural components (kernels, int8 layer norm etc) to make quantization work for BERT. If I am not wrong, their objective is to train BERT as much as possible in int8. The qat in torch takes the approach of training model in floating point fully but incorporating noise in gradients that mimic noise due to quantization. So it's basically throwing the "optimizing for quantization error" part to gradient descent, foregoing any need for altering architectures or fp32/16 training regime. This approach would be broader and apply for all the architectures without re-implementation. Maybe we can have a "qat" flag in config, that can be used to perform fake quantization and dequantization (which introduces quantization noise to parts of the gradients). |
Do you have an idea of the changes required for that? Could you do PoC and show us so that we can discuss over it? |
@LysandreJik Can you take a look at this implementation. It's a functioning qat aware BERT fine-tuning implementation. The process is described in this paper, Q8BERT: Quantized 8Bit BERT. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This is a feature I'd like to see as well, as dynamic quantization leads to a huge accuracy drop in my use case. My understanding is that a possible implementation of QAT could also easily be expanded to support static quantization. |
@sai-prasanna is it possible to load Bert-base (FP32 model) weights into Q8Bert ? |
🚀 Feature request
Pytorch supports mimicking quantization errors while training the models.
Here is the tutorial on this. For our NLP transformers, it requires a "fake quantization" operation to be done on the embeddings. I found this repository converting BERT to support this.
Motivation
I think quantization aware fine-tuning (if it works) will help a lot of use-cases where dynamic quantization alone doesn't suffice in maintaining the performance of the quantized model. Supporting it out of the box will remove the duplication of model code in end use cases.
Your contribution
I can work on this ASAP. Would appreciate initial thoughts on what a the MVP for it would be, any thoughts on the API (should we take in a "qat" boolean in config?), any pitfalls that I should be aware of, etc.
The text was updated successfully, but these errors were encountered: