From 0cd25021952bddcf5a364da45dfbefd4a0c77af4 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Wed, 25 Oct 2023 15:28:26 +0800 Subject: [PATCH] llm example run_accuracy.py: load model to meta device for quantization (#2195) * llm example run_accuracy.py: load model to meta device for quantization * Add more print * Print exception * Fix typo ipex._IPEXOnDevice -> ipex.IPEXOnDevice * ipex.IPEXOnDevice -> ipex.OnDevice * Fix typo _from_config -> from_config * Remove mem usage print --- .../llm/single_instance/run_accuracy.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/examples/cpu/inference/python/llm/single_instance/run_accuracy.py b/examples/cpu/inference/python/llm/single_instance/run_accuracy.py index bf71f9d91..adb06d970 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_accuracy.py +++ b/examples/cpu/inference/python/llm/single_instance/run_accuracy.py @@ -124,13 +124,27 @@ def __init__( config, torchscript=with_jit, trust_remote_code=True ) - self.model = model_class[0].from_pretrained( - model_id, - low_cpu_mem_usage=True, - config=self.config, - torch_dtype=load_dtype, - trust_remote_code=True, - ) + if self._dtype == "int8": + try: + with ipex.OnDevice(dtype=torch.float, device="meta"): + self.model = AutoModelForCausalLM.from_config(self.config) + except (RuntimeError, AttributeError) as e: + print('Warning: Loading model to meta device failed:', e) + self.model = model_class[0].from_pretrained( + model_id, + low_cpu_mem_usage=True, + config=self.config, + torch_dtype=load_dtype, + trust_remote_code=True, + ) + else: + self.model = model_class[0].from_pretrained( + model_id, + low_cpu_mem_usage=True, + config=self.config, + torch_dtype=load_dtype, + trust_remote_code=True, + ) self.model = self.model.eval()