From ff1725771587a691a9841641cfc24a7fe47ba234 Mon Sep 17 00:00:00 2001 From: Chang Wang Date: Tue, 20 Dec 2022 20:05:15 +0800 Subject: [PATCH] Fix DLRM OOM issue (#299) Signed-off-by: changwa1 --- .../ptq/eager/dlrm_s_pytorch_tune.py | 2 +- .../quantization/ptq/fx/dlrm_s_pytorch_tune.py | 2 +- .../quantization/ptq/ipex/dlrm_s_pytorch.py | 2 +- neural_compressor/adaptor/pytorch.py | 10 +++++++++- neural_compressor/model/torch_model.py | 17 +++++++++++------ 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/examples/pytorch/recommendation/dlrm/quantization/ptq/eager/dlrm_s_pytorch_tune.py b/examples/pytorch/recommendation/dlrm/quantization/ptq/eager/dlrm_s_pytorch_tune.py index 1ae0a3544ce..90ecf9659ee 100644 --- a/examples/pytorch/recommendation/dlrm/quantization/ptq/eager/dlrm_s_pytorch_tune.py +++ b/examples/pytorch/recommendation/dlrm/quantization/ptq/eager/dlrm_s_pytorch_tune.py @@ -842,7 +842,7 @@ def loss_fn_wrap(Z, T, use_gpu, device): args.print_freq = ld_nbatches args.test_freq = 0 - del ld_model + del(ld_model) print( "Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format( diff --git a/examples/pytorch/recommendation/dlrm/quantization/ptq/fx/dlrm_s_pytorch_tune.py b/examples/pytorch/recommendation/dlrm/quantization/ptq/fx/dlrm_s_pytorch_tune.py index d8a4169960e..e279a03fc3e 100644 --- a/examples/pytorch/recommendation/dlrm/quantization/ptq/fx/dlrm_s_pytorch_tune.py +++ b/examples/pytorch/recommendation/dlrm/quantization/ptq/fx/dlrm_s_pytorch_tune.py @@ -847,7 +847,7 @@ def loss_fn_wrap(Z, T, use_gpu, device): args.print_freq = ld_nbatches args.test_freq = 0 - del ld_model + del(ld_model) print( "Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format( diff --git a/examples/pytorch/recommendation/dlrm/quantization/ptq/ipex/dlrm_s_pytorch.py b/examples/pytorch/recommendation/dlrm/quantization/ptq/ipex/dlrm_s_pytorch.py index d78e7ac209c..605068273b6 100644 --- a/examples/pytorch/recommendation/dlrm/quantization/ptq/ipex/dlrm_s_pytorch.py +++ b/examples/pytorch/recommendation/dlrm/quantization/ptq/ipex/dlrm_s_pytorch.py @@ -821,7 +821,7 @@ def run(): ) ) print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100)) - del ld_model + del(ld_model) ext_dist.barrier() print("time/loss/accuracy (if enabled):") diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index f32985a8290..b16f0976e42 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2584,7 +2584,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): self.q_dataloader.batch(batch_size) logger.info('Recovery `calibration.dataloader.batchsize` {} according \ to config.yaml' .format(batch_size)) - del init_model + del(init_model) with open(self.ipex_config_path, 'r') as f: self.cfgs = json.load(f) if self.version.release < Version("1.12.0").release: @@ -2776,6 +2776,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx try: q_model = copy.deepcopy(model) + q_model.fp32_model = model.fp32_model except Exception as e: # pragma: no cover logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( repr(e))) @@ -2983,6 +2984,13 @@ def _pre_hook_for_qat(self, dataloader=None): # so set it to None. example_inputs = None + # For export API, deepcopy fp32_model + try: + self.model.fp32_model = copy.deepcopy(self.model.fp32_model) + except Exception as e: # pragma: no cover + logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( + repr(e))) + if self.sub_module_list is None: if self.version.release >= Version("1.13.0").release: # pragma: no cover # pylint: disable=E1123 diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 5560158358e..438f9846a00 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -45,12 +45,7 @@ def __init__(self, model, **kwargs): self.q_config = None self._workspace_path = '' self.is_quantized = False - try: - self.fp32_model = copy.deepcopy(model) - except Exception as e: # pragma: no cover - logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( - repr(e))) - self.fp32_model = model + self.fp32_model = model self.kwargs = kwargs if kwargs else None def __repr__(self): @@ -93,6 +88,16 @@ def model(self, model): """ Setter to model """ self._model = model + @property + def fp32_model(self): + """ Getter to model """ + return self._fp32_model + + @fp32_model.setter + def fp32_model(self, fp32_model): + """ Setter to model """ + self._fp32_model = fp32_model + def register_forward_pre_hook(self): self.handles.append( self._model.register_forward_pre_hook(self.generate_forward_pre_hook()))