Skip to content

Commit

Permalink
fix huggingface device bugs (#813)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jun 8, 2023
1 parent b28d251 commit 4cda39a
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,12 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
for element in ["load_in_8bit", "low_cpu_mem_usage"]:
if element in kwargs:
use_pipeline = False
device = None if "device_map" in kwargs else self.device_id
# build pipeline
if use_pipeline:
if "device_map" in kwargs:
hf_pipeline = pipeline(task=task,
model=model_id_or_path,
**kwargs)
else:
hf_pipeline = pipeline(task=task,
model=model_id_or_path,
device=self.device_id,
device=device
**kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
Expand All @@ -223,7 +219,7 @@ def get_pipeline(self, task: str, model_id_or_path: str, kwargs):
hf_pipeline = pipeline(task=task,
model=model,
tokenizer=tokenizer,
device=self.device_id)
device=device)

# wrap specific pipeline to support better ux
if task == "conversational":
Expand Down

0 comments on commit 4cda39a

Please sign in to comment.