1
- from infinity_emb ._optional_imports import CHECK_TORCH , CHECK_TRANSFORMERS
1
+ from infinity_emb ._optional_imports import CHECK_TORCH , CHECK_TRANSFORMERS , CHECK_XLA
2
2
from infinity_emb .args import EngineArgs
3
3
from infinity_emb .primitives import InferenceEngine , Device , Dtype , DeviceID , LoadingStrategy
4
4
7
7
import torch
8
8
if CHECK_TRANSFORMERS .is_available :
9
9
from transformers import is_torch_npu_available # type: ignore
10
+ from transformers .utils .import_utils import is_torch_xla_available # type: ignore
11
+
12
+ if CHECK_XLA .is_available :
13
+ import torch_xla # type: ignore
10
14
11
15
12
16
def _validate_availale_device_ids (
@@ -35,6 +39,8 @@ def get_loading_strategy_torch(args: EngineArgs) -> LoadingStrategy:
35
39
autodevice = "npu"
36
40
elif torch .backends .mps .is_available ():
37
41
autodevice = "mps"
42
+ elif is_torch_xla_available ():
43
+ autodevice = "xla"
38
44
else :
39
45
autodevice = "cpu"
40
46
else :
@@ -58,6 +64,10 @@ def get_loading_strategy_torch(args: EngineArgs) -> LoadingStrategy:
58
64
elif autodevice == "cpu" :
59
65
# spawn multiple processes on CPU. This is useful for debugging, but not for performance.
60
66
autodevice_string = ["cpu" ] * max (len (args .device_id ), 1 )
67
+ elif autodevice == "xla" :
68
+ autodevice_string = _validate_availale_device_ids (
69
+ "xla" , list (range (torch_xla .device_count ())), args .device_id
70
+ )
61
71
else :
62
72
raise ValueError (f"Unknown device { autodevice } " )
63
73
0 commit comments