New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Quantization API #309
Conversation
…ant-static-matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Aalanli ! Good progress!
I left some comments on the minor issues.
python/hidet/graph/nn/linear.py
Outdated
class SymQuantLinearTransposed(Module): | ||
def __init__(self, weight: Tensor, bias: Optional[Tensor] = None, quant_type: str = 'int8'): | ||
super().__init__() | ||
self.in_features = weight.shape[0] | ||
self.out_features = weight.shape[1] | ||
qweight, scale = ops.symmetric_quantize(weight, quant_type=quant_type, dims=[-1]) | ||
self.qweight = qweight | ||
self.scale = scale | ||
self.bias = bias | ||
|
||
def extra_str(self) -> str: | ||
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
x = ops.matmul(x, ops.symmetric_dequantize(ops.barrier(self.qweight), self.scale, dims=[-1])) | ||
if self.bias is not None: | ||
x = ops.add(x, self.bias) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can also put all the quantization nn layers to a sub-namespace like hidet.graph.nn.quantized
(like torch people used torch.nn.quantized
) or hidet.graph.nn.quant
(what you used in ops
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think that this module is currently not needed, since quantization is applied during graph pass anyways. And the copying mechanisms won't work here when converting from torch.
@@ -15,7 +15,7 @@ | |||
from hidet.ir.compute import reduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xinli-git, could you help to have a look at the change of norm? Thanks!
In the future, let's try to unify the schedule template for different data types, which will reduce the complexity of maintanance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this change in norm is exactly the same as the earlier one. Since I needed to apply the same fix for some tests to pass.
Hi @Aalanli, I forget one thing. It is recommanded to add put some of code in the |
Add extensible quantization API.
See examples/quantization/gpt2.py for usage example.
On gpt2 with first 500 test split of wikitext-2-raw-v1:
original f32 ppl: 129.88427568662286
original f32 acc: [top-1: 0.291, top-5: 0.486, top-10: 0.561]
quantized f16 ppl: 131.41456528937462
quantized f16 acc: [top-1: 0.288, top-5: 0.482, top-10: 0.556]
quantized f16 -> int8 ppl: 131.11489348364347
quantized f16 -> int8 acc: [top-1: 0.284, top-5: 0.481, top-10: 0.554]
Currently supported: