Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Conversation

@bfineran
Copy link
Contributor

@bfineran bfineran commented Sep 3, 2021

Embedding layers are a large storage requirement for transformer models. This PR introduces a pathway to quantize embedding values. By default this behavior will be enabled by the QuantizationModifier

Python Example

>>> import torch
>>> from sparseml.pytorch.optim import QuantizationModifier
>>> module = torch.nn.Embedding(100, 100)
>>> QuantizationModifier().apply(module)
>>> module
Embedding(
  100, 100
  (activation_post_process): FakeQuantize(
    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
  (weight_fake_quant): FakeQuantize(
    fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)

FP32 ONNX Graph

Screen Shot 2021-09-03 at 1 48 09 PM

QAT ONNX Graph

Screen Shot 2021-09-03 at 1 48 16 PM

Quant ONNX Graph

Screen Shot 2021-09-03 at 1 48 31 PM

@bfineran bfineran self-assigned this Sep 3, 2021
Copy link
Member

@markurtz markurtz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, a few small comments

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bfineran bfineran merged commit 0b55a06 into main Sep 8, 2021
@bfineran bfineran deleted the qat-embeddings branch September 8, 2021 15:45
bfineran added a commit that referenced this pull request Sep 8, 2021
* QAT and quant postprocessing for torch.nn.Embedding

* cleanup

* residual optim and logging fixes

* response to comments
bfineran added a commit that referenced this pull request Sep 8, 2021
* QAT and quant postprocessing for torch.nn.Embedding

* cleanup

* residual optim and logging fixes

* response to comments
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants