Skip to content

Commit

Permalink
Fixed PyTorch QAT bug for imagenet model and ssd model. (#1385)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng committed Oct 25, 2022
1 parent faa131b commit 97c9466
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
Expand Up @@ -195,6 +195,7 @@ def training_func_for_nc(model):
quantizer = Quantization(args.config)
quantizer.model = common.Model(model)
quantizer.q_func = training_func_for_nc
quantizer.calib_dataloader = val_loader
quantizer.eval_dataloader = val_loader
q_model = quantizer.fit()
q_model.save(args.tuned_checkpoint)
Expand Down
Expand Up @@ -399,6 +399,7 @@ def training_func_for_nc(model):
quantizer.model = common.Model(ssd300)
quantizer.eval_func = eval_func
quantizer.q_func = training_func_for_nc
quantizer.calib_dataloader = val_dataloader
q_model = quantizer.fit()
q_model.save(args.tuned_checkpoint)

Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/pytorch.py
Expand Up @@ -2606,7 +2606,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
self.tune_cfg["approach"] = self.approach
self.tune_cfg["framework"] = "pytorch_fx"
# pragma: no cover
if self.approach != 'post_training_dynamic_quant' and self.version > Version("1.12.1"):
if self.approach != 'post_training_dynamic_quant' and self.version.release >= Version("1.13.0").release:
assert dataloader is not None, "Please pass a dataloader to quantizer!"
example_inputs = get_example_inputs(dataloader)
else:
Expand Down

0 comments on commit 97c9466

Please sign in to comment.