Skip to content

Commit

Permalink
[quant][pt2e] Add reference representation for dynamic quantized line…
Browse files Browse the repository at this point in the history
…ar (pytorch#108073)

Summary: att

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_representation_dynamic_linear
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test:quantization_pt2e -- 'test_representation_dynamic_linear'

Reviewed By: kimishpatel

Differential Revision: D48703076

Pull Request resolved: pytorch#108073
Approved by: https://github.com/andrewor14
  • Loading branch information
jerryzh168 committed Aug 30, 2023
1 parent c9cbdaf commit 8308b69
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 8 deletions.
45 changes: 37 additions & 8 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def _test_representation(
quantizer: Quantizer,
ref_node_occurrence: Dict[ns, int],
non_ref_node_occurrence: Dict[ns, int],
fixed_output_tol: float = None,
output_scale_idx: int = 3,
) -> torch.nn.Module:
""" TODO: need to implement output checking based on output_scale once
Expand Down Expand Up @@ -542,17 +543,22 @@ def _test_representation(
self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence)
pt2e_quant_output_copy = model_copy(*example_inputs)

idx = 0
for n in model_copy.graph.nodes:
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
idx += 1
if idx == output_scale_idx:
output_scale = n.args[1]
assert output_scale is not None

output_tol = None
if fixed_output_tol is not None:
output_tol = fixed_output_tol
else:
idx = 0
for n in model_copy.graph.nodes:
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
idx += 1
if idx == output_scale_idx:
output_tol = n.args[1]
assert output_tol is not None

# make sure the result is off by one at most in the quantized integer representation
self.assertTrue(
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_scale + 1e-5)
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5)
)

@skipIfNoQNNPACK
Expand Down Expand Up @@ -2201,6 +2207,29 @@ def forward(self, x):
non_ref_node_occurrence={}
)

def test_representation_dynamic_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)

def forward(self, x):
return self.linear(x)

quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)

self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
fixed_output_tol=1e-4,
)

def test_representation_conv2d(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down
87 changes: 87 additions & 0 deletions torch/ao/quantization/pt2e/representation/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,72 @@ def _reference_quantized_linear(
return out_i8


_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randn((2, 5), dtype=torch.float),
-128,
127,
torch.finfo(torch.float32).eps,
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
)


def _qdq_dynamic_quantized_linear(
x_fp32, x_quant_min, x_quant_max, x_eps,
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
bias_fp32,
):
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
return out_fp32

def _reference_dynamic_quantized_linear(
x_fp32, x_quant_min, x_quant_max, x_eps,
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
bias_fp32,
):
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
# decomposed representation for quantize_per_tensor
# TODO: use out_dtype(mul, ...) here when the op is ready
x_fp32 = x_fp32 / x_scale # fp32
# round modes might be different here
# pytorch is rounding to even, which is also common for most of the backends
x_fp32 = torch.round(x_fp32) # fp32
x_i32 = x_fp32.to(dtype=torch.int32) # int32
x_i32 = x_i32 + x_zero_point # int32
# clamp works for fp32, int32 and int8 dtypes
x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
x_i8 = x_i32.to(dtype=torch.int8)

weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)

x_i16 = x_i8.to(torch.int16)
weight_i16 = weight_i8.to(torch.int16)
# always set bias to None so that the same representation can work for the case
# no matter if bias_scale == x_scale * weight_scale or not
acc_i32 = out_dtype(
torch.ops.aten.linear.default,
torch.int32,
x_i16 - x_zero_point,
weight_i16 - weight_zero_point,
None)
bias_scale = x_scale * weight_scale
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
acc_i32 = acc_i32 + bias_i32
out_fp32 = acc_i32 * (x_scale * weight_scale)
return out_fp32


_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
Expand Down Expand Up @@ -465,6 +531,27 @@ class _RewriteInfo:
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None

_REWRITE_INFO_LIST = [
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_qdq_dynamic_quantized_linear,
_reference_dynamic_quantized_linear,
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={
-128: 1,
127: 2,
torch.finfo(torch.float32).eps: 3
}
),
partial(
_replace_literals_with_existing_placeholders,
literal_to_ph_idx={
-128: 1,
127: 2,
torch.finfo(torch.float32).eps: 3
}
),
),
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_qdq_quantized_linear,
Expand Down

0 comments on commit 8308b69

Please sign in to comment.