From f7c21aaf19412aebc093065391c2a6f538ba58dd Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 4 Apr 2022 22:46:20 -0400 Subject: [PATCH] [quantization-refactor] mark/propagate conv export mode --- .../quantization/modifier_quantization.py | 6 ++++++ src/sparseml/pytorch/utils/exporter.py | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 210e57df76b..7c2046a640f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -673,6 +673,12 @@ def _enable_module_qat(self, module: Module): self._qat_enabled = True self._calibrate_if_possible(module) + # mark export mode for module Conv layers + module.export_with_qlinearconv = self._quantize_conv_activations + if hasattr(module, "module"): + # for DP/DDP unwrapping + module.module.export_with_qlinearconv = self._quantize_conv_activations + def _calibrate_if_possible(self, module): if self.num_calibration_steps == 0 and self._calibration_dataloader: warnings.warn( diff --git a/src/sparseml/pytorch/utils/exporter.py b/src/sparseml/pytorch/utils/exporter.py index b00ba190cda..a987287aba7 100644 --- a/src/sparseml/pytorch/utils/exporter.py +++ b/src/sparseml/pytorch/utils/exporter.py @@ -498,7 +498,15 @@ def export_onnx( quantize_torch_qat_export, ) - quantize_torch_qat_export(model=file_path, output_file_path=file_path) + use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( + module.export_with_qlinearconv + ) + + quantize_torch_qat_export( + model=file_path, + output_file_path=file_path, + use_qlinearconv=use_qlinearconv, + ) if skip_input_quantize: try: