Skip to content

Commit f98ccf4

Browse files
Add TPU support (#629)
* tpu support 1 * change package name * run format * add torch_xla dependency * run poetry lock --no-update * Delete libs/infinity_emb/poetry.lock * Update pyproject.toml * Create poetry.lock --------- Co-authored-by: Michael Feil <63565275+michaelfeil@users.noreply.github.com>
1 parent ff80951 commit f98ccf4

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

libs/infinity_emb/infinity_emb/_optional_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ def _raise_error(self) -> None:
7878
CHECK_TRANSFORMERS = OptionalImports("transformers", "torch")
7979
CHECK_TYPER = OptionalImports("typer", "server")
8080
CHECK_UVICORN = OptionalImports("uvicorn", "server")
81+
CHECK_XLA = OptionalImports("torch_xla", "torch_xla")

libs/infinity_emb/infinity_emb/inference/loading_strategy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from infinity_emb._optional_imports import CHECK_TORCH, CHECK_TRANSFORMERS
1+
from infinity_emb._optional_imports import CHECK_TORCH, CHECK_TRANSFORMERS, CHECK_XLA
22
from infinity_emb.args import EngineArgs
33
from infinity_emb.primitives import InferenceEngine, Device, Dtype, DeviceID, LoadingStrategy
44

@@ -7,6 +7,10 @@
77
import torch
88
if CHECK_TRANSFORMERS.is_available:
99
from transformers import is_torch_npu_available # type: ignore
10+
from transformers.utils.import_utils import is_torch_xla_available # type: ignore
11+
12+
if CHECK_XLA.is_available:
13+
import torch_xla # type: ignore
1014

1115

1216
def _validate_availale_device_ids(
@@ -35,6 +39,8 @@ def get_loading_strategy_torch(args: EngineArgs) -> LoadingStrategy:
3539
autodevice = "npu"
3640
elif torch.backends.mps.is_available():
3741
autodevice = "mps"
42+
elif is_torch_xla_available():
43+
autodevice = "xla"
3844
else:
3945
autodevice = "cpu"
4046
else:
@@ -58,6 +64,10 @@ def get_loading_strategy_torch(args: EngineArgs) -> LoadingStrategy:
5864
elif autodevice == "cpu":
5965
# spawn multiple processes on CPU. This is useful for debugging, but not for performance.
6066
autodevice_string = ["cpu"] * max(len(args.device_id), 1)
67+
elif autodevice == "xla":
68+
autodevice_string = _validate_availale_device_ids(
69+
"xla", list(range(torch_xla.device_count())), args.device_id
70+
)
6171
else:
6272
raise ValueError(f"Unknown device {autodevice}")
6373

libs/infinity_emb/infinity_emb/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class Device(EnumType):
109109
cuda = "cuda"
110110
mps = "mps"
111111
tensorrt = "tensorrt"
112+
xla = "xla"
112113
auto = "auto"
113114

114115
@staticmethod

0 commit comments

Comments
 (0)