Skip to content

Commit

Permalink
instancenorm: static quant graph mode support (pytorch#39096)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#39096

Hooks up instancenorm for graph mode static quant

Test Plan:
```
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_instance_norm
```

Imported from OSS

Differential Revision: D21885258

fbshipit-source-id: 650cc5b162dda044866176fea6c345082d9788ed
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jun 7, 2020
1 parent b443ca2 commit ebdff07
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
20 changes: 20 additions & 0 deletions test/quantization/test_quantize_script.py
Expand Up @@ -2148,6 +2148,26 @@ def test_group_norm(self):
FileCheck().check_not("aten::group_norm") \
.run(m.graph)

def test_instance_norm(self):
data = [(torch.rand((1, 4, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
# TODO: handle affine == False (separate PR)
instance_norm1d = torch.nn.InstanceNorm1d(4, affine=True)
m = self._test_op_impl(instance_norm1d, data, "quantized::instance_norm")
FileCheck().check_not("aten::instance_norm") \
.run(m.graph)

data = [(torch.rand((1, 4, 10, 1), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
instance_norm2d = torch.nn.InstanceNorm2d(4, affine=True)
m = self._test_op_impl(instance_norm2d, data, "quantized::instance_norm")
FileCheck().check_not("aten::instance_norm") \
.run(m.graph)

data = [(torch.rand((1, 4, 10, 1, 1), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
instance_norm3d = torch.nn.InstanceNorm3d(4, affine=True)
m = self._test_op_impl(instance_norm3d, data, "quantized::instance_norm")
FileCheck().check_not("aten::instance_norm") \
.run(m.graph)

def test_quantize_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/quantization/helper.cpp
Expand Up @@ -22,6 +22,7 @@ std::vector<std::string> _static_quantizable_call_funcs = {
"hardswish",
"layer_norm",
"group_norm",
"instance_norm",
};

std::vector<std::string> _static_quantizable_aten_funcs = {
Expand All @@ -34,6 +35,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
"hardswish",
"layer_norm",
"group_norm",
"instance_norm",
};

std::vector<std::string> _dynamic_quantizable_call_funcs = {
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/jit/passes/quantization/quantization_patterns.h
Expand Up @@ -620,6 +620,19 @@ graph(%a_quant, %num_groups, %weight, %bias, %eps, %cudnn_enabled, %output_scale
%r = quantized::group_norm(%a_quant, %num_groups, %weight, %bias, %eps, %output_scale, %output_zero_point)
return (%r) )";

// quantized::instance_norm
std::string instance_norm = R"(
graph(%a_quant, %weight, %bias, %running_mean, %running_var, %use_input_stats, %momentum, %eps, %cudnn_enabled, %output_scale, %output_zero_point, %scalar_type):
%a_dequant = aten::dequantize(%a_quant)
%r_in = aten::instance_norm(%a_dequant, %weight, %bias, %running_mean, %running_var, %use_input_stats, %momentum, %eps, %cudnn_enabled)
%r = aten::quantize_per_tensor(%r_in, %output_scale, %output_zero_point, %scalar_type)
return (%r) )";

std::string quantized_instance_norm = R"(
graph(%a_quant, %weight, %bias, %running_mean, %running_var, %use_input_stats, %momentum, %eps, %cudnn_enabled, %output_scale, %output_zero_point, %scalar_type):
%r = quantized::instance_norm(%a_quant, %weight, %bias, %eps, %output_scale, %output_zero_point)
return (%r) )";

// ============= General Ops that inherit quantization paramters from input
// tensor =============
auto avg_pool1d = getInputTensorQParamOpFusionInfo(
Expand Down Expand Up @@ -820,6 +833,7 @@ graph(%a_quant, %num_groups, %weight, %bias, %eps, %cudnn_enabled, %output_scale
{"quantized::hardswish", hardswish, quantized_hardswish},
{"quantized::layer_norm", layer_norm, quantized_layer_norm},
{"quantized::group_norm", group_norm, quantized_group_norm},
{"quantized::instance_norm", instance_norm, quantized_instance_norm},
avg_pool1d,
avg_pool2d,
avg_pool3d,
Expand Down

0 comments on commit ebdff07

Please sign in to comment.