Skip to content

Commit

Permalink
[ao] updating embedding_bag support for fx and eager
Browse files Browse the repository at this point in the history
Summary: our docs were saying dynamic embedding bag wasn't supported but
it actually is (at least at the same level as embeddings were) it just wasn't previously tested/listed.

Test Plan: python test/test_quantization.py -k "test_embedding"

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 829622fa5a662f8cb1bab42bcaf463cf234288e8
Pull Request resolved: pytorch/pytorch#107623
  • Loading branch information
HDCharles committed Nov 20, 2023
1 parent d4189d8 commit b4d79f6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
5 changes: 2 additions & 3 deletions docs/source/quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ for a more comprehensive overview of the tradeoffs between these quantization
types.

Operator coverage varies between dynamic and static quantization and is captured in the table below.
Note that for FX quantization, the corresponding functionals are also supported.

+---------------------------+-------------------+--------------------+
| |Static | Dynamic |
Expand All @@ -135,7 +134,7 @@ Note that for FX quantization, the corresponding functionals are also supported.
|nn.EmbeddingBag | Y (activations | |
| | are in fp32) | Y |
+---------------------------+-------------------+--------------------+
|nn.Embedding | Y | N |
|nn.Embedding | Y | Y |
+---------------------------+-------------------+--------------------+
| nn.MultiheadAttention | Y (through | Not supported |
| | custom modules) | |
Expand Down Expand Up @@ -881,7 +880,7 @@ Note that for FX Graph Mode Quantization, the corresponding functionals are also
|nn.EmbeddingBag | Y (activations | |
| | are in fp32) | Y |
+---------------------------+-------------------+--------------------+
|nn.Embedding | Y | N |
|nn.Embedding | Y | Y |
+---------------------------+-------------------+--------------------+
|nn.MultiheadAttention |Not Supported | Not supported |
+---------------------------+-------------------+--------------------+
Expand Down
27 changes: 24 additions & 3 deletions test/quantization/eager/test_quantize_eager_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def checkHooksIsPresent(model):
checkHooksIsPresent(model)

@skipIfNoFBGEMM
def test_embedding_ops_dynamic(self):
def test_embedding_bag_dynamic(self):
class EmbeddingBagWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -1496,9 +1496,30 @@ def forward(self, indices, offsets, linear_in):
q_model = quantize_dynamic(model, qconfig_dict)

q_model(indices, offsets, torch.randn(5, 5))
self.assertTrue('QuantizedEmbedding' in str(q_model))
self.assertTrue('DynamicQuantizedLinear' in str(q_model))
self.assertTrue('QuantizedEmbeddingBag' in str(q_model.emb))
self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc))

@skipIfNoFBGEMM
def test_embedding_ops_dynamic(self):
class EmbeddingWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(
num_embeddings=10, embedding_dim=12, scale_grad_by_freq=False)
self.fc = torch.nn.Linear(5, 5)

def forward(self, indices, linear_in):
return self.emb(indices), self.fc(linear_in)
model = EmbeddingWithLinear().eval()
qconfig_dict = {
torch.nn.Embedding : float_qparams_weight_only_qconfig,
torch.nn.Linear: default_dynamic_qconfig
}
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
q_model = quantize_dynamic(model, qconfig_dict)
self.assertTrue('QuantizedEmbedding' in str(q_model.emb))
self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc))
q_model(indices, torch.randn(5, 5))

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand Down
12 changes: 12 additions & 0 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8636,12 +8636,24 @@ def forward(self, indices):
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
example_inputs = (indices,)
quantized_node = ns.call_module(nnq.Embedding)

# check dynamic quant
self.checkGraphModeFxOp(
model,
example_inputs,
QuantType.DYNAMIC,
quantized_node,
custom_qconfig_dict={"": qconfig_type}
)
model = M().eval()

configs = [
(qconfig_type, ns.call_module(nnq.Embedding)),
(None, ns.call_module(nn.Embedding)),
(default_qconfig, ns.call_module(nn.Embedding)),
]

# check static quantization
for qconfig, node in configs:
qconfig_dict = {"": qconfig}
m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
Expand Down

0 comments on commit b4d79f6

Please sign in to comment.