From 97c9466ce4e5c9acaa55a727ed90e3d38bdf8bbc Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Tue, 25 Oct 2022 11:36:51 +0800 Subject: [PATCH] Fixed PyTorch QAT bug for imagenet model and ssd model. (#1385) --- .../torchvision_models/quantization/qat/fx/main.py | 1 + .../ssd_resnet34/quantization/qat/fx/ssd/main.py | 1 + neural_compressor/adaptor/pytorch.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/image_recognition/torchvision_models/quantization/qat/fx/main.py b/examples/pytorch/image_recognition/torchvision_models/quantization/qat/fx/main.py index 34e59f89372..c7246d000ef 100644 --- a/examples/pytorch/image_recognition/torchvision_models/quantization/qat/fx/main.py +++ b/examples/pytorch/image_recognition/torchvision_models/quantization/qat/fx/main.py @@ -195,6 +195,7 @@ def training_func_for_nc(model): quantizer = Quantization(args.config) quantizer.model = common.Model(model) quantizer.q_func = training_func_for_nc + quantizer.calib_dataloader = val_loader quantizer.eval_dataloader = val_loader q_model = quantizer.fit() q_model.save(args.tuned_checkpoint) diff --git a/examples/pytorch/object_detection/ssd_resnet34/quantization/qat/fx/ssd/main.py b/examples/pytorch/object_detection/ssd_resnet34/quantization/qat/fx/ssd/main.py index ab9aabf839b..4c3492c12b4 100644 --- a/examples/pytorch/object_detection/ssd_resnet34/quantization/qat/fx/ssd/main.py +++ b/examples/pytorch/object_detection/ssd_resnet34/quantization/qat/fx/ssd/main.py @@ -399,6 +399,7 @@ def training_func_for_nc(model): quantizer.model = common.Model(ssd300) quantizer.eval_func = eval_func quantizer.q_func = training_func_for_nc + quantizer.calib_dataloader = val_dataloader q_model = quantizer.fit() q_model.save(args.tuned_checkpoint) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 14fde3c46bc..c93bde3c7f0 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2606,7 +2606,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): self.tune_cfg["approach"] = self.approach self.tune_cfg["framework"] = "pytorch_fx" # pragma: no cover - if self.approach != 'post_training_dynamic_quant' and self.version > Version("1.12.1"): + if self.approach != 'post_training_dynamic_quant' and self.version.release >= Version("1.13.0").release: assert dataloader is not None, "Please pass a dataloader to quantizer!" example_inputs = get_example_inputs(dataloader) else: