Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Conversation

@bfineran
Copy link
Contributor

to accomodate custom quantization ranges, in #542 we added a pass to propagate quantization range information from FakeQuantization objects to their Observer objects (this functionality is missing due to a small pytorch bug).

An unintended side effect, is that now, for default symmetric uint8 quantization (range [0, 255]), the range is now considered by torch to be custom, thus calculated as 255 - 0 // 2 = 127 instead of the desired 128 which was hardcoded as the default before.

This changes addresses this bug by skipping over the propagation pass for objects with default symmetric quantization ranges.

Testing can be done using the second input of QATMatMul from our BERT integration which should have a symmetric zero point of 128. Prior to this patch, the example shows it to be 127.

from transformers.models.bert.modeling_bert import QATMatMul
from sparseml.pytorch.sparsification import QuantizationModifier
import torch

dummy_sequential = torch.nn.Sequential(QATMatMul())  # give QATMatMul a layer to be wrapped
QuantizationModifier().apply(dummy_sequential)
qat_matmul = dummy_sequential[0]
_ = qat_matmul(torch.randn(10, 10), torch.randn(10, 10))
# print affected zero point, 127 before fix, 128 after
print(qat_matmul.input_quant_stubs[1].activation_post_process.zero_point)

@bfineran bfineran merged commit 11b98ff into main Mar 10, 2022
@bfineran bfineran deleted the default-symmetric-zero_point-fix branch March 10, 2022 17:07
bfineran added a commit that referenced this pull request Apr 21, 2022
* fix QATWrapper not properly overwritting qconfig properties for symmetric activations (#724)

* re-add fix symmetric zero points for unit8 quantization (#604) (#725)
markurtz added a commit that referenced this pull request May 2, 2022
* Avoid numerically unstable log (#694)

* fix QAT->Quant conversion of repeated Gemm layers with no activation QDQ (#698)

* Revert rn residual quant (#691)

* Revert ResNet definition to not quantize input to add op in residual branches.

* Correct typo.

Co-authored-by: Mark Kurtz <mark@neuralmagic.com>

* Fix: Add linebreak before 'Supplied' for better readability (#701)

* Bump notebook in /research/information_retrieval/doc2query (#679)

Bumps [notebook](http://jupyter.org) from 6.4.1 to 6.4.10.

---
updated-dependencies:
- dependency-name: notebook
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Mark Kurtz <mark@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>

* Added integration to masked_language_modeling training command (#707)

* Switch off fp16 on QAT start (#703)

* Switch off fp16 on QAT start

* address: review comments

* Disable fp16 when torch version is lesser than `1.9`

* Fix transformer prediction step (#716)

* Fix for prediction step when teacher model has more inputs than student.

* Updated signature of prediction_step method.

* Style and quality fixes.

* bump main to 0.13 (#696)

Co-authored-by: dhuang <dhuang@dhuangs-MacBook-Pro.local>

* Fix: default python log calls to debug level (#719)

* Feature/integrations (#688)

* added tutorials to root readme split by domain

* readme update

* edited text/structure

* grammar edits

* fix QATWrapper not properly overwritting qconfig properties for symmetric activations (#724)

* re-add fix symmetric zero points for unit8 quantization (#604) (#725)

* Fix 'self' and 'disable' not working for transformers distillation (#731)

* Click refactor for SparseML-PyTorch integration with Image Classification models (#711)

* Click refactor for SparseML-PyTorch integration

* Click refactor for `Pruning Sensitivity` analysis (#714)

* Click refactor for SparseML-PyTorch pr_sensitivity analysis integration

* Review comments from @KSGulin

* Click refactor for SparseML-PyTorch `lr-analysis` integration (#713)

* Click refactor for SparseML-PyTorch lr-analysis integration

* Review comments from @KSGulin

* Click refactor for SparseML PyTorch `export` integration (#712)

* Click refactor for SparseML-PyTorch export integration

* Review comments from @KSGulin

* Addressed all review comments from @bfineran, @dbogunowicz and @KSGulin

* Regenerate and Update the train-cli docstring due to changes in a few cli-args

* `nm_argparser.py` not needed anymore

* removed `nm_argparser.py` from init

* Remove All CLI args aliases and updated doctrings accordingly

* [Fix] Follow-up fix for #731 (Fix 'self' and 'disable' not working for transformers distillation) (#737)

* initial commit

* added more files and fixed quality

* Update trainer.py

* Added flag to exclude quantization of embedding activations. (#738)

* Added flag to exclude quantization of embedding activations.

* Updated testing to contemplate quantize_embedding_activations flag.

* Updated testing to contemplate quantize_embedding_activations flag.

* Updated debugging

* Revert "Updated debugging"

This reverts commit 449703d.

* Corrected order of arguments to pass assertion.

* Update src/sparseml/version.py

Co-authored-by: Eldar Kurtic <eldar.ciki@gmail.com>
Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: Alexandre Marques <alexandre@neuralmagic.com>
Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: dhuangnm <74931910+dhuangnm@users.noreply.github.com>
Co-authored-by: dhuang <dhuang@dhuangs-MacBook-Pro.local>
Co-authored-by: Ricky Costa <79061523+InquestGeronimo@users.noreply.github.com>
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants