From 4846183847d5e0ca22f1271435a4b70f3a62e01e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 20 May 2022 15:21:47 -0700 Subject: [PATCH] Add required example_inputs argument to prepare_fx and prepare_qat_fx (#77608) Summary: X-link: https://github.com/pytorch/pytorch/pull/77608 X-link: https://github.com/pytorch/fx2trt/pull/76 X-link: https://github.com/facebookresearch/d2go/pull/249 X-link: https://github.com/fairinternal/ClassyVision/pull/104 X-link: https://github.com/pytorch/benchmark/pull/916 Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/791 X-link: https://github.com/facebookresearch/mobile-vision/pull/68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make https://github.com/pytorch/pytorch/pull/76496#discussion_r861230047 (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 68861512ebb3dd1ae6e069a144c30eee2d63e5da --- test/models_densenet_test.py | 17 ++++++++++++++--- test/models_mlp_test.py | 15 ++++++--------- test/models_regnet_test.py | 17 ++++++++--------- test/models_resnext_test.py | 27 +++++++++++++++++++-------- 4 files changed, 47 insertions(+), 29 deletions(-) diff --git a/test/models_densenet_test.py b/test/models_densenet_test.py index 72132d807..d8395e18a 100644 --- a/test/models_densenet_test.py +++ b/test/models_densenet_test.py @@ -113,21 +113,32 @@ def _test_quantize_model(self, model_config): _find_block_full_path(model.features, block_name) for block_name in heads.keys() ] + # TODO[quant-example-inputs]: The dimension here is random, if we need to + # use dimension/rank in the future we'd need to get the correct dimensions + standalone_example_inputs = (torch.randn(1, 3, 3, 3),) # we need to keep the modules used in head standalone since # it will be accessed with path name directly in execution prepare_custom_config_dict["standalone_module_name"] = [ ( head, {"": tq.default_qconfig}, + standalone_example_inputs, {"input_quantized_idxs": [0], "output_quantized_idxs": []}, None, ) for head in head_path_from_blocks ] - model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig}) + # TODO[quant-example-inputs]: The dimension here is random, if we need to + # use dimension/rank in the future we'd need to get the correct dimensions + example_inputs = (torch.randn(1, 3, 3, 3),) + model.initial_block = prepare_fx( + model.initial_block, {"": tq.default_qconfig}, example_inputs + ) + model.features = prepare_fx( model.features, {"": tq.default_qconfig}, + example_inputs, prepare_custom_config_dict, ) model.set_heads(heads) @@ -148,8 +159,8 @@ def test_small_densenet(self): self._test_model(MODELS["small_densenet"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_densenet(self): self._test_quantize_model(MODELS["small_densenet"]) diff --git a/test/models_mlp_test.py b/test/models_mlp_test.py index 293a84835..2a1bfa1d4 100644 --- a/test/models_mlp_test.py +++ b/test/models_mlp_test.py @@ -26,23 +26,20 @@ def test_build_model(self): self.assertEqual(output.shape, torch.Size([2, 1])) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantize_model(self): - if get_torch_version() >= [1, 11]: - import torch.ao.quantization as tq - from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx - else: - import torch.quantization as tq - from torch.quantization.quantize_fx import convert_fx, prepare_fx + import torch.ao.quantization as tq + from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx config = {"name": "mlp", "input_dim": 3, "output_dim": 1, "hidden_dims": [2]} model = build_model(config) self.assertTrue(isinstance(model, ClassyModel)) model.eval() - model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}) + example_inputs = (torch.rand(1, 3),) + model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}, example_inputs) model.mlp = convert_fx(model.mlp) tensor = torch.tensor([[1, 2, 3]], dtype=torch.float) diff --git a/test/models_regnet_test.py b/test/models_regnet_test.py index b8090f0fb..53238d765 100644 --- a/test/models_regnet_test.py +++ b/test/models_regnet_test.py @@ -169,19 +169,18 @@ def test_quantize_model(self, config): Test that the model builds using a config using either model_params or model_name and calls fx graph mode quantization apis """ - if get_torch_version() < [1, 8]: - self.skipTest("FX Graph Modee Quantization is only availablee from 1.8") - if get_torch_version() >= [1, 11]: - import torch.ao.quantization as tq - from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx - else: - import torch.quantization as tq - from torch.quantization.quantize_fx import convert_fx, prepare_fx + if get_torch_version() < [1, 13]: + self.skipTest( + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13" + ) + import torch.ao.quantization as tq + from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx model = build_model(config) assert isinstance(model, RegNet) model.eval() - model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}) + example_inputs = (torch.rand(1, 3, 3, 3),) + model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}, example_inputs) model.stem = convert_fx(model.stem) diff --git a/test/models_resnext_test.py b/test/models_resnext_test.py index a2b660a67..f98d745c0 100644 --- a/test/models_resnext_test.py +++ b/test/models_resnext_test.py @@ -107,18 +107,29 @@ def _post_training_quantize(model, input): ] # we need to keep the modules used in head standalone since # it will be accessed with path name directly in execution + # TODO[quant-example-inputs]: Fix the shape if it is needed in quantization + standalone_example_inputs = (torch.rand(1, 3, 3, 3),) prepare_custom_config_dict["standalone_module_name"] = [ ( head, {"": tq.default_qconfig}, + standalone_example_inputs, {"input_quantized_idxs": [0], "output_quantized_idxs": []}, None, ) for head in head_path_from_blocks ] - model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig}) + # TODO[quant-example-inputs]: Fix the shape if it is needed in quantization + example_inputs = (torch.rand(1, 3, 3, 3),) + model.initial_block = prepare_fx( + model.initial_block, {"": tq.default_qconfig}, example_inputs + ) + model.blocks = prepare_fx( - model.blocks, {"": tq.default_qconfig}, prepare_custom_config_dict + model.blocks, + {"": tq.default_qconfig}, + example_inputs, + prepare_custom_config_dict, ) model.set_heads(heads) @@ -222,8 +233,8 @@ def test_small_resnext(self): self._test_model(MODELS["small_resnext"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnext(self): self._test_quantize_model(MODELS["small_resnext"]) @@ -232,8 +243,8 @@ def test_small_resnet(self): self._test_model(MODELS["small_resnet"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnet(self): self._test_quantize_model(MODELS["small_resnet"]) @@ -242,8 +253,8 @@ def test_small_resnet_se(self): self._test_model(MODELS["small_resnet_se"]) @unittest.skipIf( - get_torch_version() < [1, 8], - "FX Graph Modee Quantization is only availablee from 1.8", + get_torch_version() < [1, 13], + "This test is using a new api of FX Graph Mode Quantization which is only available after 1.13", ) def test_quantized_small_resnet_se(self): self._test_quantize_model(MODELS["small_resnet_se"])