Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Quantization API #309

Merged
merged 31 commits into from Jul 17, 2023
Merged

Conversation

Aalanli
Copy link
Collaborator

@Aalanli Aalanli commented Jul 11, 2023

Add extensible quantization API.
See examples/quantization/gpt2.py for usage example.

On gpt2 with first 500 test split of wikitext-2-raw-v1:
original f32 ppl: 129.88427568662286
original f32 acc: [top-1: 0.291, top-5: 0.486, top-10: 0.561]

quantized f16 ppl: 131.41456528937462
quantized f16 acc: [top-1: 0.288, top-5: 0.482, top-10: 0.556]

quantized f16 -> int8 ppl: 131.11489348364347
quantized f16 -> int8 acc: [top-1: 0.284, top-5: 0.481, top-10: 0.554]

Currently supported:

  • symmetric weight quantization
  • Automatically quantize linear layers
  • Automatically quantize embedding layers
  • Custom symmetric quantized weight kernel for int8

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Aalanli ! Good progress!

I left some comments on the minor issues.

Comment on lines 66 to 83
class SymQuantLinearTransposed(Module):
def __init__(self, weight: Tensor, bias: Optional[Tensor] = None, quant_type: str = 'int8'):
super().__init__()
self.in_features = weight.shape[0]
self.out_features = weight.shape[1]
qweight, scale = ops.symmetric_quantize(weight, quant_type=quant_type, dims=[-1])
self.qweight = qweight
self.scale = scale
self.bias = bias

def extra_str(self) -> str:
return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)

def forward(self, x: Tensor) -> Tensor:
x = ops.matmul(x, ops.symmetric_dequantize(ops.barrier(self.qweight), self.scale, dims=[-1]))
if self.bias is not None:
x = ops.add(x, self.bias)
return x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also put all the quantization nn layers to a sub-namespace like hidet.graph.nn.quantized (like torch people used torch.nn.quantized) or hidet.graph.nn.quant (what you used in ops).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think that this module is currently not needed, since quantization is applied during graph pass anyways. And the copying mechanisms won't work here when converting from torch.

@@ -15,7 +15,7 @@
from hidet.ir.compute import reduce
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinli-git, could you help to have a look at the change of norm? Thanks!

In the future, let's try to unify the schedule template for different data types, which will reduce the complexity of maintanance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this change in norm is exactly the same as the earlier one. Since I needed to apply the same fix for some tests to pass.

python/hidet/graph/ops/quant/__init__.py Outdated Show resolved Hide resolved
python/hidet/graph/transforms/base.py Outdated Show resolved Hide resolved
python/hidet/graph/transforms/base.py Outdated Show resolved Hide resolved
python/hidet/graph/transforms/graph_patterns/base.py Outdated Show resolved Hide resolved
python/hidet/graph/transforms/graph_patterns/base.py Outdated Show resolved Hide resolved
python/hidet/ir/primitives/cuda/mma.py Outdated Show resolved Hide resolved
@yaoyaoding
Copy link
Member

Hi @Aalanli,

I forget one thing. It is recommanded to add put some of code in the examples/quantization to the test, so that we are sure our potential change of other places will not break the quantization support.

@Aalanli Aalanli merged commit e3b01bb into hidet-org:main Jul 17, 2023
2 checks passed
@Aalanli Aalanli deleted the quant-static-matmul branch July 17, 2023 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants