diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 930a14643..1153f1556 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -204,16 +204,12 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs): for element in ["load_in_8bit", "low_cpu_mem_usage"]: if element in kwargs: use_pipeline = False + device = None if "device_map" in kwargs else self.device_id # build pipeline if use_pipeline: - if "device_map" in kwargs: hf_pipeline = pipeline(task=task, model=model_id_or_path, - **kwargs) - else: - hf_pipeline = pipeline(task=task, - model=model_id_or_path, - device=self.device_id, + device=device **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_id_or_path) @@ -223,7 +219,7 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs): hf_pipeline = pipeline(task=task, model=model, tokenizer=tokenizer, - device=self.device_id) + device=device) # wrap specific pipeline to support better ux if task == "conversational":