From 2d1152b52c24c63ffd227ec2e4eb6db6fbaaab49 Mon Sep 17 00:00:00 2001 From: notsyncing Date: Sun, 28 Jan 2024 21:48:25 +0800 Subject: [PATCH] Fix XPU inference Though it will complain about "Device xpu is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'", but you cannot just put 0 as device, or it will treat 0 as CUDA device, then complains again that torch is not compiled with CUDA enabled. You will need safetensors >= 0.4.2 if using safetensors files. --- src/accelerate/big_modeling.py | 5 +++++ src/accelerate/utils/modeling.py | 11 ++++++++++- src/accelerate/utils/operations.py | 7 ++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 4e6ab7ef8ec..dc6e08eef59 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -37,6 +37,7 @@ get_balanced_memory, infer_auto_device_map, is_npu_available, + is_xpu_available, is_torch_version, load_checkpoint_in_model, offload_state_dict, @@ -451,6 +452,8 @@ def wrapper(*args, **kwargs): model.to = add_warning(model.to, model) if is_npu_available(): model.npu = add_warning(model.npu, model) + elif is_xpu_available(): + model.xpu = add_warning(model.xpu, model) else: model.cuda = add_warning(model.cuda, model) @@ -459,6 +462,8 @@ def wrapper(*args, **kwargs): # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + elif is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if device != "disk": model.to(device) else: diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index ebaa56250e6..86da0517677 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -367,6 +367,8 @@ def set_module_tensor_to_device( # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + if is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: @@ -427,6 +429,8 @@ def set_module_tensor_to_device( # clean pre and post foward hook if is_npu_available(): torch.npu.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() else: torch.cuda.empty_cache() @@ -1351,7 +1355,12 @@ def load_state_dict(checkpoint_file, device_map=None): else: progress_bar = None for device in devices: - with safe_open(checkpoint_file, framework="pt", device=device) as f: + target_device = device + + if is_xpu_available() and isinstance(device, int): + target_device = f"xpu:{device}" + + with safe_open(checkpoint_file, framework="pt", device=target_device) as f: for key in device_weights[device]: if progress_bar is not None: progress_bar.set_postfix(dev=device, refresh=False) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4d4dedff447..6df9917a089 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -26,7 +26,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available +from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available, is_xpu_available if is_tpu_available(check_device=False): @@ -171,6 +171,11 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None): # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). elif device == torch.device("npu"): device = "npu:0" + elif is_xpu_available(): + if isinstance(device, int): + device = f"xpu:{device}" + elif device == torch.device("xpu"): + device = "xpu:0" try: return tensor.to(device, non_blocking=non_blocking) except TypeError: # .to() doesn't accept non_blocking as kwarg