diff --git a/test/models_densenet_test.py b/test/models_densenet_test.py index 72132d807..d8395e18a 100644 --- a/test/models_densenet_test.py +++ b/test/models_densenet_test.py @@ -113,21 +113,32 @@ def _test_quantize_model(self, model_config): _find_block_full_path(model.features, block_name) for block_name in heads.keys() ] + # TODO[quant-example-inputs]: The dimension here is random, if we need to + # use dimension/rank in the future we'd need to get the correct dimensions + standalone_example_inputs = (torch.randn(1, 3, 3, 3),) # we need to keep the modules used in head standalone since # it will be accessed with path name directly in execution prepare_custom_config_dict["standalone_module_name"] = [ ( head, {"": tq.default_qconfig}, + standalone_example_inputs, {"input_quantized_idxs": [0], "output_quantized_idxs": []}, None, ) for head in head_path_from_blocks ] - model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig}) + # TODO[quant-example-inputs]: The dimension here is random, if we need to + # use dimension/rank in the future we'd need to get the correct dimensions + example_inputs = (torch.randn(1, 3, 3, 3),) + model.initial_block = prepare_fx( + model.initial_block, {"": tq.default_qconfig}, example_inputs + ) + model.features = prepare_fx( model.features, {"": tq.default_qconfig}, + example_inputs, prepare_custom_config_dict, ) model.set_heads(heads) @@ -148,8 +159,8 @@ def test_small_densenet(self): self._test_model(MODELS["small_densenet"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_densenet(self): self._test_quantize_model(MODELS["small_densenet"]) diff --git a/test/models_mlp_test.py b/test/models_mlp_test.py index 293a84835..2a1bfa1d4 100644 --- a/test/models_mlp_test.py +++ b/test/models_mlp_test.py @@ -26,23 +26,20 @@ def test_build_model(self): self.assertEqual(output.shape, torch.Size([2, 1])) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantize_model(self): - if get_torch_version() >= [1, 11]: - import torch.ao.quantization as tq - from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx - else: - import torch.quantization as tq - from torch.quantization.quantize_fx import convert_fx, prepare_fx + import torch.ao.quantization as tq + from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx config = {"name": "mlp", "input_dim": 3, "output_dim": 1, "hidden_dims": [2]} model = build_model(config) self.assertTrue(isinstance(model, ClassyModel)) model.eval() - model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}) + example_inputs = (torch.rand(1, 3),) + model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}, example_inputs) model.mlp = convert_fx(model.mlp) tensor = torch.tensor([[1, 2, 3]], dtype=torch.float) diff --git a/test/models_regnet_test.py b/test/models_regnet_test.py index b8090f0fb..53238d765 100644 --- a/test/models_regnet_test.py +++ b/test/models_regnet_test.py @@ -169,19 +169,18 @@ def test_quantize_model(self, config): Test that the model builds using a config using either model_params or model_name and calls fx graph mode quantization apis """ - if get_torch_version() < [1, 8]: - self.skipTest("FX Graph Modee Quantization is only availablee from 1.8") - if get_torch_version() >= [1, 11]: - import torch.ao.quantization as tq - from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx - else: - import torch.quantization as tq - from torch.quantization.quantize_fx import convert_fx, prepare_fx + if get_torch_version() < [1, 13]: + self.skipTest( + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13" + ) + import torch.ao.quantization as tq + from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx model = build_model(config) assert isinstance(model, RegNet) model.eval() - model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}) + example_inputs = (torch.rand(1, 3, 3, 3),) + model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}, example_inputs) model.stem = convert_fx(model.stem) diff --git a/test/models_resnext_test.py b/test/models_resnext_test.py index a2b660a67..f98d745c0 100644 --- a/test/models_resnext_test.py +++ b/test/models_resnext_test.py @@ -107,18 +107,29 @@ def _post_training_quantize(model, input): ] # we need to keep the modules used in head standalone since # it will be accessed with path name directly in execution + # TODO[quant-example-inputs]: Fix the shape if it is needed in quantization + standalone_example_inputs = (torch.rand(1, 3, 3, 3),) prepare_custom_config_dict["standalone_module_name"] = [ ( head, {"": tq.default_qconfig}, + standalone_example_inputs, {"input_quantized_idxs": [0], "output_quantized_idxs": []}, None, ) for head in head_path_from_blocks ] - model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig}) + # TODO[quant-example-inputs]: Fix the shape if it is needed in quantization + example_inputs = (torch.rand(1, 3, 3, 3),) + model.initial_block = prepare_fx( + model.initial_block, {"": tq.default_qconfig}, example_inputs + ) + model.blocks = prepare_fx( - model.blocks, {"": tq.default_qconfig}, prepare_custom_config_dict + model.blocks, + {"": tq.default_qconfig}, + example_inputs, + prepare_custom_config_dict, ) model.set_heads(heads) @@ -222,8 +233,8 @@ def test_small_resnext(self): self._test_model(MODELS["small_resnext"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnext(self): self._test_quantize_model(MODELS["small_resnext"]) @@ -232,8 +243,8 @@ def test_small_resnet(self): self._test_model(MODELS["small_resnet"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnet(self): self._test_quantize_model(MODELS["small_resnet"]) @@ -242,8 +253,8 @@ def test_small_resnet_se(self): self._test_model(MODELS["small_resnet_se"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnet_se(self): self._test_quantize_model(MODELS["small_resnet_se"])