diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 210e57df76b..7c2046a640f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -673,6 +673,12 @@ def _enable_module_qat(self, module: Module): self._qat_enabled = True self._calibrate_if_possible(module) + # mark export mode for module Conv layers + module.export_with_qlinearconv = self._quantize_conv_activations + if hasattr(module, "module"): + # for DP/DDP unwrapping + module.module.export_with_qlinearconv = self._quantize_conv_activations + def _calibrate_if_possible(self, module): if self.num_calibration_steps == 0 and self._calibration_dataloader: warnings.warn( diff --git a/src/sparseml/pytorch/utils/exporter.py b/src/sparseml/pytorch/utils/exporter.py index b00ba190cda..a987287aba7 100644 --- a/src/sparseml/pytorch/utils/exporter.py +++ b/src/sparseml/pytorch/utils/exporter.py @@ -498,7 +498,15 @@ def export_onnx( quantize_torch_qat_export, ) - quantize_torch_qat_export(model=file_path, output_file_path=file_path) + use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( + module.export_with_qlinearconv + ) + + quantize_torch_qat_export( + model=file_path, + output_file_path=file_path, + use_qlinearconv=use_qlinearconv, + ) if skip_input_quantize: try: