From 9e0b8bba10c06cc1037421e90ff43b4d6b5ec1a5 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 20 May 2022 23:19:28 -0700 Subject: [PATCH] [quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx (#77608) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/77608 X-link: https://github.com/pytorch/fx2trt/pull/76 X-link: https://github.com/facebookresearch/d2go/pull/249 X-link: https://github.com/fairinternal/ClassyVision/pull/104 X-link: https://github.com/pytorch/benchmark/pull/916 X-link: https://github.com/facebookresearch/ClassyVision/pull/791 X-link: https://github.com/facebookresearch/mobile-vision/pull/68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make https://github.com/pytorch/pytorch/pull/76496#discussion_r861230047 (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestQuantizeFxModels Imported from OSS **Static Docs Preview: classyvision** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D35984526/V44/classyvision/)| |**Modified Pages**| Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 716e5992ebe99cfb90be669357f56b214d692aef --- .../bc/test_backward_compatibility.py | 3 +- test/quantization/dbr/test_quantize_dbr.py | 28 +- test/quantization/fx/test_equalize_fx.py | 70 ++- test/quantization/fx/test_numeric_suite_fx.py | 177 ++++-- test/quantization/fx/test_quantize_fx.py | 554 ++++++++++-------- torch/ao/quantization/fx/prepare.py | 23 +- torch/ao/quantization/fx/qconfig_utils.py | 6 +- torch/ao/quantization/quantize_fx.py | 14 +- .../testing/_internal/common_quantization.py | 1 + 9 files changed, 536 insertions(+), 340 deletions(-) diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index b89d43c3e3e5..ef8f49b3b2e2 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -171,9 +171,10 @@ def _do_quant_transforms( m: torch.nn.Module, input_tensor: torch.Tensor, ) -> torch.nn.Module: + example_inputs = (input_tensor,) # do the quantizaton transforms and save result qconfig = torch.quantization.get_default_qconfig('fbgemm') - mp = quantize_fx.prepare_fx(m, {'': qconfig}) + mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs) mp(input_tensor) mq = quantize_fx.convert_fx(mp) return mq diff --git a/test/quantization/dbr/test_quantize_dbr.py b/test/quantization/dbr/test_quantize_dbr.py index cd6dd6968ad0..4d7a2c5aabaa 100644 --- a/test/quantization/dbr/test_quantize_dbr.py +++ b/test/quantization/dbr/test_quantize_dbr.py @@ -82,7 +82,7 @@ def _test_auto_tracing( # compare it against FX if do_fx_comparison: - m_copy_p = prepare_fx(m_copy, {'': qconfig}) + m_copy_p = prepare_fx(m_copy, {'': qconfig}, example_inputs=example_args) out_m_copy_p = m_copy_p(*example_args) # print(m_copy_p) m_copy_q = convert_fx(m_copy_p) @@ -1236,11 +1236,11 @@ def test_qconfig_dict_unsupported_does_not_crash_when_empty(self): """ m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() qconfig_dict = {'': torch.quantization.default_qconfig} + example_inputs = (torch.randn(1, 1, 1, 1),) # this modifies qconfig_dict inplace to include more keys - mp = prepare_fx(m, qconfig_dict) - example_args = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # need this line to not crash - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) def _test_serialization(self, model, input_shape): example_inputs = (torch.randn(*input_shape),) @@ -1324,15 +1324,15 @@ def test_jit_tracing_removes_aliases(self): ), ) qconfig_dict = {'': torch.quantization.default_qconfig} - example_args = (torch.randn(1, 1, 1, 1),) - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) mq = _quantize_dbr.convert(mp) - mqs = torch.jit.trace(mq, example_args) + mqs = torch.jit.trace(mq, example_inputs) FileCheck().check_count("aten::alias", 5, exactly=True).run( mqs.inlined_graph) - res1 = mqs(*example_args) + res1 = mqs(*example_inputs) mqs = remove_redundant_aliases(mqs) - res2 = mqs(*example_args) + res2 = mqs(*example_inputs) self.assertTrue(torch.allclose(res1, res2)) # TODO(future PR): figure out why aliasing still appears in the inlined # graph, and if that is fixed then just check the inlined graph. @@ -1609,11 +1609,11 @@ def test_mobilenet_v2_removes_aliases(self): m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False)\ .eval().float() qconfig_dict = {'': torch.quantization.default_qconfig} - example_args = (torch.randn(1, 3, 224, 224),) - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) mq = _quantize_dbr.convert(mp) - mqs = torch.jit.trace(mq, example_args) - res1 = mqs(*example_args) + mqs = torch.jit.trace(mq, example_inputs) + res1 = mqs(*example_inputs) mqs = remove_redundant_aliases(mqs) - res2 = mqs(*example_args) + res2 = mqs(*example_inputs) self.assertTrue(torch.allclose(res1, res2)) diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index a552fc0a0f1e..2e211853aa16 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -274,7 +274,14 @@ def test_input_weight_equalization_prepare(self): for (M, node_occurrence) in tests: m = M().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + # TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test + # for now we do not need shape so this can be fixed later + example_inputs = (torch.randn(1, 1, 1, 1),) + prepared = prepare_fx( + m, + specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) def test_input_weight_equalization_branching(self): @@ -305,7 +312,10 @@ def forward(self, x): } m = TestBranchingWithoutEqualizationModel().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared = prepare_fx( + m, specific_qconfig_dict, example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence) # Tests that we will add an equalization observer because there is only @@ -326,7 +336,10 @@ def forward(self, x): } m = TestBranchingWithEqualizationModel().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared = prepare_fx( + m, specific_qconfig_dict, example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence) @skipIfNoFBGEMM @@ -353,9 +366,11 @@ def test_input_weight_equalization_convert(self): elif ndim == 4: x = torch.rand((16, 3, 224, 224)) + example_inputs = (x,) prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=default_equalization_qconfig_dict ) output = prepared(x) @@ -363,7 +378,10 @@ def test_input_weight_equalization_convert(self): convert_ref = _convert_equalization_ref(prepared) convert_ref_output = convert_ref(x) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_fx(prepared) # Check if compile self.assertEqual(output, convert_ref_output) @@ -411,8 +429,12 @@ def test_input_weight_equalization_equalization_scales(self): m = M().eval() exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) - prepared(x) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) + prepared(*example_inputs) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -460,7 +482,11 @@ def test_input_weight_equalization_weights_bias(self): exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -516,7 +542,11 @@ def test_input_weight_equalization_activation_values(self): exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias) exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -751,7 +781,13 @@ def test_input_weight_equalization_graphs(self): for (M, node_list) in tests: m = M().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + # TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test + # for now we do not need shape so this can be fixed later + example_inputs = (torch.randn(1, 1, 1, 1),) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph @@ -771,7 +807,12 @@ def test_input_weight_equalization_results(self): m = M().eval() # No equalization - prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={}) + example_inputs = (x,) + prepared = prepare_fx( + copy.deepcopy(m), + specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict={}) prepared(x) quantized = convert_fx(prepared) # Check if compile quantized_output = quantized(x) @@ -780,6 +821,7 @@ def test_input_weight_equalization_results(self): prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=default_equalization_qconfig_dict ) prepared(x) @@ -817,7 +859,12 @@ def forward(self, x): [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model - prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) + example_inputs = (x,) + prepared_model = prepare_fx( + copy.deepcopy(float_model), + specific_qconfig_dict, + example_inputs=example_inputs + ) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) @@ -832,6 +879,7 @@ def forward(self, x): prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 4559c6389be6..a83c5cb5eefe 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -295,7 +295,7 @@ class TestFXGraphMatcher(QuantizationTestCase): @skipIfNoFBGEMM def test_simple_mod(self): m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -322,7 +322,7 @@ def forward(self, x): return F.linear(x, self.w, self.b) m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -340,7 +340,7 @@ def forward(self, x): @skipIfNoFBGEMM def test_simple_fusion(self): m = LinearReluFunctional().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(4, 4),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -363,7 +363,7 @@ def test_simple_mod_multi(self): ), nn.Conv2d(1, 1, 1), ).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions @@ -380,7 +380,8 @@ def forward(self, x, y): return z m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1), torch.randn(1)) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions @@ -392,8 +393,9 @@ def test_matching_failure_node_count(self): # different counts of matchable nodes fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() - mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}) - mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) + mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @@ -402,8 +404,10 @@ def test_matching_failure_node_type(self): # verify that matching graphs with non-matching node types fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Linear(1, 1)).eval() - mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}) - mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) + example_inputs = (torch.randn(1, 1),) + mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @@ -421,7 +425,8 @@ def forward(self, x0): return x2 m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -456,7 +461,8 @@ def forward(self, x0): return a1 m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -499,7 +505,8 @@ def forward(self, x): '': torch.ao.quantization.default_qconfig, 'module_name': [('conv2', None)], } - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -541,8 +548,9 @@ def forward(self, x): m1 = M().eval() m2 = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1p = prepare_fx(m1, qconfig_dict) - m2p = prepare_fx(m2, qconfig_dict) + example_inputs = (torch.randn(1),) + m1p = prepare_fx(m1, qconfig_dict, example_inputs=example_inputs) + m2p = prepare_fx(m2, qconfig_dict, example_inputs=example_inputs) results = get_matching_subgraph_pairs(m1p, m2p) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() sigmoid_name_0 = 'base_op_' + get_base_name_for_op( @@ -733,8 +741,9 @@ def forward(self, x): return x qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1 = prepare_fx(M1().eval(), qconfig_dict) - m2 = prepare_fx(M2().eval(), qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_op_to_sets_of_related_ops( @@ -760,7 +769,8 @@ def test_results_order(self): nn.Conv2d(1, 1, 1), nn.Linear(1, 1), ).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -782,7 +792,8 @@ def test_mobilenet_v2(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float() - mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) @@ -796,9 +807,11 @@ def test_mobilenet_v2_qat(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() + example_inputs = (torch.randn(1, 3, 224, 224),) mp = prepare_qat_fx( copy.deepcopy(m), - {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}) + {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, + example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) @@ -809,12 +822,12 @@ def test_mobilenet_v2_qat(self): class FXNumericSuiteQuantizationTestCase(QuantizationTestCase): def _test_extract_weights( - self, m, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx + self, m, example_inputs, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx ): m = torch.fx.symbolic_trace(m) if qconfig_dict is None: qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -848,7 +861,7 @@ def _test_match_activations( m.eval() else: m.train() - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -911,7 +924,7 @@ def _test_match_shadow_activations( m.eval() else: m.train() - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -969,26 +982,30 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase): @skipIfNoFBGEMM def test_extract_weights_mod_ptq(self): m = AllConvAndLinearFusionModules().eval() - self._test_extract_weights(m, results_len=14) + example_inputs = (torch.randn(1, 1, 1, 1),) + self._test_extract_weights(m, example_inputs, results_len=14) @skipIfNoFBGEMM def test_extract_weights_mod_qat(self): m = AllConvAndLinearFusionModules().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( - m, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_linear_fun_ptq(self): m = LinearReluLinearFunctional().eval() - self._test_extract_weights(m, results_len=2) + example_inputs = (torch.randn(1, 4),) + self._test_extract_weights(m, example_inputs, results_len=2) @skipIfNoFBGEMM def test_extract_weights_linear_fun_qat(self): m = LinearReluLinearFunctional().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 4),) self._test_extract_weights( - m, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_conv_fun_ptq(self): @@ -999,7 +1016,8 @@ def test_extract_weights_conv_fun_ptq(self): b2d = torch.randn(1) b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval() - self._test_extract_weights(m, results_len=6) + example_inputs = (torch.randn(1, 1, 1, 1),) + self._test_extract_weights(m, example_inputs, results_len=6) @skipIfNoFBGEMM def test_extract_weights_conv_fun_qat(self): @@ -1011,8 +1029,9 @@ def test_extract_weights_conv_fun_qat(self): b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( - m, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_dynamic(self): @@ -1023,7 +1042,8 @@ def test_extract_weights_dynamic(self): (nn.Linear, default_dynamic_qconfig), ], } - self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict) + example_inputs = (torch.randn(1, 1),) + self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_extract_weights_fqn(self): @@ -1032,7 +1052,8 @@ def test_extract_weights_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) results = extract_weights('a', mp, 'b', mq) fqn_a_0 = results['_0_0']['weight']['a'][0]['fqn'] @@ -1110,7 +1131,8 @@ def test_match_activations_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_ns, mq_ns = add_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) @@ -1187,7 +1209,8 @@ def test_shadow_activations_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) @@ -1254,7 +1277,8 @@ def test_add_mul_inputs_activations(self): def test_linear_fp16_weights(self): qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() - self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict) + example_inputs = (torch.randn(1, 4),) + self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_linear_fp16_activations(self): @@ -1292,7 +1316,8 @@ def test_linear_fp16_shadow_activations(self): def test_linear_fp16_vs_linear_fp16_shadow_activations(self): m = LinearFunctional().eval() qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(copy.deepcopy(mp)) mq1_shadows_mq2 = _add_shadow_loggers_impl( @@ -1332,8 +1357,9 @@ def _test_int8_shadows_int8_impl(self, m): Verify that shadowing works where both modules are int8 """ qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.randn(4, 1, 4, 4)) + example_inputs = (torch.randn(4, 1, 4, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(mp) mq1_shadows_mq2 = add_shadow_loggers('a', mq1, 'b', mq2, OutputLogger) @@ -1377,7 +1403,12 @@ def forward(self, x): prepare_custom_config_dict = { 'non_traceable_module_class': [M1], } - mp1 = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) + example_inputs = (torch.randn(1),) + mp1 = prepare_fx( + m, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) mp2 = copy.deepcopy(mp1) unmatchable_types_map = get_unmatchable_types_map() unmatchable_types_map['mods_unmatchable'].add(M1) @@ -1419,8 +1450,13 @@ def forward(self, x): # quantize without tracing through UserModule qconfig_dict = {'': torch.ao.quantization.default_qconfig} prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']} - mp = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) - mp(torch.randn(1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1),) + mp = prepare_fx( + m, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) + mp(*example_inputs) mq = convert_fx(copy.deepcopy(mp)) # weight extraction should not crash @@ -1661,8 +1697,9 @@ def forward(self, x): return x qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1 = prepare_fx(M1().eval(), qconfig_dict) - m2 = prepare_fx(M2().eval(), qconfig_dict) + example_inputs = (torch.randn(1, 1),) + m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) data = torch.randn(1, 1) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() @@ -1731,7 +1768,8 @@ def test_layer_names(self): nn.Sigmoid(), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights @@ -1767,7 +1805,9 @@ def test_layer_names(self): def test_extend_logger_results_with_comparison(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_dict, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights @@ -1792,7 +1832,9 @@ def test_extend_logger_results_with_comparison(self): def test_int8_shadows_fp32_simple(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.ReLU()).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_dict, example_inputs=example_inputs) mp(torch.randn(1, 1, 1, 1)) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) @@ -1846,8 +1888,9 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers( @@ -1862,18 +1905,18 @@ def forward(self, x): def test_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) - datum = torch.randn(1, 1, 1, 1) - ref_fp32 = mp(datum) - ref_int8 = mc(datum) + ref_fp32 = mp(*example_inputs) + ref_int8 = mc(*example_inputs) mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger) - ref_fp32_ns = mp_ns(datum) - ref_int8_ns = mc_ns(datum) + ref_fp32_ns = mp_ns(*example_inputs) + ref_int8_ns = mc_ns(*example_inputs) self.assertEqual(ref_fp32, ref_fp32_ns) self.assertEqual(ref_int8, ref_int8_ns) @@ -1881,17 +1924,17 @@ def test_loggers_preserve_qat_numerics(self): def test_shadow_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) - datum = torch.randn(1, 1, 1, 1) - ref_fp32 = mp(datum) - ref_int8 = mc(datum) + ref_fp32 = mp(*example_inputs) + ref_int8 = mc(*example_inputs) mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger) - ref_shadow = mc_shadows_mp(datum) + ref_shadow = mc_shadows_mp(*example_inputs) self.assertEqual(ref_fp32, ref_shadow) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @@ -1940,8 +1983,9 @@ def test_add_shadow_loggers_cuda(self): def test_fp16_shadows_fp32(self): m = LinearReluFunctional().eval() + example_inputs = (torch.randn(1, 4),) qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig} - mp = prepare_fx(copy.deepcopy(m), qconfig_dict) + mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mq = convert_fx(mp, is_reference=True) mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger) @@ -2005,7 +2049,8 @@ def test_compare_weights_conv(self): ) for m, in test_cases: m.eval() - self._test_extract_weights(m, results_len=1) + example_inputs = (torch.randn(1, 3, 5, 5),) + self._test_extract_weights(m, example_inputs, results_len=1) @skipIfNoFBGEMM def test_compare_weights_linear(self): @@ -2018,15 +2063,19 @@ def test_compare_weights_linear(self): ) for m, qconfig_dict in test_cases: m.eval() + example_inputs = (torch.randn(1, 3, 5, 5),) res = self._test_extract_weights( - m, results_len=1, qconfig_dict=qconfig_dict) + m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_weights_lstm_dynamic(self): qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} + lstm_input = torch.rand((1, 1, 2)) + lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) + example_inputs = (lstm_input, lstm_hidden) m = LSTMwithHiddenDynamicModel().eval() res = self._test_extract_weights( - m, results_len=1, qconfig_dict=qconfig_dict) + m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_activations_conv(self): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 27a83c5e7874..a83808119b67 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -248,7 +248,7 @@ def forward(self, x): # TODO: if we decide to do that in the future, this test needs to # be updated # train mode fuse_fx is called in prepare_qat_fx - m = prepare_qat_fx(m, {}) + m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),)) expected_nodes = [ ns.call_module(nni.ConvBn1d), ns.call_module(nni.ConvBn2d), @@ -401,8 +401,10 @@ def test_qconfig_fused_module(self): for M, node_list in tests: m = M().eval() - prepared = prepare_fx(m, qconfig_dict) - prepared(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + + prepared(*example_inputs) quantized = convert_fx(prepared) self.checkGraphModuleNodes(quantized, expected_node_list=node_list) @@ -435,7 +437,7 @@ def forward(self, x): (torch.nn.ReLU, get_default_qconfig('fbgemm')), ], } - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) @@ -732,7 +734,7 @@ def forward(self, x): (torch.nn.ReLU, default_qat_qconfig), ], } - prepared = prepare_qat_fx(model, qconfig_dict) + prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.nn.intrinsic.qat.LinearReLU)) def _get_conv_linear_test_cases(self, is_reference): @@ -1020,7 +1022,8 @@ def forward(self, x): m = M(torch.rand(1, 1)).eval() qconfig = default_dynamic_qconfig qconfig_dict = {'': qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.rand(1, 1),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) quantized = convert_fx(prepared, is_reference=True) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() @@ -1199,7 +1202,7 @@ def forward(self, x): node_occurrence[weight_prepack_node] = 0 m = ModuleClass(*module_constructor_inputs).eval() qconfig_dict = {"": float16_dynamic_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=inputs) m = convert_fx(m, is_reference=is_reference) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @@ -1232,12 +1235,12 @@ def forward(self, x): device = torch.device('cuda:0') model.to(device) + example_inputs = (torch.randn(4, 1, 4, 4, device=device),) # QAT prepare - model = prepare_qat_fx(model, qconfig_dict) + model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure that running an input on CUDA works without any needed changes - input = torch.randn(4, 1, 4, 4, device=device) - model(input) + model(*example_inputs) # ensure all buffers and parameters are on the device we expect model_devices = {p.device for p in model.parameters()} | \ @@ -1258,13 +1261,13 @@ def __init__(self): def forward(self, x): return {"output": self.conv(x["input"])} - dict_input = {"input": torch.randn(1, 1, 1, 1)} + example_inputs = ({"input": torch.randn(1, 1, 1, 1)},) m = M().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) - m(dict_input) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(dict_input) + m(*example_inputs) @override_qengines def test_attention(self): @@ -1287,7 +1290,7 @@ def forward(self, x): r = torch.mm(k, v) return q * k + r - tensor_input = torch.randn(3, 1, 1, 1) + example_inputs = (torch.randn(3, 1, 1, 1),) m = M().eval() qconfig_dict = { "": None, @@ -1296,10 +1299,10 @@ def forward(self, x): ] } # make sure it runs - m = prepare_fx(m, qconfig_dict) - m(tensor_input) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(tensor_input) + m(*example_inputs) def _test_standalone_module( self, @@ -1341,7 +1344,7 @@ def forward(self, x): x = self.conv2(x) return x - data = torch.randn(1, 1, 1, 1) + example_inputs = (torch.randn(1, 1, 1, 1),) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() @@ -1351,13 +1354,14 @@ def forward(self, x): original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) for is_name in [True, False]: + sm_example_inputs = example_inputs if is_name: prepare_config = { - "standalone_module_name": [("standalone", None, interface_config, None)] + "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)] } else: prepare_config = { - "standalone_module_class": [(StandaloneModule, None, interface_config, None)] + "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)] } original_m_copy = copy.deepcopy(original_m) @@ -1366,9 +1370,12 @@ def forward(self, x): qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( - original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) + original_m_copy, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_config) # calibration - m(data) + m(*example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) @@ -1376,13 +1383,17 @@ def forward(self, x): m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) - res = m(data) + res = m(*example_inputs) # quantize the reference model - ref_m = prepare_fx(original_ref_m_copy, qconfig_dict) - ref_m(data) + ref_m = prepare_fx( + original_ref_m_copy, + qconfig_dict, + example_inputs=example_inputs, + ) + ref_m(*example_inputs) ref_m = convert_fx(ref_m) - ref_res = ref_m(data) + ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) def test_standalone_module_float_interface(self): @@ -1471,11 +1482,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv2", None)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1499,11 +1510,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1541,10 +1552,11 @@ def forward(self, x): (torch.nn.ReLU, default_qat_qconfig), ], } - m = prepare_qat_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), @@ -1563,11 +1575,12 @@ def forward(self, x, y): m = M().eval() qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} - m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) - m(data, data) + example_inputs = (data, data) + m = prepare_fx(m, qconfig_dict, example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data, data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1590,11 +1603,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1636,7 +1649,7 @@ def forward(self, x): "object_type": [(nn.Conv2d, object_type_qconfig)], "module_name_regex": [("module_conv*", module_name_regex_qconfig)], "module_name": [("module_conv2", module_name_qconfig)]} - m_prep = prepare_fx(m, qconfig_dict) + m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),)) self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func) self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func) self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func) @@ -1702,11 +1715,11 @@ def forward(self, x): ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig), ], } - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) node_list = [ # m3 @@ -1759,11 +1772,11 @@ def forward(self, x): ("", torch.add, 1, None), ], } - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1819,7 +1832,7 @@ def forward(self, x): m = model(relu).eval() qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm") # should not crash as in https://github.com/pytorch/pytorch/issues/75825 - prepare_fx(m, qconfig_dict) + prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) def test_qconfig_dict_validity(self): r""" @@ -1831,7 +1844,7 @@ def test_qconfig_dict_validity(self): qconfig_dict = {"object_typo": [(torch.nn.Conv2d, default_qconfig)]} with self.assertRaises(ValueError) as context: - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) self.assertTrue( 'Expected qconfig_dict to have the following keys:' in str(context.exception) ) @@ -1848,7 +1861,11 @@ def test_prepare_custom_config_dict_validity(self): prepare_custom_config_dict = {"typo": None} with self.assertRaises(ValueError) as context: - m = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) + m = prepare_fx( + m, + qconfig_dict, + example_inputs=(torch.randn(1, 3, 3, 3),), + prepare_custom_config_dict=prepare_custom_config_dict) self.assertTrue( 'Expected prepare_custom_config_dict to have the following keys:' in str(context.exception) @@ -1863,7 +1880,7 @@ def test_convert_custom_config_dict_validity(self): """ m = ConvModel().eval() qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) convert_custom_config_dict = {"typo": None} with self.assertRaises(ValueError) as context: @@ -1885,11 +1902,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': default_qconfig} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) for name, module in m.named_modules(): self.assertFalse(hasattr(module, 'qconfig'), 'qconfig is not removed for ' + name) @@ -1901,7 +1918,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_default_quant_after_none_qconfig(self): @@ -1924,7 +1941,7 @@ def forward(self, x): ("conv1", None) ] } - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) m = convert_fx(m) def test_qconfig_for_call_method(self): @@ -1988,13 +2005,14 @@ def forward(self, x): (qconfig_dict1, node_list1), (qconfig_dict2, node_list2) ]: + example_inputs = (torch.randn(2, 1, 3, 3),) m = M().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(torch.randn(2, 1, 3, 3)) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_list=node_list) # make sure it runs - m(torch.randn(2, 1, 3, 3)) + m(*example_inputs) def test_qconfig_for_call_func(self): class Linear(torch.nn.Module): @@ -2021,9 +2039,10 @@ def forward(self, x): return x model = M().eval() + example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) node_list = [ @@ -2051,7 +2070,12 @@ def forward(self, x): prepare_custom_config_dict = { "preserved_attributes": ["preserved_attr"] } - m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx( + m, + {"": default_qconfig}, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) def assertAttrPreserved(m): self.assertTrue(hasattr(m, "preserved_attr")) @@ -2069,12 +2093,13 @@ def test_qat_and_script(self): model = LinearModelWithSubmodule().train() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} - model = prepare_qat_fx(model, qconfig_dict) + x = torch.randn(5, 5) + example_inputs = (x,) + model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure scripting works scripted = torch.jit.script(model) # run one round to make sure model runs - x = torch.randn(5, 5) scripted(x) FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ .run(scripted.graph) @@ -2104,10 +2129,10 @@ def test_save_observer_state_dict(self): orig = LinearModelWithSubmodule().eval() model = orig qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} - model = prepare_fx(model, qconfig_dict) + x = torch.randn(5, 5) + model = prepare_fx(model, qconfig_dict, example_inputs=(x,)) # run it through input - x = torch.randn(5, 5) model(x) quant = convert_fx(model) @@ -2120,7 +2145,7 @@ def test_save_observer_state_dict(self): # Load the stats into new model model_2 = orig - model_2 = prepare_fx(model_2, qconfig_dict) + model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,)) loaded_dict = torch.load(b) torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict) @@ -2208,7 +2233,6 @@ def forward(self, x): x = self.linear2(x) return x - data = torch.randn(3, 3) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() @@ -2255,13 +2279,15 @@ def forward(self, x): } } + example_inputs = (torch.randn(3, 3),) # check prepared model m = prepare_fx( original_m, qconfig_dict, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) # calibration - m(data) + m(*example_inputs) # all activation observers are inserted in the top level module count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers @@ -2280,13 +2306,13 @@ def forward(self, x): } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) self.assertEqual(type(m.custom), quantized_module_class) - res = m(data) + res = m(*example_inputs) # quantize the reference model - ref_m = prepare_fx(original_ref_m, qconfig_dict) - ref_m(data) + ref_m = prepare_fx(original_ref_m, qconfig_dict, example_inputs=example_inputs) + ref_m(*example_inputs) ref_m = convert_fx(ref_m) - ref_res = ref_m(data) + ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) @skipIfNoFBGEMM @@ -2360,16 +2386,18 @@ def forward(self, x0): } } m = M().eval() + example_inputs = (torch.randn(3, 3),) m = prepare_fx( m, {"": default_qconfig}, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) # make sure it works m = convert_fx( m, convert_custom_config_dict=convert_custom_config_dict) # make sure it runs - m(torch.randn(3, 3)) + m(*example_inputs) @skipIfNoFBGEMM def test_non_traceable_module(self): @@ -2415,6 +2443,7 @@ def forward(self, x): } m = prepare_fx( m, qconfig_dict, + example_inputs=({"key": torch.randn(1)},), prepare_custom_config_dict=prepare_custom_config_dict) node_occurrence = { @@ -2441,9 +2470,10 @@ def forward(self, x): m = M() m.eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(4, 1, 4, 4),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # calibrate - prepared(torch.randn(4, 1, 4, 4)) + prepared(*example_inputs) # copy prepared_copy = copy.deepcopy(prepared) # quantize, should run with no errors @@ -2460,21 +2490,21 @@ def __init__(self): def forward(self, x): return self.linear(x) - data = torch.rand(8, 5) + example_inputs = (torch.rand(8, 5),) m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) # test deepcopy m_copy = copy.deepcopy(m) - self.assertEqual(m_copy(data), m(data)) + self.assertEqual(m_copy(*example_inputs), m(*example_inputs)) # test state_dict state_dict = m.state_dict() m_new = M().eval() - m_new = prepare_fx(m_new, {"": default_qconfig}) + m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs) m_new = convert_fx(m_new) m_new.load_state_dict(state_dict) - self.assertEqual(m_new(data), m(data)) + self.assertEqual(m_new(*example_inputs), m(*example_inputs)) def test_dequantize(self): r""" Test to make sure dequantize node are placed before @@ -2542,12 +2572,14 @@ def forward(self, x): # quantized input, quantized output m = M() qconfig_dict = {'': torch.ao.quantization.default_qconfig} + example_inputs = (torch.randn(1, 1, 4, 4),) m.eval() mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) - mp(torch.randn(1, 1, 4, 4)) + mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(mp) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) @@ -2612,7 +2644,7 @@ def test_convtranspose_per_channel_fails_early(self): m.eval() qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} with self.assertRaises(AssertionError) as context: - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) self.assertTrue( str(context.exception) == 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') @@ -2644,8 +2676,9 @@ def forward(self, x): model = M().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0 @@ -2663,7 +2696,7 @@ def forward(self, x): self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict") self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict") # ensure it runs - m(torch.rand(5, 5)) + m(*example_inputs) # ensure it is scriptable scripted = torch.jit.script(m) scripted_keys = scripted.state_dict().keys() @@ -2708,9 +2741,10 @@ def forward(self, x): return x model = M().eval() + example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) assert hasattr(m, "mods1_0_packed_weight_0") assert hasattr(m, "mods1_1_packed_weight_0") @@ -2743,10 +2777,11 @@ def forward(self, x): return x model = M().eval() qconfig_dict = {"": float16_dynamic_qconfig} - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) # make sure it runs - m(torch.randn(5, 5)) + m(*example_inputs) def test_getattr_with_nontensor_result(self): """ @@ -2785,9 +2820,10 @@ def forward(self, x): for cls in (M1, M2, M3): m = cls().eval() - m(torch.rand(4, 4, 4, 4)) + example_inputs = (torch.rand(4, 4, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(torch.rand(4, 4, 4, 4)) mc = convert_fx(mp) @@ -2841,7 +2877,7 @@ def _check_node_not_observed(model, arg_node, node): def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args): model.eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} - prepared_model = prepare_fx(model, qconfig_dict) + prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args)) self._check_not_observed(prepared_model, node_info_to_non_tensor_args) prepared_model(*args) @@ -3035,12 +3071,13 @@ def forward(self, x): return x m = M().eval() - m(torch.rand(4, 1, 4, 4)) + example_inputs = (torch.rand(4, 1, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(4, 1, 4, 4)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.rand(4, 1, 4, 4)) + mc(*example_inputs) def test_fp32_sum(self): """ @@ -3073,12 +3110,13 @@ def forward(self, x): for cls in (M1, M2): m = cls().eval() - m(torch.rand(4, 1, 4, 4)) + example_inputs = (torch.rand(4, 1, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(4, 1, 4, 4)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.rand(4, 1, 4, 4)) + mc(*example_inputs) def test_fusion_pattern_unquantized(self): """ @@ -3113,8 +3151,9 @@ def forward(self, x): ('child', None), ], } - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(1, 1, 1, 1)) + example_inputs = (torch.rand(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) def test_state_dict(self): @@ -3133,7 +3172,7 @@ def forward(self, x): m = M1().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) @@ -3154,7 +3193,7 @@ def forward(self, x): m = M2().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) @@ -3164,7 +3203,7 @@ def forward(self, x): data = torch.rand(1, 3, 5, 5) ref_res = m(data) m = M2().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, (data,)) m = convert_fx(m) res = m(data) weight, bias = m._packed_weight_0.unpack() @@ -3186,7 +3225,7 @@ def checkModel(m, data, ref_weight, ref_bias, ref_res): # Test save to disk and load back m = M2().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(data,)) m = convert_fx(m) m.load_state_dict(state_dict) with TemporaryFileName() as fname: @@ -3229,8 +3268,9 @@ def forward(self, x): (torch.nn.functional.linear, float16_dynamic_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m, _remove_qconfig=False) self.assertTrue(hasattr(m.mods2, 'qconfig')) @@ -3250,7 +3290,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": float16_static_qconfig} # make sure quantization runs - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_qparams_fqn(self): @@ -3288,8 +3328,9 @@ def forward(self, x): (torch.nn.functional.relu, default_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() m(torch.randn(5, 5)) @@ -3321,12 +3362,13 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(4, 4, 4, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # if an observer is inserted after _user_func_with_complex_return_type, # the following call will fail - mp(torch.randn(4, 4, 4, 4)) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.randn(4, 4, 4, 4)) + mc(*example_inputs) def test_fold_quant_dequant(self): """ Test that the sequence of quant-dequant nodes in the @@ -3352,11 +3394,12 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() - m(torch.randn(5, 5)) + m(*example_inputs) dequant = 0 quant = 0 for n in m.graph.nodes: @@ -3375,7 +3418,7 @@ def test_quant_output_always_observed(self): """ qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} prepare_custom_config_dict = {'output_quantized_idxs': [0]} - data = (torch.randn(4, 1, 4, 4),) + example_inputs = (torch.randn(4, 1, 4, 4),) # non-quantizeable node, quantized output class M1(torch.nn.Module): @@ -3389,7 +3432,7 @@ def forward(self, x): m1 = M1() self.checkGraphModeFxOp( - m1, data, QuantType.QAT, + m1, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, }, @@ -3410,7 +3453,7 @@ def forward(self, x): m2 = M2() self.checkGraphModeFxOp( - m2, data, QuantType.QAT, + m2, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, @@ -3432,7 +3475,7 @@ def forward(self, x): m3 = M3() self.checkGraphModeFxOp( - m3, data, QuantType.QAT, + m3, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, @@ -3452,7 +3495,11 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict={"preserved_attributes": ["attr"]}) + m = prepare_fx( + m, + {"": default_qconfig}, + example_inputs=(torch.randn(1),), + prepare_custom_config_dict={"preserved_attributes": ["attr"]}) self.assertTrue(hasattr(m, "attr")) m2 = copy.deepcopy(m) self.assertTrue(hasattr(m2, "attr")) @@ -3475,7 +3522,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) mc = convert_fx(mp) def test_shape_followed_by_quantized_op(self): @@ -3497,9 +3544,10 @@ def forward(self, x): # make sure quantization runs m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(2, 2, 4, 4),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) - m(torch.randn(2, 2, 4, 4)) + m(*example_inputs) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 @@ -3517,7 +3565,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),)) m = convert_fx(m) # Make sure this runs without error m = torch.fx.Transformer(m).transform() @@ -3558,7 +3606,8 @@ def forward(self, x): qconfig = default_qconfig actpp_module_class = torch.ao.quantization.MinMaxObserver - m = prepare(m, {"": qconfig}) + example_inputs = (torch.randn(1, 3, 3, 3),) + m = prepare(m, {"": qconfig}, example_inputs=example_inputs) # check that there is a duplicated observer instance actpp_module_count = 0 for name, module in m.named_modules(remove_duplicate=False): @@ -3620,12 +3669,16 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": torch.ao.quantization.QConfig( - activation=torch.ao.quantization.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 - ), weight=torch.ao.quantization.default_per_channel_weight_observer)}) + example_inputs = (torch.rand(2, 1, 5, 5),) + m = prepare_fx( + m, + {"": torch.ao.quantization.QConfig( + activation=torch.ao.quantization.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, + example_inputs=example_inputs) m = convert_fx(m, is_reference=True) - m(torch.rand(2, 1, 5, 5)) + m(*example_inputs) def test_preserve_tuple(self): """ Test tuple input type is preserved @@ -3642,7 +3695,8 @@ def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]): return self.lstm(inputs, (h, c)) m = LSTM().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50)) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # make sure the arg[1] of lstm module is a tuple for n in m.graph.nodes: if n.target == "lstm": @@ -3654,7 +3708,7 @@ def forward(self, x): return torch.nn.functional.relu(x) m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_fx(m_copy, is_reference=True) @@ -3716,9 +3770,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3756,9 +3811,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.randn(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3796,9 +3852,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3828,13 +3885,13 @@ def forward(self, x): for M in [M1, M2]: m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(5, 10),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m_copy = copy.deepcopy(m) m = convert_fx(m, is_reference=False) m_ref = convert_fx(m_copy, is_reference=True) - data = torch.randn(5, 10) - result = m(data) - result_ref = m_ref(data) + result = m(*example_inputs) + result_ref = m_ref(*example_inputs) self.assertTrue(torch.equal(result, result_ref)) def test_ref_conv_module(self): @@ -3866,11 +3923,11 @@ def forward(self, x): for dim, M in itertools.product([1, 2, 3], [M1, M2]): m = M(dim).eval() - m = prepare_fx(m, {"": default_qconfig}) + data = self.img_data_dict[dim][0][0] + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_copy = copy.deepcopy(m) m = convert_fx(m, is_reference=False) m_ref = convert_fx(m_copy, is_reference=True) - data = self.img_data_dict[dim][0][0] result = m(data) result_ref = m_ref(data) self.assertTrue(torch.equal(result, result_ref)) @@ -3885,7 +3942,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),)) m = convert_fx(m) occurrence = { ns.call_function(torch.quantize_per_tensor): 2, @@ -3930,7 +3987,7 @@ def forward(self, x): model = M().eval() - prepared = prepare_fx(model, {"": default_qconfig}) + prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5))) name_list = [] for name, mod in prepared.named_modules(): if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver): @@ -3959,11 +4016,11 @@ def forward(self, x): for dim in range(1, len(convs) + 1): m = M(dim).eval() - m = prepare_fx(m, {"": default_qconfig}) + data = self.img_data_dict[dim][0][0] + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_ref = copy.deepcopy(m) m_ref = convert_fx(m_ref, is_reference=True) m = convert_fx(m) - data = self.img_data_dict[dim][0][0] out_ref = m_ref(data) out = m(data) # check that reference pattern for quantized conv module is fused @@ -4013,8 +4070,9 @@ def forward(self, x): (nn.Linear, get_default_qat_qconfig("fbgemm")), ], } - prepared = prepare_qat_fx(model, qconfig_dict) - prepared(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) + prepared(*example_inputs) if check == "module_name": convert_qconfig_dict = {"": None, "object_type": [ @@ -4133,7 +4191,8 @@ def forward(self, x): options = itertools.product([M1, M2], [True, False]) for M, is_qat in options: m = M1().eval() - m = prepare_fx(m, get_default_qconfig_dict()) + example_inputs = (torch.randn(1, 3, 3, 3),) + m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = convert_fx(m) node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -4146,7 +4205,7 @@ def forward(self, x): expected_node_list=node_list) m = M2().eval() - m = prepare_fx(m, get_default_qconfig_dict()) + m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, @@ -4171,7 +4230,7 @@ def forward(self, x): return x m = M().eval() - mp = prepare_fx(m, get_default_qconfig_dict()) + mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),)) found_stack_trace = False for n in mp.graph.nodes: @@ -4233,11 +4292,14 @@ def forward(self, x): "non_traceable_module_class": [UnTraceableModuleClass], "non_traceable_module_name": ["untraceable_module_name"], } + example_inputs = (torch.randn(2, 2),) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( - mod.train(), qconfig_dict, prepare_custom_config_dict + mod.train(), qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict ) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( - mod.train(), qconfig_dict, prepare_custom_config_dict + mod.train(), qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict ) self.assertTrue( isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear) @@ -4287,7 +4349,7 @@ def forward(self, x): for backend in backends: m = M().eval() qconfig_dict = func(backend) - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) for name, mod in m.named_modules(): if is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": @@ -4316,10 +4378,11 @@ def _test(prepare_fn, qconfig_dict): m = LinearModel() m1 = copy.deepcopy(m) m1.train() - prepare_fn(m1, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepare_fn(m1, qconfig_dict, example_inputs=example_inputs) m2 = copy.deepcopy(m) m2.eval() - prepare_fn(m2, qconfig_dict) + prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes _test(prepare_fx, get_default_qconfig_dict()) @@ -5016,7 +5079,8 @@ def forward(self, x, y): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(3), torch.randn(3)) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, @@ -5027,7 +5091,7 @@ def forward(self, x, y): # check the model is scriptable m = torch.jit.script(m) # check the model is runnable - m(torch.randn(3), torch.randn(3)) + m(*example_inputs) @skipIfNoFBGEMM def test_mul_relu(self): @@ -5037,9 +5101,9 @@ def test_mul_relu(self): operator.mul, operator.imul) # TODO(future PR): make more generic - def _test_quantized_add_mul_qat(self, model, expected_node_occurrence): + def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence): qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes( mp, expected_node_occurrence=expected_node_occurrence) @@ -5060,10 +5124,11 @@ def forward(self, x): return x m = M() + example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } - self._test_quantized_add_mul_qat(m, expected_node_occurrence) + self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) @skipIfNoFBGEMM def test_quantized_mul_qat(self): @@ -5082,10 +5147,11 @@ def forward(self, x): return x m = M() + example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } - self._test_quantized_add_mul_qat(m, expected_node_occurrence) + self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) def test_int8_input_no_unnecessary_fq(self): """ @@ -5105,6 +5171,7 @@ def forward(self, x): m = M(0.5) mp = torch.ao.quantization.quantize_fx.prepare_qat_fx( m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, + example_inputs=(torch.randn(1),), prepare_custom_config_dict={"input_quantized_idxs": [0]}) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1, @@ -5128,8 +5195,8 @@ def forward(self, x, y): y = self.conv2(y) return torch.cat([x, y], 1) - data = (torch.randn(1, 2, 5, 5, dtype=torch.float), - torch.randn(1, 2, 5, 5, dtype=torch.float)) + example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float), + torch.randn(1, 2, 5, 5, dtype=torch.float)) quantized_node = ns.call_function(torch.cat) options = itertools.product(self.static_quant_types, [True, False]) for quant_type, is_reference in options: @@ -5158,7 +5225,7 @@ def forward(self, x, y): self.checkGraphModeFxOp( M(), - data, + example_inputs, quant_type, quantized_node, expected_node_list=converted_node_list, @@ -5167,7 +5234,7 @@ def forward(self, x, y): # check cat is using the same observer for input and output m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # two inputs and one output of torch.cat are using same observer, so we have # 2 observers that's replicated all_observers = len(dict(m.named_modules(remove_duplicate=False))) @@ -5175,7 +5242,7 @@ def forward(self, x, y): self.assertEqual(all_observers, distinct_observers + 2) # make sure the converted model runs m = convert_fx(m) - m(*data) + m(*example_inputs) @skipIfNoFBGEMM def test_qbatch_norm(self): @@ -5630,6 +5697,7 @@ def forward(self, x, y): data_x = torch.randn((2, 2, 2,)) data_y = torch.randn((2, 2, 2,)) + example_inputs = (data_x, data_y) qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} is_reference = True node_list = [ @@ -5637,10 +5705,10 @@ def forward(self, x, y): ] m = M().eval() - m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) - m_prep(data_x, data_y) + m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m_prep(*example_inputs) m_quant = torch.ao.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference) - m_quant(data_x, data_y) + m_quant(*example_inputs) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @@ -5823,12 +5891,12 @@ def forward(self, x): x = self.conv2(x) return x - data = torch.rand(1, 3, 10, 10) + example_inputs = (torch.rand(1, 3, 10, 10),) # This model is not executable since we just put all ops # in the same forward m = M().eval() qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) @@ -5859,7 +5927,7 @@ def forward(self, x): # Checking the is_reference output m = M().eval() qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared, is_reference=True) @@ -5884,7 +5952,10 @@ def forward(self, x): m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict, prepare_custom_config_dict={"input_quantized_idxs": [0]}) + example_inputs = (torch.randn(1, 3, 3, 3),) + prepared = prepare_fx( + m, qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict={"input_quantized_idxs": [0]}) # not runnable quantized = convert_fx(prepared) @@ -5950,7 +6021,8 @@ def forward(self, x): m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 3, 3, 3),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) @@ -5983,7 +6055,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_reuse_input_qconfig}) + m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),)) m = convert_fx(m) # make sure it runs m(torch.rand(1)) @@ -5998,12 +6070,13 @@ def forward(self, xs): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) - m(torch.rand(1, 2)) + m(*example_inputs) class M2(torch.nn.Module): def forward(self, xs): @@ -6012,7 +6085,8 @@ def forward(self, xs): return x m2 = M2().eval() - m2 = prepare_fx(m2, {"": default_qconfig}) + example_inputs = ([torch.rand(1, 2)],) + m2 = prepare_fx(m2, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m2, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 1 }) @@ -6021,7 +6095,7 @@ def forward(self, xs): ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) - m2([torch.rand(1, 2)]) + m2(*example_inputs) # testing prepare recognizes non-Tensor input for getitem class M3(torch.nn.Module): @@ -6032,7 +6106,8 @@ def forward(self, x): return x m3 = M3().eval() - m3 = prepare_fx(m3, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2, 3, 4),) + m3 = prepare_fx(m3, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m3, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 1 }) @@ -6041,7 +6116,7 @@ def forward(self, x): ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) - m3(torch.rand(1, 2, 3, 4)) + m3(*example_inputs) @skipIfNoFBGEMM @@ -6090,12 +6165,12 @@ def forward(self, x): # nothing to fuse so skipping the fuse step m_copy = copy.deepcopy(m) qconfig_dict = {'': qconfig} - prepared = prepare(m, qconfig_dict) + example_inputs = (torch.rand(3, 3, 3, 3),) + prepared = prepare(m, qconfig_dict, example_inputs=example_inputs) prepared_copy = copy.deepcopy(prepared) # check that prepare does not change model result if eval_mode: - r = torch.rand(3, 3, 3, 3) - self.assertEqual(m_copy(r), prepared_copy(r)) + self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs)) # check the correct number of activation_post_process is inserted expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize count_check = { @@ -6183,7 +6258,7 @@ def forward(self, x): x = self.ff6.cat([x]) return x - data = torch.rand(3, 3) + example_inputs = (torch.rand(3, 3),) # Note: QAT test succeeded by chance, to make it actually work # we need to fix eager mode FloatFunctional by removing # activation_post_process in add_scalar and mul_scalar @@ -6204,13 +6279,13 @@ def forward(self, x): prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx qconfig_dict = {"": qconfig} - m = prepare_fx_function(m, qconfig_dict) + m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs) node_occurrence = { ns.call_module(expected_act_post_process): 7, ns.call_module(torch.nn.quantized.FloatFunctional): 0 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - m(data) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), @@ -6228,7 +6303,7 @@ def forward(self, x): ref_m.qconfig = qconfig prepare_function = prepare_qat if is_qat else prepare ref_m = prepare_function(ref_m) - ref_m(data) + ref_m(*example_inputs) ref_m = convert(ref_m) # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar # self.assertEqual(m(data), ref_m(data)) @@ -6245,6 +6320,7 @@ def forward(self, indices): for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: model = M().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + example_inputs = (indices,) quantized_node = ns.call_module(nnq.Embedding) configs = [ (qconfig_type, ns.call_module(nnq.Embedding)), @@ -6254,14 +6330,14 @@ def forward(self, indices): for qconfig, node in configs: qconfig_dict = {"": qconfig} - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=node) # make sure it runs - m(indices) + m(*example_inputs) def test_embedding_bag(self): class M(torch.nn.Module): @@ -6275,7 +6351,7 @@ def forward(self, indices, offsets): indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) quantized_node = ns.call_module(nnq.EmbeddingBag) - inputs = (indices, offsets) + example_inputs = (indices, offsets) for dtype in [torch.quint8, torch.quint4x2]: model = M().eval() @@ -6286,7 +6362,7 @@ def forward(self, indices, offsets): weight=float_qparams_observer) self.checkGraphModeFxOp( model, - inputs, + example_inputs, QuantType.DYNAMIC, quantized_node, custom_qconfig_dict={"": float_qparams_qconfig} @@ -6296,14 +6372,14 @@ def forward(self, indices, offsets): for qconfig in [None, default_qconfig]: qconfig_dict = {"": default_qconfig} m = M().eval() - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) # make sure it runs - m(*inputs) + m(*example_inputs) def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): options = itertools.product(qconfigs, module_type_strs) @@ -6323,7 +6399,7 @@ def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_inp (x, qconfig) for x in module_types ] } - model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,)) model_graph = convert_fx(model_graph) self.assertEqual(model_eager(sample_input), model_graph(sample_input)) self.checkScriptable(model_graph, [[sample_input]], True) @@ -6407,7 +6483,8 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig) ] } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) expected_occurrence = { # input and weight of first and second linear, output of first and second linear ns.call_module(torch.ao.quantization.MinMaxObserver): 6, @@ -6456,7 +6533,8 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig) ] } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) expected_occurrence = { # input and weight of linear, output of linear ns.call_module(torch.ao.quantization.MinMaxObserver): 3, @@ -6489,7 +6567,8 @@ def forward(self, x, mask): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) expected_occurrence = { ns.call_module(torch.ao.quantization.MinMaxObserver): 0 } @@ -6497,8 +6576,7 @@ def forward(self, x, mask): m, expected_node_occurrence=expected_occurrence) m = convert_fx(m) - m(torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) - return m + m(*example_inputs) def test_chunk(self): class M(torch.nn.Module): @@ -6507,11 +6585,11 @@ def forward(self, x): x = x + y return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) - data = torch.rand(2, 2, 2, 2) - m(data) + example_inputs = (torch.rand(2, 2, 2, 2),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # make sure everything runs def test_ref_pattern_multi_use(self): @@ -6536,7 +6614,8 @@ def forward(self, x): (torch.nn.ReLU, get_default_qconfig("fbgemm")), ], } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, @@ -6557,9 +6636,10 @@ def forward(self, x, y): return z m = M().eval() + example_inputs = (torch.randn(2, 2), torch.randn(2, 2)) qconfig_dict = {"": torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.randn(2, 2), torch.randn(2, 2)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = convert_fx(mp) expected_occurrence = { ns.call_function(torch.matmul): 0, @@ -6569,7 +6649,7 @@ def forward(self, x, y): mq, expected_node_occurrence=expected_occurrence) # verify no crash - res = mq(torch.randn(2, 2), torch.randn(2, 2)) + res = mq(*example_inputs) class TestQuantizeFxModels(QuantizationTestCase): @skipIfNoFBGEMM @@ -6589,12 +6669,13 @@ def forward(self, x): return y input = torch.randn((5, 1, 6, 6)).to('cuda') + example_inputs = (input,) model = Net().to('cuda').eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared = prepare_fx(model, qconfig_dict) - model_prepared(input) + model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + model_prepared(*example_inputs) model_quantized = convert_fx(model_prepared, is_reference=True) - out = model_quantized(input) + out = model_quantized(*example_inputs) self.assertEqual(out.device.type, 'cuda') @skipIfNoFBGEMM @@ -6618,7 +6699,7 @@ def forward(self, x): input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared = prepare_fx(model, qconfig_dict) + model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared(input) model_prepared.to(device_after) model_quantized = convert_fx(model_prepared, is_reference=True) @@ -6644,8 +6725,8 @@ def forward(self, x): input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared_first = prepare_fx(model, qconfig_dict) - model_prepared_second = prepare_fx(model, qconfig_dict) + model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,)) + model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared_first(input) state_dict = model_prepared_first.state_dict() del model_prepared_first @@ -6660,10 +6741,11 @@ def test_model_dropout(self): from torchvision import models m = models.mobilenet_v3_small() qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 3, 224, 224)) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = convert_fx(mp) - res = mq(torch.randn(1, 3, 224, 224)) + res = mq(*example_inputs) def _test_model_impl( self, mode, name, model, eager_quantizable_model, @@ -6813,7 +6895,7 @@ def _test_building_block(self, quant_type, BB): eager = eager_prepare(eager) qconfig_dict = {"": qconfig} - graph = graph_prepare(graph, qconfig_dict) + graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],)) eager_out = eager(data[0][0]) graph_out = graph(data[0][0]) @@ -6946,7 +7028,7 @@ def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingBagLinear().train() - prepared_fx_model = prepare_qat_fx(model, qconfig_dict) + prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_dict=qconfig_dict) @@ -6986,7 +7068,7 @@ def forward(self, input: torch.Tensor): eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingLinear().train() - prepared_fx_model = prepare_qat_fx(model, qconfig_dict) + prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_dict=qconfig_dict) @@ -7048,8 +7130,8 @@ def forward(self, x): activation=ref_fake_quant, weight=ref_weight_fake_quant ) qconfig_dict = {"": ref_qat_qconfig} - - prepared_ref = prepare_qat_fx(model, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, @@ -7069,7 +7151,7 @@ def forward(self, x): activation=custom_fake_quant, weight=custom_weight_fake_quant ) custom_qconfig_dict = {"": custom_qconfig} - prepared = prepare_qat_fx(model, custom_qconfig_dict) + prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs) prepared.to(device) prepared_ref.to(device) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 086b65e13c90..ad0113ab5285 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -34,7 +34,7 @@ from torch.ao.quantization.quantization_types import ( Pattern, - NodePattern + NodePattern, ) from ._equalize import ( @@ -231,7 +231,7 @@ def prepare_get_standalone_module_configs( prepare_custom_config_dict: Dict[str, Any], parent_qconfig: QConfigAny, parent_backend_config_dict: Optional[Dict[str, Any]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +) -> Tuple[Dict[str, Any], Tuple[Any], Dict[str, Any], Dict[str, Any]]: """ Returns the standalone module qconfig_dict and prepare_config_dict for `node`, assuming that the module pointed to by `node` is @@ -239,8 +239,11 @@ def prepare_get_standalone_module_configs( """ standalone_module_name = str(node.target) standalone_module_type = type(modules[standalone_module_name]) # type: ignore[index] - sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \ - get_standalone_module_configs(standalone_module_name, standalone_module_type, prepare_custom_config_dict) + sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, \ + sm_backend_config_dict = get_standalone_module_configs( + standalone_module_name, + standalone_module_type, + prepare_custom_config_dict) # fallback to use parent module's qconfig if user didn't specify qconfig dict if sm_qconfig_dict is None: sm_qconfig_dict = {"": parent_qconfig} @@ -250,7 +253,7 @@ def prepare_get_standalone_module_configs( # as well, this can be added later if sm_backend_config_dict is None: sm_backend_config_dict = parent_backend_config_dict - return sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict + return sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, sm_backend_config_dict def qat_swap_modules( root: torch.nn.Module, @@ -529,7 +532,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( else: # custom flow for standalone modules - _sm_qconfig_dict, sm_prepare_config_dict, _sm_backend_config_dict = \ + _sm_qconfig_dict, _, sm_prepare_config_dict, _sm_backend_config_dict = \ prepare_get_standalone_module_configs( node, modules, prepare_custom_config_dict, qconfig, backend_config_dict) @@ -1301,8 +1304,8 @@ def run_prepare_fx_on_standalone_modules( elif not qhandler.is_standalone_module(): continue - sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \ - prepare_get_standalone_module_configs( + sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, \ + sm_backend_config_dict = prepare_get_standalone_module_configs( root_node, modules, prepare_custom_config_dict, qconfig, backend_config_dict) standalone_module = modules[root_node.target] @@ -1313,7 +1316,8 @@ def run_prepare_fx_on_standalone_modules( standalone_module, sm_qconfig_dict, is_qat, - sm_prepare_config_dict, + example_inputs=sm_example_inputs, + prepare_custom_config_dict=sm_prepare_config_dict, backend_config_dict=sm_backend_config_dict) preserved_attributes = \ set(sm_prepare_config_dict.get("preserved_attributes", [])) @@ -1349,6 +1353,7 @@ def prepare( qconfig_dict: Any, is_qat: bool, node_name_to_scope: Dict[str, Tuple[str, type]], + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, diff --git a/torch/ao/quantization/fx/qconfig_utils.py b/torch/ao/quantization/fx/qconfig_utils.py index 4884ef08d0d6..79d3508facf2 100644 --- a/torch/ao/quantization/fx/qconfig_utils.py +++ b/torch/ao/quantization/fx/qconfig_utils.py @@ -319,9 +319,9 @@ def get_standalone_module_configs( custom_config_dict.get("standalone_module_name", []) standalone_module_class_configs = \ custom_config_dict.get("standalone_module_class", []) - class_config_map = {x[0]: (x[1], x[2], x[3]) for x in standalone_module_class_configs} - name_config_map = {x[0]: (x[1], x[2], x[3]) for x in standalone_module_name_configs} - config = class_config_map.get(module_type, (None, None, None)) + class_config_map = {x[0]: (x[1], x[2], x[3], x[4]) for x in standalone_module_class_configs} + name_config_map = {x[0]: (x[1], x[2], x[3], x[4]) for x in standalone_module_name_configs} + config = class_config_map.get(module_type, (None, None, None, None)) # name config has precedence over type config config = name_config_map.get(module_name, config) return config diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 64de11818bd1..1d16ce817144 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -19,7 +19,6 @@ from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 - def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): raise ValueError( @@ -178,6 +177,7 @@ def _prepare_fx( model: torch.nn.Module, qconfig_dict: Any, is_qat: bool, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, @@ -247,6 +247,7 @@ def _prepare_fx( qconfig_dict, is_qat, tracer.node_name_to_scope, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict, equalization_qconfig_dict=equalization_qconfig_dict, backend_config_dict=backend_config_dict, @@ -262,6 +263,7 @@ def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, is_qat: bool, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> GraphModule: @@ -291,6 +293,7 @@ def _prepare_standalone_module_fx( model, qconfig_dict, is_qat, + example_inputs, prepare_custom_config_dict, backend_config_dict=backend_config_dict, is_standalone_module=True, @@ -340,6 +343,7 @@ def fuse_fx( def prepare_fx( model: torch.nn.Module, qconfig_dict: Any, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, @@ -450,6 +454,7 @@ def prepare_fx( "preserved_attributes": ["preserved_attr"], } + * `example_inputs`: (required) Example inputs for forward function of the model * `equalization_qconfig_dict`: equalization_qconfig_dict is a dictionary with a similar structure as qconfig_dict except it will contain configurations specific to equalization techniques such as input-weight @@ -480,7 +485,8 @@ def calibrate(model, data_loader): model(image) qconfig_dict = {"": qconfig} - prepared_model = prepare_fx(float_model, qconfig_dict) + example_inputs = (torch.randn(1, 3, 224, 224),) + prepared_model = prepare_fx(float_model, qconfig_dict, example_inputs=example_inputs) # Run calibration calibrate(prepared_model, sample_inference_data) @@ -490,6 +496,7 @@ def calibrate(model, data_loader): model, qconfig_dict, False, # is_qat + example_inputs, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict, @@ -499,6 +506,7 @@ def calibrate(model, data_loader): def prepare_qat_fx( model: torch.nn.Module, qconfig_dict: Any, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> ObservedGraphModule: @@ -507,6 +515,7 @@ def prepare_qat_fx( Args: * `model`: torch.nn.Module model, must be in train mode * `qconfig_dict`: see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs`: see :func:`~torch.ao.quantization.prepare_fx` * `prepare_custom_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` * `backend_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` @@ -538,6 +547,7 @@ def train_loop(model, train_data): model, qconfig_dict, True, # is_qat + example_inputs, prepare_custom_config_dict, backend_config_dict=backend_config_dict, ) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 42b1ea20c225..7687e28aa915 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -864,6 +864,7 @@ def checkGraphModeFxOp( qconfig_dict = custom_qconfig_dict prepared = prepare( model, qconfig_dict, + example_inputs=inputs, prepare_custom_config_dict=prepare_custom_config_dict, backend_config_dict=backend_config_dict) if not quant_type == QuantType.DYNAMIC: