From 659d5014084319286c0e5b57531253cc8c97e987 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 21 May 2022 13:49:52 -0700 Subject: [PATCH] Add required example_inputs argument to prepare_fx and prepare_qat_fx Summary: X-link: https://github.com/pytorch/pytorch/pull/77608 X-link: https://github.com/pytorch/fx2trt/pull/76 X-link: https://github.com/facebookresearch/d2go/pull/249 X-link: https://github.com/fairinternal/ClassyVision/pull/104 X-link: https://github.com/pytorch/benchmark/pull/916 X-link: https://github.com/facebookresearch/ClassyVision/pull/791 Pull Request resolved: https://github.com/facebookresearch/mobile-vision/pull/68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make https://github.com/pytorch/pytorch/pull/76496#discussion_r861230047 (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 706c8df71722c9aa5082a6491734f0144f0dd670 --- mobile_cv/arch/tests/test_fbnet_v2_quantize.py | 6 +++++- mobile_cv/arch/tests/test_fbnet_v2_res_block.py | 3 ++- mobile_cv/arch/tests/test_utils_quantize_utils.py | 3 ++- mobile_cv/arch/utils/quantize_utils.py | 6 ++++-- mobile_cv/model_zoo/tools/create_model.py | 3 ++- mobile_cv/model_zoo/tools/model_exporter.py | 6 +++++- 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mobile_cv/arch/tests/test_fbnet_v2_quantize.py b/mobile_cv/arch/tests/test_fbnet_v2_quantize.py index a197e20b..335f323b 100644 --- a/mobile_cv/arch/tests/test_fbnet_v2_quantize.py +++ b/mobile_cv/arch/tests/test_fbnet_v2_quantize.py @@ -86,7 +86,11 @@ def test_qat(self): model.train() qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig("fbgemm")} - model_prepared = quantize_fx.prepare_qat_fx(model, qconfig_dict) + example_inputs = (torch.rand(2, 3, 8, 8),) + model_prepared = quantize_fx.prepare_qat_fx( + model, qconfig_dict, example_inputs=example_inputs + ) + print(f"Prepared model {model_prepared}") # calibration diff --git a/mobile_cv/arch/tests/test_fbnet_v2_res_block.py b/mobile_cv/arch/tests/test_fbnet_v2_res_block.py index 789cee88..b91237d3 100644 --- a/mobile_cv/arch/tests/test_fbnet_v2_res_block.py +++ b/mobile_cv/arch/tests/test_fbnet_v2_res_block.py @@ -94,7 +94,8 @@ def test_res_block_quantize_partial(self): data = torch.zeros(1, 8, 4, 4) qconfig_dict = qu.get_qconfig_dict(model, qconfig) - model = prepare_fx(model, qconfig_dict) + example_inputs = (data,) + model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) model = convert_fx(model) print(model) diff --git a/mobile_cv/arch/tests/test_utils_quantize_utils.py b/mobile_cv/arch/tests/test_utils_quantize_utils.py index 415e4400..02bab121 100644 --- a/mobile_cv/arch/tests/test_utils_quantize_utils.py +++ b/mobile_cv/arch/tests/test_utils_quantize_utils.py @@ -361,7 +361,8 @@ def forward(self, x): model = MM().eval() qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") qconfig_dict = qu.get_qconfig_dict(model, qconfig) - model = prepare_fx(model, qconfig_dict) + example_inputs = (torch.rand(1, 1, 3, 3),) + model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) model = convert_fx(model) print(model) diff --git a/mobile_cv/arch/utils/quantize_utils.py b/mobile_cv/arch/utils/quantize_utils.py index 2cefdde2..4176093b 100644 --- a/mobile_cv/arch/utils/quantize_utils.py +++ b/mobile_cv/arch/utils/quantize_utils.py @@ -186,13 +186,15 @@ def set_quant_config(self, quant_cfg): self.qconfig = quant_cfg return self - def prepare(self, qconfig_dict=None): + def prepare(self, example_inputs, qconfig_dict=None): if qconfig_dict is None: qconfig_dict = get_qconfig_dict(self.model, self.qconfig) if qconfig_dict is None: qconfig_dict = {"": self.qconfig} self._prepared_model = torch.ao.quantization.quantize_fx.prepare_fx( - self.model, qconfig_dict + self.model, + qconfig_dict=qconfig_dict, + example_inputs=example_inputs, ) return self diff --git a/mobile_cv/model_zoo/tools/create_model.py b/mobile_cv/model_zoo/tools/create_model.py index f31debe8..83e5edba 100644 --- a/mobile_cv/model_zoo/tools/create_model.py +++ b/mobile_cv/model_zoo/tools/create_model.py @@ -140,9 +140,10 @@ def convert_int8_jit(args, model, data, folder_name="int8_jit"): ) else: quant = qu.PostQuantizationFX(model) + example_inputs = tuple(data) quant_model = ( quant.set_quant_backend("default") - .prepare() + .prepare(example_inputs=example_inputs) .calibrate_model([data], 1) .convert_model() ) diff --git a/mobile_cv/model_zoo/tools/model_exporter.py b/mobile_cv/model_zoo/tools/model_exporter.py index 8abeccb5..a28537e1 100644 --- a/mobile_cv/model_zoo/tools/model_exporter.py +++ b/mobile_cv/model_zoo/tools/model_exporter.py @@ -309,19 +309,23 @@ def export_to_torchscript_int8( ): cur_loader = itertools.chain([inputs], data_iter) + example_inputs = tuple(inputs) if hasattr(task, "get_quantized_model"): + print("calling get quantized model") ptq_model = task.get_quantized_model(model, cur_loader) model_attrs = _get_model_attributes(ptq_model) + print("after calling get quantized model") elif args.use_graph_mode_quant: print(f"Post quantization using {args.post_quant_backend} backend fx mode...") model_attrs = _get_model_attributes(model) quant = quantize_utils.PostQuantizationFX(model) ptq_model = ( quant.set_quant_backend(args.post_quant_backend) - .prepare() + .prepare(example_inputs=example_inputs) .calibrate_model(cur_loader, 1) .convert_model() ) + print("after calling callback") else: print(f"Post quantization using {args.post_quant_backend} backend...") qa_model = task.get_quantizable_model(model)