diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index b89d43c3e3e5..ef8f49b3b2e2 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -171,9 +171,10 @@ def _do_quant_transforms( m: torch.nn.Module, input_tensor: torch.Tensor, ) -> torch.nn.Module: + example_inputs = (input_tensor,) # do the quantizaton transforms and save result qconfig = torch.quantization.get_default_qconfig('fbgemm') - mp = quantize_fx.prepare_fx(m, {'': qconfig}) + mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs) mp(input_tensor) mq = quantize_fx.convert_fx(mp) return mq diff --git a/test/quantization/dbr/test_quantize_dbr.py b/test/quantization/dbr/test_quantize_dbr.py index cd6dd6968ad0..4d7a2c5aabaa 100644 --- a/test/quantization/dbr/test_quantize_dbr.py +++ b/test/quantization/dbr/test_quantize_dbr.py @@ -82,7 +82,7 @@ def _test_auto_tracing( # compare it against FX if do_fx_comparison: - m_copy_p = prepare_fx(m_copy, {'': qconfig}) + m_copy_p = prepare_fx(m_copy, {'': qconfig}, example_inputs=example_args) out_m_copy_p = m_copy_p(*example_args) # print(m_copy_p) m_copy_q = convert_fx(m_copy_p) @@ -1236,11 +1236,11 @@ def test_qconfig_dict_unsupported_does_not_crash_when_empty(self): """ m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() qconfig_dict = {'': torch.quantization.default_qconfig} + example_inputs = (torch.randn(1, 1, 1, 1),) # this modifies qconfig_dict inplace to include more keys - mp = prepare_fx(m, qconfig_dict) - example_args = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # need this line to not crash - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) def _test_serialization(self, model, input_shape): example_inputs = (torch.randn(*input_shape),) @@ -1324,15 +1324,15 @@ def test_jit_tracing_removes_aliases(self): ), ) qconfig_dict = {'': torch.quantization.default_qconfig} - example_args = (torch.randn(1, 1, 1, 1),) - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) mq = _quantize_dbr.convert(mp) - mqs = torch.jit.trace(mq, example_args) + mqs = torch.jit.trace(mq, example_inputs) FileCheck().check_count("aten::alias", 5, exactly=True).run( mqs.inlined_graph) - res1 = mqs(*example_args) + res1 = mqs(*example_inputs) mqs = remove_redundant_aliases(mqs) - res2 = mqs(*example_args) + res2 = mqs(*example_inputs) self.assertTrue(torch.allclose(res1, res2)) # TODO(future PR): figure out why aliasing still appears in the inlined # graph, and if that is fixed then just check the inlined graph. @@ -1609,11 +1609,11 @@ def test_mobilenet_v2_removes_aliases(self): m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False)\ .eval().float() qconfig_dict = {'': torch.quantization.default_qconfig} - example_args = (torch.randn(1, 3, 224, 224),) - mp = _quantize_dbr.prepare(m, qconfig_dict, example_args) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs) mq = _quantize_dbr.convert(mp) - mqs = torch.jit.trace(mq, example_args) - res1 = mqs(*example_args) + mqs = torch.jit.trace(mq, example_inputs) + res1 = mqs(*example_inputs) mqs = remove_redundant_aliases(mqs) - res2 = mqs(*example_args) + res2 = mqs(*example_inputs) self.assertTrue(torch.allclose(res1, res2)) diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index a552fc0a0f1e..2e211853aa16 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -274,7 +274,14 @@ def test_input_weight_equalization_prepare(self): for (M, node_occurrence) in tests: m = M().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + # TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test + # for now we do not need shape so this can be fixed later + example_inputs = (torch.randn(1, 1, 1, 1),) + prepared = prepare_fx( + m, + specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) def test_input_weight_equalization_branching(self): @@ -305,7 +312,10 @@ def forward(self, x): } m = TestBranchingWithoutEqualizationModel().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared = prepare_fx( + m, specific_qconfig_dict, example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence) # Tests that we will add an equalization observer because there is only @@ -326,7 +336,10 @@ def forward(self, x): } m = TestBranchingWithEqualizationModel().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared = prepare_fx( + m, specific_qconfig_dict, example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence) @skipIfNoFBGEMM @@ -353,9 +366,11 @@ def test_input_weight_equalization_convert(self): elif ndim == 4: x = torch.rand((16, 3, 224, 224)) + example_inputs = (x,) prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=default_equalization_qconfig_dict ) output = prepared(x) @@ -363,7 +378,10 @@ def test_input_weight_equalization_convert(self): convert_ref = _convert_equalization_ref(prepared) convert_ref_output = convert_ref(x) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_fx(prepared) # Check if compile self.assertEqual(output, convert_ref_output) @@ -411,8 +429,12 @@ def test_input_weight_equalization_equalization_scales(self): m = M().eval() exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) - prepared(x) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) + prepared(*example_inputs) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -460,7 +482,11 @@ def test_input_weight_equalization_weights_bias(self): exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy()) exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -516,7 +542,11 @@ def test_input_weight_equalization_activation_values(self): exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias) exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights) - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + example_inputs = (x,) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) prepared(x) convert_ref = _convert_equalization_ref(prepared) convert_ref(x) @@ -751,7 +781,13 @@ def test_input_weight_equalization_graphs(self): for (M, node_list) in tests: m = M().eval() - prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict) + # TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test + # for now we do not need shape so this can be fixed later + example_inputs = (torch.randn(1, 1, 1, 1),) + prepared = prepare_fx( + m, specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict=default_equalization_qconfig_dict) equalized_quantized_model = convert_fx(prepared) # Check the order of nodes in the graph @@ -771,7 +807,12 @@ def test_input_weight_equalization_results(self): m = M().eval() # No equalization - prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={}) + example_inputs = (x,) + prepared = prepare_fx( + copy.deepcopy(m), + specific_qconfig_dict, + example_inputs=example_inputs, + equalization_qconfig_dict={}) prepared(x) quantized = convert_fx(prepared) # Check if compile quantized_output = quantized(x) @@ -780,6 +821,7 @@ def test_input_weight_equalization_results(self): prepared = prepare_fx( copy.deepcopy(m), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=default_equalization_qconfig_dict ) prepared(x) @@ -817,7 +859,12 @@ def forward(self, x): [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]]) # Quantize the float model - prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict) + example_inputs = (x,) + prepared_model = prepare_fx( + copy.deepcopy(float_model), + specific_qconfig_dict, + example_inputs=example_inputs + ) prepared_model(x) quantized_model = convert_fx(copy.deepcopy(prepared_model)) @@ -832,6 +879,7 @@ def forward(self, x): prepared_model = prepare_fx( copy.deepcopy(float_model), specific_qconfig_dict, + example_inputs=example_inputs, equalization_qconfig_dict=selective_equalization_qconfig_dict, ) prepared_model(x) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 4559c6389be6..a83c5cb5eefe 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -295,7 +295,7 @@ class TestFXGraphMatcher(QuantizationTestCase): @skipIfNoFBGEMM def test_simple_mod(self): m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -322,7 +322,7 @@ def forward(self, x): return F.linear(x, self.w, self.b) m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -340,7 +340,7 @@ def forward(self, x): @skipIfNoFBGEMM def test_simple_fusion(self): m = LinearReluFunctional().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(4, 4),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -363,7 +363,7 @@ def test_simple_mod_multi(self): ), nn.Conv2d(1, 1, 1), ).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions @@ -380,7 +380,8 @@ def forward(self, x, y): return z m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1), torch.randn(1)) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions @@ -392,8 +393,9 @@ def test_matching_failure_node_count(self): # different counts of matchable nodes fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() - mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}) - mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) + mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @@ -402,8 +404,10 @@ def test_matching_failure_node_type(self): # verify that matching graphs with non-matching node types fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Linear(1, 1)).eval() - mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}) - mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) + example_inputs = (torch.randn(1, 1),) + mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_subgraph_pairs(mp1, mp2) @@ -421,7 +425,8 @@ def forward(self, x0): return x2 m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -456,7 +461,8 @@ def forward(self, x0): return a1 m = M().eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -499,7 +505,8 @@ def forward(self, x): '': torch.ao.quantization.default_qconfig, 'module_name': [('conv2', None)], } - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -541,8 +548,9 @@ def forward(self, x): m1 = M().eval() m2 = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1p = prepare_fx(m1, qconfig_dict) - m2p = prepare_fx(m2, qconfig_dict) + example_inputs = (torch.randn(1),) + m1p = prepare_fx(m1, qconfig_dict, example_inputs=example_inputs) + m2p = prepare_fx(m2, qconfig_dict, example_inputs=example_inputs) results = get_matching_subgraph_pairs(m1p, m2p) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() sigmoid_name_0 = 'base_op_' + get_base_name_for_op( @@ -733,8 +741,9 @@ def forward(self, x): return x qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1 = prepare_fx(M1().eval(), qconfig_dict) - m2 = prepare_fx(M2().eval(), qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() add_op_to_sets_of_related_ops( @@ -760,7 +769,8 @@ def test_results_order(self): nn.Conv2d(1, 1, 1), nn.Linear(1, 1), ).eval() - mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_subgraph_pairs(mp, mq) @@ -782,7 +792,8 @@ def test_mobilenet_v2(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float() - mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) @@ -796,9 +807,11 @@ def test_mobilenet_v2_qat(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() + example_inputs = (torch.randn(1, 3, 224, 224),) mp = prepare_qat_fx( copy.deepcopy(m), - {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}) + {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, + example_inputs=example_inputs) # assume success if no exceptions results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) mp_copy = copy.deepcopy(mp) @@ -809,12 +822,12 @@ def test_mobilenet_v2_qat(self): class FXNumericSuiteQuantizationTestCase(QuantizationTestCase): def _test_extract_weights( - self, m, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx + self, m, example_inputs, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx ): m = torch.fx.symbolic_trace(m) if qconfig_dict is None: qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -848,7 +861,7 @@ def _test_match_activations( m.eval() else: m.train() - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -911,7 +924,7 @@ def _test_match_shadow_activations( m.eval() else: m.train() - mp = prepare_fn(copy.deepcopy(m), qconfig_dict) + mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) mp(*data) mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) @@ -969,26 +982,30 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase): @skipIfNoFBGEMM def test_extract_weights_mod_ptq(self): m = AllConvAndLinearFusionModules().eval() - self._test_extract_weights(m, results_len=14) + example_inputs = (torch.randn(1, 1, 1, 1),) + self._test_extract_weights(m, example_inputs, results_len=14) @skipIfNoFBGEMM def test_extract_weights_mod_qat(self): m = AllConvAndLinearFusionModules().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( - m, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_linear_fun_ptq(self): m = LinearReluLinearFunctional().eval() - self._test_extract_weights(m, results_len=2) + example_inputs = (torch.randn(1, 4),) + self._test_extract_weights(m, example_inputs, results_len=2) @skipIfNoFBGEMM def test_extract_weights_linear_fun_qat(self): m = LinearReluLinearFunctional().train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 4),) self._test_extract_weights( - m, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_conv_fun_ptq(self): @@ -999,7 +1016,8 @@ def test_extract_weights_conv_fun_ptq(self): b2d = torch.randn(1) b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval() - self._test_extract_weights(m, results_len=6) + example_inputs = (torch.randn(1, 1, 1, 1),) + self._test_extract_weights(m, example_inputs, results_len=6) @skipIfNoFBGEMM def test_extract_weights_conv_fun_qat(self): @@ -1011,8 +1029,9 @@ def test_extract_weights_conv_fun_qat(self): b3d = torch.randn(1) m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train() qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} + example_inputs = (torch.randn(1, 1, 1, 1),) self._test_extract_weights( - m, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) + m, example_inputs, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) @skipIfNoFBGEMM def test_extract_weights_dynamic(self): @@ -1023,7 +1042,8 @@ def test_extract_weights_dynamic(self): (nn.Linear, default_dynamic_qconfig), ], } - self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict) + example_inputs = (torch.randn(1, 1),) + self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_extract_weights_fqn(self): @@ -1032,7 +1052,8 @@ def test_extract_weights_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) results = extract_weights('a', mp, 'b', mq) fqn_a_0 = results['_0_0']['weight']['a'][0]['fqn'] @@ -1110,7 +1131,8 @@ def test_match_activations_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_ns, mq_ns = add_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) @@ -1187,7 +1209,8 @@ def test_shadow_activations_fqn(self): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger) datum = torch.randn(1, 1, 1, 1) @@ -1254,7 +1277,8 @@ def test_add_mul_inputs_activations(self): def test_linear_fp16_weights(self): qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} m = LinearReluFunctional().eval() - self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict) + example_inputs = (torch.randn(1, 4),) + self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_linear_fp16_activations(self): @@ -1292,7 +1316,8 @@ def test_linear_fp16_shadow_activations(self): def test_linear_fp16_vs_linear_fp16_shadow_activations(self): m = LinearFunctional().eval() qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(copy.deepcopy(mp)) mq1_shadows_mq2 = _add_shadow_loggers_impl( @@ -1332,8 +1357,9 @@ def _test_int8_shadows_int8_impl(self, m): Verify that shadowing works where both modules are int8 """ qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.randn(4, 1, 4, 4)) + example_inputs = (torch.randn(4, 1, 4, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq1 = convert_fx(copy.deepcopy(mp)) mq2 = convert_fx(mp) mq1_shadows_mq2 = add_shadow_loggers('a', mq1, 'b', mq2, OutputLogger) @@ -1377,7 +1403,12 @@ def forward(self, x): prepare_custom_config_dict = { 'non_traceable_module_class': [M1], } - mp1 = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) + example_inputs = (torch.randn(1),) + mp1 = prepare_fx( + m, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) mp2 = copy.deepcopy(mp1) unmatchable_types_map = get_unmatchable_types_map() unmatchable_types_map['mods_unmatchable'].add(M1) @@ -1419,8 +1450,13 @@ def forward(self, x): # quantize without tracing through UserModule qconfig_dict = {'': torch.ao.quantization.default_qconfig} prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']} - mp = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) - mp(torch.randn(1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1),) + mp = prepare_fx( + m, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) + mp(*example_inputs) mq = convert_fx(copy.deepcopy(mp)) # weight extraction should not crash @@ -1661,8 +1697,9 @@ def forward(self, x): return x qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m1 = prepare_fx(M1().eval(), qconfig_dict) - m2 = prepare_fx(M2().eval(), qconfig_dict) + example_inputs = (torch.randn(1, 1),) + m1 = prepare_fx(M1().eval(), qconfig_dict, example_inputs=example_inputs) + m2 = prepare_fx(M2().eval(), qconfig_dict, example_inputs=example_inputs) data = torch.randn(1, 1) base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() @@ -1731,7 +1768,8 @@ def test_layer_names(self): nn.Sigmoid(), ).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights @@ -1767,7 +1805,9 @@ def test_layer_names(self): def test_extend_logger_results_with_comparison(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_dict, example_inputs=example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) # extract weights @@ -1792,7 +1832,9 @@ def test_extend_logger_results_with_comparison(self): def test_int8_shadows_fp32_simple(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.ReLU()).eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_dict, example_inputs=example_inputs) mp(torch.randn(1, 1, 1, 1)) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) @@ -1846,8 +1888,9 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) mp_shadows_mq = add_shadow_loggers( @@ -1862,18 +1905,18 @@ def forward(self, x): def test_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) - datum = torch.randn(1, 1, 1, 1) - ref_fp32 = mp(datum) - ref_int8 = mc(datum) + ref_fp32 = mp(*example_inputs) + ref_int8 = mc(*example_inputs) mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger) - ref_fp32_ns = mp_ns(datum) - ref_int8_ns = mc_ns(datum) + ref_fp32_ns = mp_ns(*example_inputs) + ref_int8_ns = mc_ns(*example_inputs) self.assertEqual(ref_fp32, ref_fp32_ns) self.assertEqual(ref_int8, ref_int8_ns) @@ -1881,17 +1924,17 @@ def test_loggers_preserve_qat_numerics(self): def test_shadow_loggers_preserve_qat_numerics(self): m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 1, 1, 1)) + example_inputs = (torch.randn(1, 1, 1, 1),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(copy.deepcopy(mp)) mp.apply(torch.ao.quantization.disable_observer) - datum = torch.randn(1, 1, 1, 1) - ref_fp32 = mp(datum) - ref_int8 = mc(datum) + ref_fp32 = mp(*example_inputs) + ref_int8 = mc(*example_inputs) mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger) - ref_shadow = mc_shadows_mp(datum) + ref_shadow = mc_shadows_mp(*example_inputs) self.assertEqual(ref_fp32, ref_shadow) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @@ -1940,8 +1983,9 @@ def test_add_shadow_loggers_cuda(self): def test_fp16_shadows_fp32(self): m = LinearReluFunctional().eval() + example_inputs = (torch.randn(1, 4),) qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig} - mp = prepare_fx(copy.deepcopy(m), qconfig_dict) + mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) mq = convert_fx(mp, is_reference=True) mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger) @@ -2005,7 +2049,8 @@ def test_compare_weights_conv(self): ) for m, in test_cases: m.eval() - self._test_extract_weights(m, results_len=1) + example_inputs = (torch.randn(1, 3, 5, 5),) + self._test_extract_weights(m, example_inputs, results_len=1) @skipIfNoFBGEMM def test_compare_weights_linear(self): @@ -2018,15 +2063,19 @@ def test_compare_weights_linear(self): ) for m, qconfig_dict in test_cases: m.eval() + example_inputs = (torch.randn(1, 3, 5, 5),) res = self._test_extract_weights( - m, results_len=1, qconfig_dict=qconfig_dict) + m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_weights_lstm_dynamic(self): qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} + lstm_input = torch.rand((1, 1, 2)) + lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) + example_inputs = (lstm_input, lstm_hidden) m = LSTMwithHiddenDynamicModel().eval() res = self._test_extract_weights( - m, results_len=1, qconfig_dict=qconfig_dict) + m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) @skipIfNoFBGEMM def test_compare_activations_conv(self): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 27a83c5e7874..a83808119b67 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -248,7 +248,7 @@ def forward(self, x): # TODO: if we decide to do that in the future, this test needs to # be updated # train mode fuse_fx is called in prepare_qat_fx - m = prepare_qat_fx(m, {}) + m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),)) expected_nodes = [ ns.call_module(nni.ConvBn1d), ns.call_module(nni.ConvBn2d), @@ -401,8 +401,10 @@ def test_qconfig_fused_module(self): for M, node_list in tests: m = M().eval() - prepared = prepare_fx(m, qconfig_dict) - prepared(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + + prepared(*example_inputs) quantized = convert_fx(prepared) self.checkGraphModuleNodes(quantized, expected_node_list=node_list) @@ -435,7 +437,7 @@ def forward(self, x): (torch.nn.ReLU, get_default_qconfig('fbgemm')), ], } - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) @@ -732,7 +734,7 @@ def forward(self, x): (torch.nn.ReLU, default_qat_qconfig), ], } - prepared = prepare_qat_fx(model, qconfig_dict) + prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.nn.intrinsic.qat.LinearReLU)) def _get_conv_linear_test_cases(self, is_reference): @@ -1020,7 +1022,8 @@ def forward(self, x): m = M(torch.rand(1, 1)).eval() qconfig = default_dynamic_qconfig qconfig_dict = {'': qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.rand(1, 1),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) quantized = convert_fx(prepared, is_reference=True) qparams = (quantized._scale_0, quantized._zero_point_0) weight_obs = qconfig.weight() @@ -1199,7 +1202,7 @@ def forward(self, x): node_occurrence[weight_prepack_node] = 0 m = ModuleClass(*module_constructor_inputs).eval() qconfig_dict = {"": float16_dynamic_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=inputs) m = convert_fx(m, is_reference=is_reference) self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) @@ -1232,12 +1235,12 @@ def forward(self, x): device = torch.device('cuda:0') model.to(device) + example_inputs = (torch.randn(4, 1, 4, 4, device=device),) # QAT prepare - model = prepare_qat_fx(model, qconfig_dict) + model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure that running an input on CUDA works without any needed changes - input = torch.randn(4, 1, 4, 4, device=device) - model(input) + model(*example_inputs) # ensure all buffers and parameters are on the device we expect model_devices = {p.device for p in model.parameters()} | \ @@ -1258,13 +1261,13 @@ def __init__(self): def forward(self, x): return {"output": self.conv(x["input"])} - dict_input = {"input": torch.randn(1, 1, 1, 1)} + example_inputs = ({"input": torch.randn(1, 1, 1, 1)},) m = M().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) - m(dict_input) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(dict_input) + m(*example_inputs) @override_qengines def test_attention(self): @@ -1287,7 +1290,7 @@ def forward(self, x): r = torch.mm(k, v) return q * k + r - tensor_input = torch.randn(3, 1, 1, 1) + example_inputs = (torch.randn(3, 1, 1, 1),) m = M().eval() qconfig_dict = { "": None, @@ -1296,10 +1299,10 @@ def forward(self, x): ] } # make sure it runs - m = prepare_fx(m, qconfig_dict) - m(tensor_input) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(tensor_input) + m(*example_inputs) def _test_standalone_module( self, @@ -1341,7 +1344,7 @@ def forward(self, x): x = self.conv2(x) return x - data = torch.randn(1, 1, 1, 1) + example_inputs = (torch.randn(1, 1, 1, 1),) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() @@ -1351,13 +1354,14 @@ def forward(self, x): original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) for is_name in [True, False]: + sm_example_inputs = example_inputs if is_name: prepare_config = { - "standalone_module_name": [("standalone", None, interface_config, None)] + "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)] } else: prepare_config = { - "standalone_module_class": [(StandaloneModule, None, interface_config, None)] + "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)] } original_m_copy = copy.deepcopy(original_m) @@ -1366,9 +1370,12 @@ def forward(self, x): qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( - original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) + original_m_copy, + qconfig_dict, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_config) # calibration - m(data) + m(*example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) @@ -1376,13 +1383,17 @@ def forward(self, x): m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) - res = m(data) + res = m(*example_inputs) # quantize the reference model - ref_m = prepare_fx(original_ref_m_copy, qconfig_dict) - ref_m(data) + ref_m = prepare_fx( + original_ref_m_copy, + qconfig_dict, + example_inputs=example_inputs, + ) + ref_m(*example_inputs) ref_m = convert_fx(ref_m) - ref_res = ref_m(data) + ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) def test_standalone_module_float_interface(self): @@ -1471,11 +1482,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv2", None)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1499,11 +1510,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1541,10 +1552,11 @@ def forward(self, x): (torch.nn.ReLU, default_qat_qconfig), ], } - m = prepare_qat_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nniq.LinearReLU), @@ -1563,11 +1575,12 @@ def forward(self, x, y): m = M().eval() qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} - m = prepare_fx(m, qconfig_dict) data = torch.randn(1, 1, 1, 1) - m(data, data) + example_inputs = (data, data) + m = prepare_fx(m, qconfig_dict, example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data, data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1590,11 +1603,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # first conv is quantized, second conv is not quantized node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1636,7 +1649,7 @@ def forward(self, x): "object_type": [(nn.Conv2d, object_type_qconfig)], "module_name_regex": [("module_conv*", module_name_regex_qconfig)], "module_name": [("module_conv2", module_name_qconfig)]} - m_prep = prepare_fx(m, qconfig_dict) + m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),)) self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func) self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func) self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func) @@ -1702,11 +1715,11 @@ def forward(self, x): ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig), ], } - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) node_list = [ # m3 @@ -1759,11 +1772,11 @@ def forward(self, x): ("", torch.add, 1, None), ], } - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -1819,7 +1832,7 @@ def forward(self, x): m = model(relu).eval() qconfig_dict = torch.ao.quantization.get_default_qconfig_dict("fbgemm") # should not crash as in https://github.com/pytorch/pytorch/issues/75825 - prepare_fx(m, qconfig_dict) + prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) def test_qconfig_dict_validity(self): r""" @@ -1831,7 +1844,7 @@ def test_qconfig_dict_validity(self): qconfig_dict = {"object_typo": [(torch.nn.Conv2d, default_qconfig)]} with self.assertRaises(ValueError) as context: - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) self.assertTrue( 'Expected qconfig_dict to have the following keys:' in str(context.exception) ) @@ -1848,7 +1861,11 @@ def test_prepare_custom_config_dict_validity(self): prepare_custom_config_dict = {"typo": None} with self.assertRaises(ValueError) as context: - m = prepare_fx(m, qconfig_dict, prepare_custom_config_dict) + m = prepare_fx( + m, + qconfig_dict, + example_inputs=(torch.randn(1, 3, 3, 3),), + prepare_custom_config_dict=prepare_custom_config_dict) self.assertTrue( 'Expected prepare_custom_config_dict to have the following keys:' in str(context.exception) @@ -1863,7 +1880,7 @@ def test_convert_custom_config_dict_validity(self): """ m = ConvModel().eval() qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) convert_custom_config_dict = {"typo": None} with self.assertRaises(ValueError) as context: @@ -1885,11 +1902,11 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': default_qconfig} - m = prepare_fx(m, qconfig_dict) - data = torch.randn(1, 1, 1, 1) - m(data) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) for name, module in m.named_modules(): self.assertFalse(hasattr(module, 'qconfig'), 'qconfig is not removed for ' + name) @@ -1901,7 +1918,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_default_quant_after_none_qconfig(self): @@ -1924,7 +1941,7 @@ def forward(self, x): ("conv1", None) ] } - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) m = convert_fx(m) def test_qconfig_for_call_method(self): @@ -1988,13 +2005,14 @@ def forward(self, x): (qconfig_dict1, node_list1), (qconfig_dict2, node_list2) ]: + example_inputs = (torch.randn(2, 1, 3, 3),) m = M().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m(torch.randn(2, 1, 3, 3)) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node_list=node_list) # make sure it runs - m(torch.randn(2, 1, 3, 3)) + m(*example_inputs) def test_qconfig_for_call_func(self): class Linear(torch.nn.Module): @@ -2021,9 +2039,10 @@ def forward(self, x): return x model = M().eval() + example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) node_list = [ @@ -2051,7 +2070,12 @@ def forward(self, x): prepare_custom_config_dict = { "preserved_attributes": ["preserved_attr"] } - m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict) + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx( + m, + {"": default_qconfig}, + example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict) def assertAttrPreserved(m): self.assertTrue(hasattr(m, "preserved_attr")) @@ -2069,12 +2093,13 @@ def test_qat_and_script(self): model = LinearModelWithSubmodule().train() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} - model = prepare_qat_fx(model, qconfig_dict) + x = torch.randn(5, 5) + example_inputs = (x,) + model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) # ensure scripting works scripted = torch.jit.script(model) # run one round to make sure model runs - x = torch.randn(5, 5) scripted(x) FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ .run(scripted.graph) @@ -2104,10 +2129,10 @@ def test_save_observer_state_dict(self): orig = LinearModelWithSubmodule().eval() model = orig qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} - model = prepare_fx(model, qconfig_dict) + x = torch.randn(5, 5) + model = prepare_fx(model, qconfig_dict, example_inputs=(x,)) # run it through input - x = torch.randn(5, 5) model(x) quant = convert_fx(model) @@ -2120,7 +2145,7 @@ def test_save_observer_state_dict(self): # Load the stats into new model model_2 = orig - model_2 = prepare_fx(model_2, qconfig_dict) + model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,)) loaded_dict = torch.load(b) torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict) @@ -2208,7 +2233,6 @@ def forward(self, x): x = self.linear2(x) return x - data = torch.randn(3, 3) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() @@ -2255,13 +2279,15 @@ def forward(self, x): } } + example_inputs = (torch.randn(3, 3),) # check prepared model m = prepare_fx( original_m, qconfig_dict, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) # calibration - m(data) + m(*example_inputs) # all activation observers are inserted in the top level module count_check = { ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers @@ -2280,13 +2306,13 @@ def forward(self, x): } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) self.assertEqual(type(m.custom), quantized_module_class) - res = m(data) + res = m(*example_inputs) # quantize the reference model - ref_m = prepare_fx(original_ref_m, qconfig_dict) - ref_m(data) + ref_m = prepare_fx(original_ref_m, qconfig_dict, example_inputs=example_inputs) + ref_m(*example_inputs) ref_m = convert_fx(ref_m) - ref_res = ref_m(data) + ref_res = ref_m(*example_inputs) self.assertEqual(res, ref_res) @skipIfNoFBGEMM @@ -2360,16 +2386,18 @@ def forward(self, x0): } } m = M().eval() + example_inputs = (torch.randn(3, 3),) m = prepare_fx( m, {"": default_qconfig}, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) # make sure it works m = convert_fx( m, convert_custom_config_dict=convert_custom_config_dict) # make sure it runs - m(torch.randn(3, 3)) + m(*example_inputs) @skipIfNoFBGEMM def test_non_traceable_module(self): @@ -2415,6 +2443,7 @@ def forward(self, x): } m = prepare_fx( m, qconfig_dict, + example_inputs=({"key": torch.randn(1)},), prepare_custom_config_dict=prepare_custom_config_dict) node_occurrence = { @@ -2441,9 +2470,10 @@ def forward(self, x): m = M() m.eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(4, 1, 4, 4),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # calibrate - prepared(torch.randn(4, 1, 4, 4)) + prepared(*example_inputs) # copy prepared_copy = copy.deepcopy(prepared) # quantize, should run with no errors @@ -2460,21 +2490,21 @@ def __init__(self): def forward(self, x): return self.linear(x) - data = torch.rand(8, 5) + example_inputs = (torch.rand(8, 5),) m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) # test deepcopy m_copy = copy.deepcopy(m) - self.assertEqual(m_copy(data), m(data)) + self.assertEqual(m_copy(*example_inputs), m(*example_inputs)) # test state_dict state_dict = m.state_dict() m_new = M().eval() - m_new = prepare_fx(m_new, {"": default_qconfig}) + m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs) m_new = convert_fx(m_new) m_new.load_state_dict(state_dict) - self.assertEqual(m_new(data), m(data)) + self.assertEqual(m_new(*example_inputs), m(*example_inputs)) def test_dequantize(self): r""" Test to make sure dequantize node are placed before @@ -2542,12 +2572,14 @@ def forward(self, x): # quantized input, quantized output m = M() qconfig_dict = {'': torch.ao.quantization.default_qconfig} + example_inputs = (torch.randn(1, 1, 4, 4),) m.eval() mp = torch.ao.quantization.quantize_fx.prepare_fx( m, qconfig_dict, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) - mp(torch.randn(1, 1, 4, 4)) + mp(*example_inputs) mq = torch.ao.quantization.quantize_fx.convert_fx(mp) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) @@ -2612,7 +2644,7 @@ def test_convtranspose_per_channel_fails_early(self): m.eval() qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} with self.assertRaises(AssertionError) as context: - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) self.assertTrue( str(context.exception) == 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') @@ -2644,8 +2676,9 @@ def forward(self, x): model = M().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0 @@ -2663,7 +2696,7 @@ def forward(self, x): self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict") self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict") # ensure it runs - m(torch.rand(5, 5)) + m(*example_inputs) # ensure it is scriptable scripted = torch.jit.script(m) scripted_keys = scripted.state_dict().keys() @@ -2708,9 +2741,10 @@ def forward(self, x): return x model = M().eval() + example_inputs = (torch.rand(5, 5),) qconfig_dict = {"": default_qconfig} - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) assert hasattr(m, "mods1_0_packed_weight_0") assert hasattr(m, "mods1_1_packed_weight_0") @@ -2743,10 +2777,11 @@ def forward(self, x): return x model = M().eval() qconfig_dict = {"": float16_dynamic_qconfig} - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) # make sure it runs - m(torch.randn(5, 5)) + m(*example_inputs) def test_getattr_with_nontensor_result(self): """ @@ -2785,9 +2820,10 @@ def forward(self, x): for cls in (M1, M2, M3): m = cls().eval() - m(torch.rand(4, 4, 4, 4)) + example_inputs = (torch.rand(4, 4, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) mp(torch.rand(4, 4, 4, 4)) mc = convert_fx(mp) @@ -2841,7 +2877,7 @@ def _check_node_not_observed(model, arg_node, node): def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args): model.eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} - prepared_model = prepare_fx(model, qconfig_dict) + prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args)) self._check_not_observed(prepared_model, node_info_to_non_tensor_args) prepared_model(*args) @@ -3035,12 +3071,13 @@ def forward(self, x): return x m = M().eval() - m(torch.rand(4, 1, 4, 4)) + example_inputs = (torch.rand(4, 1, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(4, 1, 4, 4)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.rand(4, 1, 4, 4)) + mc(*example_inputs) def test_fp32_sum(self): """ @@ -3073,12 +3110,13 @@ def forward(self, x): for cls in (M1, M2): m = cls().eval() - m(torch.rand(4, 1, 4, 4)) + example_inputs = (torch.rand(4, 1, 4, 4),) + m(*example_inputs) qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(4, 1, 4, 4)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.rand(4, 1, 4, 4)) + mc(*example_inputs) def test_fusion_pattern_unquantized(self): """ @@ -3113,8 +3151,9 @@ def forward(self, x): ('child', None), ], } - mp = prepare_fx(m, qconfig_dict) - mp(torch.rand(1, 1, 1, 1)) + example_inputs = (torch.rand(1, 1, 1, 1),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mc = convert_fx(mp) def test_state_dict(self): @@ -3133,7 +3172,7 @@ def forward(self, x): m = M1().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) @@ -3154,7 +3193,7 @@ def forward(self, x): m = M2().eval() qconfig_dict = {"": default_qconfig} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) m = convert_fx(m) state_dict = m.state_dict() self.assertTrue("_packed_weight_0" in state_dict) @@ -3164,7 +3203,7 @@ def forward(self, x): data = torch.rand(1, 3, 5, 5) ref_res = m(data) m = M2().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, (data,)) m = convert_fx(m) res = m(data) weight, bias = m._packed_weight_0.unpack() @@ -3186,7 +3225,7 @@ def checkModel(m, data, ref_weight, ref_bias, ref_res): # Test save to disk and load back m = M2().eval() - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(data,)) m = convert_fx(m) m.load_state_dict(state_dict) with TemporaryFileName() as fname: @@ -3229,8 +3268,9 @@ def forward(self, x): (torch.nn.functional.linear, float16_dynamic_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m, _remove_qconfig=False) self.assertTrue(hasattr(m.mods2, 'qconfig')) @@ -3250,7 +3290,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": float16_static_qconfig} # make sure quantization runs - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) m = convert_fx(m) def test_qparams_fqn(self): @@ -3288,8 +3328,9 @@ def forward(self, x): (torch.nn.functional.relu, default_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() m(torch.randn(5, 5)) @@ -3321,12 +3362,13 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(4, 4, 4, 4),) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # if an observer is inserted after _user_func_with_complex_return_type, # the following call will fail - mp(torch.randn(4, 4, 4, 4)) + mp(*example_inputs) mc = convert_fx(mp) - mc(torch.randn(4, 4, 4, 4)) + mc(*example_inputs) def test_fold_quant_dequant(self): """ Test that the sequence of quant-dequant nodes in the @@ -3352,11 +3394,12 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig), ], } - m = prepare_fx(model, qconfig_dict) - m(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) keys = m.state_dict().keys() - m(torch.randn(5, 5)) + m(*example_inputs) dequant = 0 quant = 0 for n in m.graph.nodes: @@ -3375,7 +3418,7 @@ def test_quant_output_always_observed(self): """ qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} prepare_custom_config_dict = {'output_quantized_idxs': [0]} - data = (torch.randn(4, 1, 4, 4),) + example_inputs = (torch.randn(4, 1, 4, 4),) # non-quantizeable node, quantized output class M1(torch.nn.Module): @@ -3389,7 +3432,7 @@ def forward(self, x): m1 = M1() self.checkGraphModeFxOp( - m1, data, QuantType.QAT, + m1, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, }, @@ -3410,7 +3453,7 @@ def forward(self, x): m2 = M2() self.checkGraphModeFxOp( - m2, data, QuantType.QAT, + m2, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, @@ -3432,7 +3475,7 @@ def forward(self, x): m3 = M3() self.checkGraphModeFxOp( - m3, data, QuantType.QAT, + m3, example_inputs, QuantType.QAT, prepare_expected_node_occurrence={ # one for weights, one for activations ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, @@ -3452,7 +3495,11 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict={"preserved_attributes": ["attr"]}) + m = prepare_fx( + m, + {"": default_qconfig}, + example_inputs=(torch.randn(1),), + prepare_custom_config_dict={"preserved_attributes": ["attr"]}) self.assertTrue(hasattr(m, "attr")) m2 = copy.deepcopy(m) self.assertTrue(hasattr(m2, "attr")) @@ -3475,7 +3522,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {'': torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) + mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) mc = convert_fx(mp) def test_shape_followed_by_quantized_op(self): @@ -3497,9 +3544,10 @@ def forward(self, x): # make sure quantization runs m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(2, 2, 4, 4),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) - m(torch.randn(2, 2, 4, 4)) + m(*example_inputs) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method("dequantize"): 1 @@ -3517,7 +3565,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),)) m = convert_fx(m) # Make sure this runs without error m = torch.fx.Transformer(m).transform() @@ -3558,7 +3606,8 @@ def forward(self, x): qconfig = default_qconfig actpp_module_class = torch.ao.quantization.MinMaxObserver - m = prepare(m, {"": qconfig}) + example_inputs = (torch.randn(1, 3, 3, 3),) + m = prepare(m, {"": qconfig}, example_inputs=example_inputs) # check that there is a duplicated observer instance actpp_module_count = 0 for name, module in m.named_modules(remove_duplicate=False): @@ -3620,12 +3669,16 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": torch.ao.quantization.QConfig( - activation=torch.ao.quantization.HistogramObserver.with_args( - qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 - ), weight=torch.ao.quantization.default_per_channel_weight_observer)}) + example_inputs = (torch.rand(2, 1, 5, 5),) + m = prepare_fx( + m, + {"": torch.ao.quantization.QConfig( + activation=torch.ao.quantization.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, + example_inputs=example_inputs) m = convert_fx(m, is_reference=True) - m(torch.rand(2, 1, 5, 5)) + m(*example_inputs) def test_preserve_tuple(self): """ Test tuple input type is preserved @@ -3642,7 +3695,8 @@ def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]): return self.lstm(inputs, (h, c)) m = LSTM().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50)) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # make sure the arg[1] of lstm module is a tuple for n in m.graph.nodes: if n.target == "lstm": @@ -3654,7 +3708,7 @@ def forward(self, x): return torch.nn.functional.relu(x) m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) m_copy = copy.deepcopy(m) m = convert_fx(m) m_ref = convert_fx(m_copy, is_reference=True) @@ -3716,9 +3770,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3756,9 +3811,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.randn(5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3796,9 +3852,10 @@ def forward(self, x): qconfig_dict = { "": qconfig } - m = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(5, 5, 5),) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) - m(torch.rand(5, 5, 5)) + m(*example_inputs) node_list = [ ns.call_module(nniqd.LinearReLU), ns.call_module(nniqd.LinearReLU), @@ -3828,13 +3885,13 @@ def forward(self, x): for M in [M1, M2]: m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(5, 10),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m_copy = copy.deepcopy(m) m = convert_fx(m, is_reference=False) m_ref = convert_fx(m_copy, is_reference=True) - data = torch.randn(5, 10) - result = m(data) - result_ref = m_ref(data) + result = m(*example_inputs) + result_ref = m_ref(*example_inputs) self.assertTrue(torch.equal(result, result_ref)) def test_ref_conv_module(self): @@ -3866,11 +3923,11 @@ def forward(self, x): for dim, M in itertools.product([1, 2, 3], [M1, M2]): m = M(dim).eval() - m = prepare_fx(m, {"": default_qconfig}) + data = self.img_data_dict[dim][0][0] + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_copy = copy.deepcopy(m) m = convert_fx(m, is_reference=False) m_ref = convert_fx(m_copy, is_reference=True) - data = self.img_data_dict[dim][0][0] result = m(data) result_ref = m_ref(data) self.assertTrue(torch.equal(result, result_ref)) @@ -3885,7 +3942,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),)) m = convert_fx(m) occurrence = { ns.call_function(torch.quantize_per_tensor): 2, @@ -3930,7 +3987,7 @@ def forward(self, x): model = M().eval() - prepared = prepare_fx(model, {"": default_qconfig}) + prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5))) name_list = [] for name, mod in prepared.named_modules(): if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver): @@ -3959,11 +4016,11 @@ def forward(self, x): for dim in range(1, len(convs) + 1): m = M(dim).eval() - m = prepare_fx(m, {"": default_qconfig}) + data = self.img_data_dict[dim][0][0] + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) m_ref = copy.deepcopy(m) m_ref = convert_fx(m_ref, is_reference=True) m = convert_fx(m) - data = self.img_data_dict[dim][0][0] out_ref = m_ref(data) out = m(data) # check that reference pattern for quantized conv module is fused @@ -4013,8 +4070,9 @@ def forward(self, x): (nn.Linear, get_default_qat_qconfig("fbgemm")), ], } - prepared = prepare_qat_fx(model, qconfig_dict) - prepared(torch.rand(5, 5)) + example_inputs = (torch.rand(5, 5),) + prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) + prepared(*example_inputs) if check == "module_name": convert_qconfig_dict = {"": None, "object_type": [ @@ -4133,7 +4191,8 @@ def forward(self, x): options = itertools.product([M1, M2], [True, False]) for M, is_qat in options: m = M1().eval() - m = prepare_fx(m, get_default_qconfig_dict()) + example_inputs = (torch.randn(1, 3, 3, 3),) + m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = convert_fx(m) node_list = [ ns.call_function(torch.quantize_per_tensor), @@ -4146,7 +4205,7 @@ def forward(self, x): expected_node_list=node_list) m = M2().eval() - m = prepare_fx(m, get_default_qconfig_dict()) + m = prepare_fx(m, get_default_qconfig_dict(), example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 0, @@ -4171,7 +4230,7 @@ def forward(self, x): return x m = M().eval() - mp = prepare_fx(m, get_default_qconfig_dict()) + mp = prepare_fx(m, get_default_qconfig_dict(), example_inputs=(torch.randn(1, 1),)) found_stack_trace = False for n in mp.graph.nodes: @@ -4233,11 +4292,14 @@ def forward(self, x): "non_traceable_module_class": [UnTraceableModuleClass], "non_traceable_module_name": ["untraceable_module_name"], } + example_inputs = (torch.randn(2, 2),) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( - mod.train(), qconfig_dict, prepare_custom_config_dict + mod.train(), qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict ) mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( - mod.train(), qconfig_dict, prepare_custom_config_dict + mod.train(), qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict=prepare_custom_config_dict ) self.assertTrue( isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear) @@ -4287,7 +4349,7 @@ def forward(self, x): for backend in backends: m = M().eval() qconfig_dict = func(backend) - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) for name, mod in m.named_modules(): if is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": @@ -4316,10 +4378,11 @@ def _test(prepare_fn, qconfig_dict): m = LinearModel() m1 = copy.deepcopy(m) m1.train() - prepare_fn(m1, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepare_fn(m1, qconfig_dict, example_inputs=example_inputs) m2 = copy.deepcopy(m) m2.eval() - prepare_fn(m2, qconfig_dict) + prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes _test(prepare_fx, get_default_qconfig_dict()) @@ -5016,7 +5079,8 @@ def forward(self, x, y): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.randn(3), torch.randn(3)) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) m = convert_fx(m) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 2, @@ -5027,7 +5091,7 @@ def forward(self, x, y): # check the model is scriptable m = torch.jit.script(m) # check the model is runnable - m(torch.randn(3), torch.randn(3)) + m(*example_inputs) @skipIfNoFBGEMM def test_mul_relu(self): @@ -5037,9 +5101,9 @@ def test_mul_relu(self): operator.mul, operator.imul) # TODO(future PR): make more generic - def _test_quantized_add_mul_qat(self, model, expected_node_occurrence): + def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence): qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} - mp = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes( mp, expected_node_occurrence=expected_node_occurrence) @@ -5060,10 +5124,11 @@ def forward(self, x): return x m = M() + example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } - self._test_quantized_add_mul_qat(m, expected_node_occurrence) + self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) @skipIfNoFBGEMM def test_quantized_mul_qat(self): @@ -5082,10 +5147,11 @@ def forward(self, x): return x m = M() + example_inputs = (torch.randn(1, 1, 1, 1),) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, } - self._test_quantized_add_mul_qat(m, expected_node_occurrence) + self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) def test_int8_input_no_unnecessary_fq(self): """ @@ -5105,6 +5171,7 @@ def forward(self, x): m = M(0.5) mp = torch.ao.quantization.quantize_fx.prepare_qat_fx( m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, + example_inputs=(torch.randn(1),), prepare_custom_config_dict={"input_quantized_idxs": [0]}) expected_node_occurrence = { ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1, @@ -5128,8 +5195,8 @@ def forward(self, x, y): y = self.conv2(y) return torch.cat([x, y], 1) - data = (torch.randn(1, 2, 5, 5, dtype=torch.float), - torch.randn(1, 2, 5, 5, dtype=torch.float)) + example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float), + torch.randn(1, 2, 5, 5, dtype=torch.float)) quantized_node = ns.call_function(torch.cat) options = itertools.product(self.static_quant_types, [True, False]) for quant_type, is_reference in options: @@ -5158,7 +5225,7 @@ def forward(self, x, y): self.checkGraphModeFxOp( M(), - data, + example_inputs, quant_type, quantized_node, expected_node_list=converted_node_list, @@ -5167,7 +5234,7 @@ def forward(self, x, y): # check cat is using the same observer for input and output m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) # two inputs and one output of torch.cat are using same observer, so we have # 2 observers that's replicated all_observers = len(dict(m.named_modules(remove_duplicate=False))) @@ -5175,7 +5242,7 @@ def forward(self, x, y): self.assertEqual(all_observers, distinct_observers + 2) # make sure the converted model runs m = convert_fx(m) - m(*data) + m(*example_inputs) @skipIfNoFBGEMM def test_qbatch_norm(self): @@ -5630,6 +5697,7 @@ def forward(self, x, y): data_x = torch.randn((2, 2, 2,)) data_y = torch.randn((2, 2, 2,)) + example_inputs = (data_x, data_y) qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} is_reference = True node_list = [ @@ -5637,10 +5705,10 @@ def forward(self, x, y): ] m = M().eval() - m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict) - m_prep(data_x, data_y) + m_prep = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + m_prep(*example_inputs) m_quant = torch.ao.quantization.quantize_fx.convert_fx(m_prep, is_reference=is_reference) - m_quant(data_x, data_y) + m_quant(*example_inputs) self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) @@ -5823,12 +5891,12 @@ def forward(self, x): x = self.conv2(x) return x - data = torch.rand(1, 3, 10, 10) + example_inputs = (torch.rand(1, 3, 10, 10),) # This model is not executable since we just put all ops # in the same forward m = M().eval() qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) @@ -5859,7 +5927,7 @@ def forward(self, x): # Checking the is_reference output m = M().eval() qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared, is_reference=True) @@ -5884,7 +5952,10 @@ def forward(self, x): m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict, prepare_custom_config_dict={"input_quantized_idxs": [0]}) + example_inputs = (torch.randn(1, 3, 3, 3),) + prepared = prepare_fx( + m, qconfig_dict, example_inputs=example_inputs, + prepare_custom_config_dict={"input_quantized_idxs": [0]}) # not runnable quantized = convert_fx(prepared) @@ -5950,7 +6021,8 @@ def forward(self, x): m = M().eval() # nothing to fuse so skipping the fuse step qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 3, 3, 3),) + prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) # not runnable quantized = convert_fx(prepared) @@ -5983,7 +6055,7 @@ def forward(self, x): return x m = M().eval() - m = prepare_fx(m, {"": default_reuse_input_qconfig}) + m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),)) m = convert_fx(m) # make sure it runs m(torch.rand(1)) @@ -5998,12 +6070,13 @@ def forward(self, xs): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) - m(torch.rand(1, 2)) + m(*example_inputs) class M2(torch.nn.Module): def forward(self, xs): @@ -6012,7 +6085,8 @@ def forward(self, xs): return x m2 = M2().eval() - m2 = prepare_fx(m2, {"": default_qconfig}) + example_inputs = ([torch.rand(1, 2)],) + m2 = prepare_fx(m2, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m2, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 1 }) @@ -6021,7 +6095,7 @@ def forward(self, xs): ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) - m2([torch.rand(1, 2)]) + m2(*example_inputs) # testing prepare recognizes non-Tensor input for getitem class M3(torch.nn.Module): @@ -6032,7 +6106,8 @@ def forward(self, x): return x m3 = M3().eval() - m3 = prepare_fx(m3, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2, 3, 4),) + m3 = prepare_fx(m3, {"": default_qconfig}, example_inputs=example_inputs) self.checkGraphModuleNodes(m3, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 1 }) @@ -6041,7 +6116,7 @@ def forward(self, x): ns.call_function(torch.quantize_per_tensor), ns.call_method("dequantize") ]) - m3(torch.rand(1, 2, 3, 4)) + m3(*example_inputs) @skipIfNoFBGEMM @@ -6090,12 +6165,12 @@ def forward(self, x): # nothing to fuse so skipping the fuse step m_copy = copy.deepcopy(m) qconfig_dict = {'': qconfig} - prepared = prepare(m, qconfig_dict) + example_inputs = (torch.rand(3, 3, 3, 3),) + prepared = prepare(m, qconfig_dict, example_inputs=example_inputs) prepared_copy = copy.deepcopy(prepared) # check that prepare does not change model result if eval_mode: - r = torch.rand(3, 3, 3, 3) - self.assertEqual(m_copy(r), prepared_copy(r)) + self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs)) # check the correct number of activation_post_process is inserted expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize count_check = { @@ -6183,7 +6258,7 @@ def forward(self, x): x = self.ff6.cat([x]) return x - data = torch.rand(3, 3) + example_inputs = (torch.rand(3, 3),) # Note: QAT test succeeded by chance, to make it actually work # we need to fix eager mode FloatFunctional by removing # activation_post_process in add_scalar and mul_scalar @@ -6204,13 +6279,13 @@ def forward(self, x): prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx qconfig_dict = {"": qconfig} - m = prepare_fx_function(m, qconfig_dict) + m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs) node_occurrence = { ns.call_module(expected_act_post_process): 7, ns.call_module(torch.nn.quantized.FloatFunctional): 0 } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - m(data) + m(*example_inputs) node_list = [ ns.call_function(torch.quantize_per_tensor), ns.call_function(torch.ops.quantized.add), @@ -6228,7 +6303,7 @@ def forward(self, x): ref_m.qconfig = qconfig prepare_function = prepare_qat if is_qat else prepare ref_m = prepare_function(ref_m) - ref_m(data) + ref_m(*example_inputs) ref_m = convert(ref_m) # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar # self.assertEqual(m(data), ref_m(data)) @@ -6245,6 +6320,7 @@ def forward(self, indices): for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: model = M().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + example_inputs = (indices,) quantized_node = ns.call_module(nnq.Embedding) configs = [ (qconfig_type, ns.call_module(nnq.Embedding)), @@ -6254,14 +6330,14 @@ def forward(self, indices): for qconfig, node in configs: qconfig_dict = {"": qconfig} - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=node) # make sure it runs - m(indices) + m(*example_inputs) def test_embedding_bag(self): class M(torch.nn.Module): @@ -6275,7 +6351,7 @@ def forward(self, indices, offsets): indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) quantized_node = ns.call_module(nnq.EmbeddingBag) - inputs = (indices, offsets) + example_inputs = (indices, offsets) for dtype in [torch.quint8, torch.quint4x2]: model = M().eval() @@ -6286,7 +6362,7 @@ def forward(self, indices, offsets): weight=float_qparams_observer) self.checkGraphModeFxOp( model, - inputs, + example_inputs, QuantType.DYNAMIC, quantized_node, custom_qconfig_dict={"": float_qparams_qconfig} @@ -6296,14 +6372,14 @@ def forward(self, indices, offsets): for qconfig in [None, default_qconfig]: qconfig_dict = {"": default_qconfig} m = M().eval() - m = prepare_fx(model, qconfig_dict) + m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ ns.call_module(torch.ao.quantization.MinMaxObserver): 0 }) m = convert_fx(m) self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) # make sure it runs - m(*inputs) + m(*example_inputs) def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): options = itertools.product(qconfigs, module_type_strs) @@ -6323,7 +6399,7 @@ def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_inp (x, qconfig) for x in module_types ] } - model_graph = prepare_fx(model_graph, graph_qconfig_dict) + model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,)) model_graph = convert_fx(model_graph) self.assertEqual(model_eager(sample_input), model_graph(sample_input)) self.checkScriptable(model_graph, [[sample_input]], True) @@ -6407,7 +6483,8 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig) ] } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) expected_occurrence = { # input and weight of first and second linear, output of first and second linear ns.call_module(torch.ao.quantization.MinMaxObserver): 6, @@ -6456,7 +6533,8 @@ def forward(self, x): (torch.nn.functional.linear, default_qconfig) ] } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 4),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) expected_occurrence = { # input and weight of linear, output of linear ns.call_module(torch.ao.quantization.MinMaxObserver): 3, @@ -6489,7 +6567,8 @@ def forward(self, x, mask): return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) + example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) expected_occurrence = { ns.call_module(torch.ao.quantization.MinMaxObserver): 0 } @@ -6497,8 +6576,7 @@ def forward(self, x, mask): m, expected_node_occurrence=expected_occurrence) m = convert_fx(m) - m(torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) - return m + m(*example_inputs) def test_chunk(self): class M(torch.nn.Module): @@ -6507,11 +6585,11 @@ def forward(self, x): x = x + y return x m = M().eval() - m = prepare_fx(m, {"": default_qconfig}) - data = torch.rand(2, 2, 2, 2) - m(data) + example_inputs = (torch.rand(2, 2, 2, 2),) + m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) + m(*example_inputs) m = convert_fx(m) - m(data) + m(*example_inputs) # make sure everything runs def test_ref_pattern_multi_use(self): @@ -6536,7 +6614,8 @@ def forward(self, x): (torch.nn.ReLU, get_default_qconfig("fbgemm")), ], } - m = prepare_fx(m, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) m = convert_fx(m) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 1, @@ -6557,9 +6636,10 @@ def forward(self, x, y): return z m = M().eval() + example_inputs = (torch.randn(2, 2), torch.randn(2, 2)) qconfig_dict = {"": torch.ao.quantization.default_qconfig} - mp = prepare_fx(m, qconfig_dict) - mp(torch.randn(2, 2), torch.randn(2, 2)) + mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = convert_fx(mp) expected_occurrence = { ns.call_function(torch.matmul): 0, @@ -6569,7 +6649,7 @@ def forward(self, x, y): mq, expected_node_occurrence=expected_occurrence) # verify no crash - res = mq(torch.randn(2, 2), torch.randn(2, 2)) + res = mq(*example_inputs) class TestQuantizeFxModels(QuantizationTestCase): @skipIfNoFBGEMM @@ -6589,12 +6669,13 @@ def forward(self, x): return y input = torch.randn((5, 1, 6, 6)).to('cuda') + example_inputs = (input,) model = Net().to('cuda').eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared = prepare_fx(model, qconfig_dict) - model_prepared(input) + model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) + model_prepared(*example_inputs) model_quantized = convert_fx(model_prepared, is_reference=True) - out = model_quantized(input) + out = model_quantized(*example_inputs) self.assertEqual(out.device.type, 'cuda') @skipIfNoFBGEMM @@ -6618,7 +6699,7 @@ def forward(self, x): input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared = prepare_fx(model, qconfig_dict) + model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared(input) model_prepared.to(device_after) model_quantized = convert_fx(model_prepared, is_reference=True) @@ -6644,8 +6725,8 @@ def forward(self, x): input = torch.randn((5, 1, 6, 6)).to(device) model = Net().to(device).eval() qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} - model_prepared_first = prepare_fx(model, qconfig_dict) - model_prepared_second = prepare_fx(model, qconfig_dict) + model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,)) + model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,)) model_prepared_first(input) state_dict = model_prepared_first.state_dict() del model_prepared_first @@ -6660,10 +6741,11 @@ def test_model_dropout(self): from torchvision import models m = models.mobilenet_v3_small() qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} - mp = prepare_qat_fx(m, qconfig_dict) - mp(torch.randn(1, 3, 224, 224)) + example_inputs = (torch.randn(1, 3, 224, 224),) + mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) + mp(*example_inputs) mq = convert_fx(mp) - res = mq(torch.randn(1, 3, 224, 224)) + res = mq(*example_inputs) def _test_model_impl( self, mode, name, model, eager_quantizable_model, @@ -6813,7 +6895,7 @@ def _test_building_block(self, quant_type, BB): eager = eager_prepare(eager) qconfig_dict = {"": qconfig} - graph = graph_prepare(graph, qconfig_dict) + graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],)) eager_out = eager(data[0][0]) graph_out = graph(data[0][0]) @@ -6946,7 +7028,7 @@ def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingBagLinear().train() - prepared_fx_model = prepare_qat_fx(model, qconfig_dict) + prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_dict=qconfig_dict) @@ -6986,7 +7068,7 @@ def forward(self, input: torch.Tensor): eval_output = [[torch.randint(0, 10, (12, 1))]] model = EmbeddingLinear().train() - prepared_fx_model = prepare_qat_fx(model, qconfig_dict) + prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, qconfig_dict=qconfig_dict) @@ -7048,8 +7130,8 @@ def forward(self, x): activation=ref_fake_quant, weight=ref_weight_fake_quant ) qconfig_dict = {"": ref_qat_qconfig} - - prepared_ref = prepare_qat_fx(model, qconfig_dict) + example_inputs = (torch.randn(1, 5),) + prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, @@ -7069,7 +7151,7 @@ def forward(self, x): activation=custom_fake_quant, weight=custom_weight_fake_quant ) custom_qconfig_dict = {"": custom_qconfig} - prepared = prepare_qat_fx(model, custom_qconfig_dict) + prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs) prepared.to(device) prepared_ref.to(device) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 086b65e13c90..ad0113ab5285 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -34,7 +34,7 @@ from torch.ao.quantization.quantization_types import ( Pattern, - NodePattern + NodePattern, ) from ._equalize import ( @@ -231,7 +231,7 @@ def prepare_get_standalone_module_configs( prepare_custom_config_dict: Dict[str, Any], parent_qconfig: QConfigAny, parent_backend_config_dict: Optional[Dict[str, Any]], -) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: +) -> Tuple[Dict[str, Any], Tuple[Any], Dict[str, Any], Dict[str, Any]]: """ Returns the standalone module qconfig_dict and prepare_config_dict for `node`, assuming that the module pointed to by `node` is @@ -239,8 +239,11 @@ def prepare_get_standalone_module_configs( """ standalone_module_name = str(node.target) standalone_module_type = type(modules[standalone_module_name]) # type: ignore[index] - sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \ - get_standalone_module_configs(standalone_module_name, standalone_module_type, prepare_custom_config_dict) + sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, \ + sm_backend_config_dict = get_standalone_module_configs( + standalone_module_name, + standalone_module_type, + prepare_custom_config_dict) # fallback to use parent module's qconfig if user didn't specify qconfig dict if sm_qconfig_dict is None: sm_qconfig_dict = {"": parent_qconfig} @@ -250,7 +253,7 @@ def prepare_get_standalone_module_configs( # as well, this can be added later if sm_backend_config_dict is None: sm_backend_config_dict = parent_backend_config_dict - return sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict + return sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, sm_backend_config_dict def qat_swap_modules( root: torch.nn.Module, @@ -529,7 +532,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( else: # custom flow for standalone modules - _sm_qconfig_dict, sm_prepare_config_dict, _sm_backend_config_dict = \ + _sm_qconfig_dict, _, sm_prepare_config_dict, _sm_backend_config_dict = \ prepare_get_standalone_module_configs( node, modules, prepare_custom_config_dict, qconfig, backend_config_dict) @@ -1301,8 +1304,8 @@ def run_prepare_fx_on_standalone_modules( elif not qhandler.is_standalone_module(): continue - sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \ - prepare_get_standalone_module_configs( + sm_qconfig_dict, sm_example_inputs, sm_prepare_config_dict, \ + sm_backend_config_dict = prepare_get_standalone_module_configs( root_node, modules, prepare_custom_config_dict, qconfig, backend_config_dict) standalone_module = modules[root_node.target] @@ -1313,7 +1316,8 @@ def run_prepare_fx_on_standalone_modules( standalone_module, sm_qconfig_dict, is_qat, - sm_prepare_config_dict, + example_inputs=sm_example_inputs, + prepare_custom_config_dict=sm_prepare_config_dict, backend_config_dict=sm_backend_config_dict) preserved_attributes = \ set(sm_prepare_config_dict.get("preserved_attributes", [])) @@ -1349,6 +1353,7 @@ def prepare( qconfig_dict: Any, is_qat: bool, node_name_to_scope: Dict[str, Tuple[str, type]], + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, diff --git a/torch/ao/quantization/fx/qconfig_utils.py b/torch/ao/quantization/fx/qconfig_utils.py index 4884ef08d0d6..79d3508facf2 100644 --- a/torch/ao/quantization/fx/qconfig_utils.py +++ b/torch/ao/quantization/fx/qconfig_utils.py @@ -319,9 +319,9 @@ def get_standalone_module_configs( custom_config_dict.get("standalone_module_name", []) standalone_module_class_configs = \ custom_config_dict.get("standalone_module_class", []) - class_config_map = {x[0]: (x[1], x[2], x[3]) for x in standalone_module_class_configs} - name_config_map = {x[0]: (x[1], x[2], x[3]) for x in standalone_module_name_configs} - config = class_config_map.get(module_type, (None, None, None)) + class_config_map = {x[0]: (x[1], x[2], x[3], x[4]) for x in standalone_module_class_configs} + name_config_map = {x[0]: (x[1], x[2], x[3], x[4]) for x in standalone_module_name_configs} + config = class_config_map.get(module_type, (None, None, None, None)) # name config has precedence over type config config = name_config_map.get(module_name, config) return config diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 64de11818bd1..1d16ce817144 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -19,7 +19,6 @@ from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 - def _check_is_graph_module(model: torch.nn.Module) -> None: if not isinstance(model, GraphModule): raise ValueError( @@ -178,6 +177,7 @@ def _prepare_fx( model: torch.nn.Module, qconfig_dict: Any, is_qat: bool, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, @@ -247,6 +247,7 @@ def _prepare_fx( qconfig_dict, is_qat, tracer.node_name_to_scope, + example_inputs=example_inputs, prepare_custom_config_dict=prepare_custom_config_dict, equalization_qconfig_dict=equalization_qconfig_dict, backend_config_dict=backend_config_dict, @@ -262,6 +263,7 @@ def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, is_qat: bool, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> GraphModule: @@ -291,6 +293,7 @@ def _prepare_standalone_module_fx( model, qconfig_dict, is_qat, + example_inputs, prepare_custom_config_dict, backend_config_dict=backend_config_dict, is_standalone_module=True, @@ -340,6 +343,7 @@ def fuse_fx( def prepare_fx( model: torch.nn.Module, qconfig_dict: Any, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, @@ -450,6 +454,7 @@ def prepare_fx( "preserved_attributes": ["preserved_attr"], } + * `example_inputs`: (required) Example inputs for forward function of the model * `equalization_qconfig_dict`: equalization_qconfig_dict is a dictionary with a similar structure as qconfig_dict except it will contain configurations specific to equalization techniques such as input-weight @@ -480,7 +485,8 @@ def calibrate(model, data_loader): model(image) qconfig_dict = {"": qconfig} - prepared_model = prepare_fx(float_model, qconfig_dict) + example_inputs = (torch.randn(1, 3, 224, 224),) + prepared_model = prepare_fx(float_model, qconfig_dict, example_inputs=example_inputs) # Run calibration calibrate(prepared_model, sample_inference_data) @@ -490,6 +496,7 @@ def calibrate(model, data_loader): model, qconfig_dict, False, # is_qat + example_inputs, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict, @@ -499,6 +506,7 @@ def calibrate(model, data_loader): def prepare_qat_fx( model: torch.nn.Module, qconfig_dict: Any, + example_inputs: Tuple[Any, ...], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, backend_config_dict: Optional[Dict[str, Any]] = None, ) -> ObservedGraphModule: @@ -507,6 +515,7 @@ def prepare_qat_fx( Args: * `model`: torch.nn.Module model, must be in train mode * `qconfig_dict`: see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs`: see :func:`~torch.ao.quantization.prepare_fx` * `prepare_custom_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` * `backend_config_dict`: see :func:`~torch.ao.quantization.prepare_fx` @@ -538,6 +547,7 @@ def train_loop(model, train_data): model, qconfig_dict, True, # is_qat + example_inputs, prepare_custom_config_dict, backend_config_dict=backend_config_dict, ) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 42b1ea20c225..7687e28aa915 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -864,6 +864,7 @@ def checkGraphModeFxOp( qconfig_dict = custom_qconfig_dict prepared = prepare( model, qconfig_dict, + example_inputs=inputs, prepare_custom_config_dict=prepare_custom_config_dict, backend_config_dict=backend_config_dict) if not quant_type == QuantType.DYNAMIC: