Skip to content

Commit

Permalink
[quant][fx][bc-breaking] Add required example_inputs argument to prep…
Browse files Browse the repository at this point in the history
…are_fx and prepare_qat_fx (pytorch#77608)

Summary:
Pull Request resolved: pytorch#77608

X-link: pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)
# or
m = prepare_qat_fx(m, qconfig_dict)
```
After:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
# or
m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
```

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Imported from OSS

**Static Docs Preview: classyvision**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D35984526/V44/classyvision/)|

|**Modified Pages**|

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 716e5992ebe99cfb90be669357f56b214d692aef
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed May 21, 2022
1 parent c0abd83 commit 9e0b8bb
Show file tree
Hide file tree
Showing 9 changed files with 536 additions and 340 deletions.
3 changes: 2 additions & 1 deletion test/quantization/bc/test_backward_compatibility.py
Expand Up @@ -171,9 +171,10 @@ def _do_quant_transforms(
m: torch.nn.Module,
input_tensor: torch.Tensor,
) -> torch.nn.Module:
example_inputs = (input_tensor,)
# do the quantizaton transforms and save result
qconfig = torch.quantization.get_default_qconfig('fbgemm')
mp = quantize_fx.prepare_fx(m, {'': qconfig})
mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs)
mp(input_tensor)
mq = quantize_fx.convert_fx(mp)
return mq
Expand Down
28 changes: 14 additions & 14 deletions test/quantization/dbr/test_quantize_dbr.py
Expand Up @@ -82,7 +82,7 @@ def _test_auto_tracing(

# compare it against FX
if do_fx_comparison:
m_copy_p = prepare_fx(m_copy, {'': qconfig})
m_copy_p = prepare_fx(m_copy, {'': qconfig}, example_inputs=example_args)
out_m_copy_p = m_copy_p(*example_args)
# print(m_copy_p)
m_copy_q = convert_fx(m_copy_p)
Expand Down Expand Up @@ -1236,11 +1236,11 @@ def test_qconfig_dict_unsupported_does_not_crash_when_empty(self):
"""
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
qconfig_dict = {'': torch.quantization.default_qconfig}
example_inputs = (torch.randn(1, 1, 1, 1),)
# this modifies qconfig_dict inplace to include more keys
mp = prepare_fx(m, qconfig_dict)
example_args = (torch.randn(1, 1, 1, 1),)
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
# need this line to not crash
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)

def _test_serialization(self, model, input_shape):
example_inputs = (torch.randn(*input_shape),)
Expand Down Expand Up @@ -1324,15 +1324,15 @@ def test_jit_tracing_removes_aliases(self):
),
)
qconfig_dict = {'': torch.quantization.default_qconfig}
example_args = (torch.randn(1, 1, 1, 1),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
example_inputs = (torch.randn(1, 1, 1, 1),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)
mq = _quantize_dbr.convert(mp)
mqs = torch.jit.trace(mq, example_args)
mqs = torch.jit.trace(mq, example_inputs)
FileCheck().check_count("aten::alias", 5, exactly=True).run(
mqs.inlined_graph)
res1 = mqs(*example_args)
res1 = mqs(*example_inputs)
mqs = remove_redundant_aliases(mqs)
res2 = mqs(*example_args)
res2 = mqs(*example_inputs)
self.assertTrue(torch.allclose(res1, res2))
# TODO(future PR): figure out why aliasing still appears in the inlined
# graph, and if that is fixed then just check the inlined graph.
Expand Down Expand Up @@ -1609,11 +1609,11 @@ def test_mobilenet_v2_removes_aliases(self):
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False)\
.eval().float()
qconfig_dict = {'': torch.quantization.default_qconfig}
example_args = (torch.randn(1, 3, 224, 224),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
example_inputs = (torch.randn(1, 3, 224, 224),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)
mq = _quantize_dbr.convert(mp)
mqs = torch.jit.trace(mq, example_args)
res1 = mqs(*example_args)
mqs = torch.jit.trace(mq, example_inputs)
res1 = mqs(*example_inputs)
mqs = remove_redundant_aliases(mqs)
res2 = mqs(*example_args)
res2 = mqs(*example_inputs)
self.assertTrue(torch.allclose(res1, res2))
70 changes: 59 additions & 11 deletions test/quantization/fx/test_equalize_fx.py
Expand Up @@ -274,7 +274,14 @@ def test_input_weight_equalization_prepare(self):

for (M, node_occurrence) in tests:
m = M().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
# TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test
# for now we do not need shape so this can be fixed later
example_inputs = (torch.randn(1, 1, 1, 1),)
prepared = prepare_fx(
m,
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)

def test_input_weight_equalization_branching(self):
Expand Down Expand Up @@ -305,7 +312,10 @@ def forward(self, x):
}

m = TestBranchingWithoutEqualizationModel().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
example_inputs = (torch.randn(1, 5),)
prepared = prepare_fx(
m, specific_qconfig_dict, example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)

# Tests that we will add an equalization observer because there is only
Expand All @@ -326,7 +336,10 @@ def forward(self, x):
}

m = TestBranchingWithEqualizationModel().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
example_inputs = (torch.randn(1, 5),)
prepared = prepare_fx(
m, specific_qconfig_dict, example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)

@skipIfNoFBGEMM
Expand All @@ -353,17 +366,22 @@ def test_input_weight_equalization_convert(self):
elif ndim == 4:
x = torch.rand((16, 3, 224, 224))

example_inputs = (x,)
prepared = prepare_fx(
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict
)
output = prepared(x)

convert_ref = _convert_equalization_ref(prepared)
convert_ref_output = convert_ref(x)

prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_fx(prepared) # Check if compile
self.assertEqual(output, convert_ref_output)
Expand Down Expand Up @@ -411,8 +429,12 @@ def test_input_weight_equalization_equalization_scales(self):
m = M().eval()
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())

prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
example_inputs = (x,)
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(*example_inputs)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)

Expand Down Expand Up @@ -460,7 +482,11 @@ def test_input_weight_equalization_weights_bias(self):
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)

prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
example_inputs = (x,)
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
Expand Down Expand Up @@ -516,7 +542,11 @@ def test_input_weight_equalization_activation_values(self):
exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias)
exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights)

prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
example_inputs = (x,)
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref(x)
Expand Down Expand Up @@ -751,7 +781,13 @@ def test_input_weight_equalization_graphs(self):

for (M, node_list) in tests:
m = M().eval()
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
# TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test
# for now we do not need shape so this can be fixed later
example_inputs = (torch.randn(1, 1, 1, 1),)
prepared = prepare_fx(
m, specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict)
equalized_quantized_model = convert_fx(prepared)

# Check the order of nodes in the graph
Expand All @@ -771,7 +807,12 @@ def test_input_weight_equalization_results(self):
m = M().eval()

# No equalization
prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={})
example_inputs = (x,)
prepared = prepare_fx(
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict={})
prepared(x)
quantized = convert_fx(prepared) # Check if compile
quantized_output = quantized(x)
Expand All @@ -780,6 +821,7 @@ def test_input_weight_equalization_results(self):
prepared = prepare_fx(
copy.deepcopy(m),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=default_equalization_qconfig_dict
)
prepared(x)
Expand Down Expand Up @@ -817,7 +859,12 @@ def forward(self, x):
[0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

# Quantize the float model
prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict)
example_inputs = (x,)
prepared_model = prepare_fx(
copy.deepcopy(float_model),
specific_qconfig_dict,
example_inputs=example_inputs
)
prepared_model(x)
quantized_model = convert_fx(copy.deepcopy(prepared_model))

Expand All @@ -832,6 +879,7 @@ def forward(self, x):
prepared_model = prepare_fx(
copy.deepcopy(float_model),
specific_qconfig_dict,
example_inputs=example_inputs,
equalization_qconfig_dict=selective_equalization_qconfig_dict,
)
prepared_model(x)
Expand Down

0 comments on commit 9e0b8bb

Please sign in to comment.