diff --git a/examples/cpu/inference/python/llm/single_instance/run_accuracy.py b/examples/cpu/inference/python/llm/single_instance/run_accuracy.py index bf71f9d91..adb06d970 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_accuracy.py +++ b/examples/cpu/inference/python/llm/single_instance/run_accuracy.py @@ -124,13 +124,27 @@ def __init__( config, torchscript=with_jit, trust_remote_code=True ) - self.model = model_class[0].from_pretrained( - model_id, - low_cpu_mem_usage=True, - config=self.config, - torch_dtype=load_dtype, - trust_remote_code=True, - ) + if self._dtype == "int8": + try: + with ipex.OnDevice(dtype=torch.float, device="meta"): + self.model = AutoModelForCausalLM.from_config(self.config) + except (RuntimeError, AttributeError) as e: + print('Warning: Loading model to meta device failed:', e) + self.model = model_class[0].from_pretrained( + model_id, + low_cpu_mem_usage=True, + config=self.config, + torch_dtype=load_dtype, + trust_remote_code=True, + ) + else: + self.model = model_class[0].from_pretrained( + model_id, + low_cpu_mem_usage=True, + config=self.config, + torch_dtype=load_dtype, + trust_remote_code=True, + ) self.model = self.model.eval()