We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a0f308d commit 0947688Copy full SHA for 0947688
src/pipelines/hf.py
@@ -9,10 +9,11 @@ class HF_Pipeline(Pipeline):
9
def __init__(self, args: Namespace, device: str = "cpu") -> None:
10
super().__init__(args)
11
12
- model_kwargs = {"device_map": "auto"}
+ model_kwargs = {}
13
14
if args.dtype == torch.int8:
15
model_kwargs["load_in_8bit"] = True
16
+ model_kwargs["device_map"] = "auto"
17
else:
18
model_kwargs["torch_dtype"] = args.dtype
19
0 commit comments