Skip to content

Commit 0947688

Browse files
committed
device map
1 parent a0f308d commit 0947688

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/pipelines/hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ class HF_Pipeline(Pipeline):
99
def __init__(self, args: Namespace, device: str = "cpu") -> None:
1010
super().__init__(args)
1111

12-
model_kwargs = {"device_map": "auto"}
12+
model_kwargs = {}
1313

1414
if args.dtype == torch.int8:
1515
model_kwargs["load_in_8bit"] = True
16+
model_kwargs["device_map"] = "auto"
1617
else:
1718
model_kwargs["torch_dtype"] = args.dtype
1819

0 commit comments

Comments
 (0)