Skip to content

Commit

Permalink
Fixed bf16 error in QAT for torch version < 1.11 (#279)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <penghui.cheng@intel.com>
  • Loading branch information
PenghuiCheng committed Dec 15, 2022
1 parent 0878bea commit eda8cb7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -1391,10 +1391,10 @@ def _pre_hook_for_qat(self, dataloader=None):
self.non_quant_dict = self.get_non_quant_modules(self.model.kwargs)
quantizable_ops = []
self._get_quantizable_ops_recursively(self.model._model, '', quantizable_ops)
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
bf16_ops = []
if self.version.release >= Version("1.11.0").release and self.use_bf16 and \
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
self._get_bf16_ops_recursively(self.model._model, '', bf16_ops)
bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops]
self.model.model.training = True
Expand Down Expand Up @@ -2949,10 +2949,10 @@ def _pre_hook_for_qat(self, dataloader=None):
quantizable_ops = []
tmp_model = self.fuse_fx_model(self.model, is_qat=True)
self._get_quantizable_ops_recursively(tmp_model, '', quantizable_ops)
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
bf16_ops = []
if self.version.release >= Version("1.11.0").release and self.use_bf16 and \
(CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1'): # pragma: no cover
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
self._get_bf16_ops_recursively(tmp_model, '', bf16_ops)
bf16_ops_list = [(op) for op in bf16_ops if op not in quantizable_ops]
quantized_ops = OrderedDict()
Expand Down

0 comments on commit eda8cb7

Please sign in to comment.