From 42dcc0b7989de3396bca5803994e4801a95e42ee Mon Sep 17 00:00:00 2001 From: notsyncing Date: Sun, 28 Jan 2024 21:48:25 +0800 Subject: [PATCH] Fix XPU inference Though it will complain about "Device xpu is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'", but you cannot just put 0 as device, or it will treat 0 as CUDA device, then complains again that torch is not compiled with CUDA enabled. You will need safetensors >= 0.4.2 if using safetensors files. --- src/accelerate/big_modeling.py | 5 +++++ src/accelerate/utils/modeling.py | 22 +++++++++++++++++++++- src/accelerate/utils/operations.py | 13 ++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 4e6ab7ef8ec..ecab3070761 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -38,6 +38,7 @@ infer_auto_device_map, is_npu_available, is_torch_version, + is_xpu_available, load_checkpoint_in_model, offload_state_dict, parse_flag_from_env, @@ -451,6 +452,8 @@ def wrapper(*args, **kwargs): model.to = add_warning(model.to, model) if is_npu_available(): model.npu = add_warning(model.npu, model) + elif is_xpu_available(): + model.xpu = add_warning(model.xpu, model) else: model.cuda = add_warning(model.cuda, model) @@ -459,6 +462,8 @@ def wrapper(*args, **kwargs): # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + elif is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if device != "disk": model.to(device) else: diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index ebaa56250e6..4d46a3a1e85 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -14,6 +14,7 @@ import contextlib import gc +import importlib import inspect import json import logging @@ -24,6 +25,7 @@ from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Tuple, Union +import packaging import torch import torch.nn as nn @@ -33,6 +35,7 @@ from .imports import is_mps_available, is_npu_available, is_peft_available, is_xpu_available from .offload import load_offloaded_weight, offload_weight, save_offload_index from .tqdm import is_tqdm_available, tqdm +from .versions import compare_versions if is_npu_available(check_device=False): @@ -367,6 +370,8 @@ def set_module_tensor_to_device( # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + if is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: @@ -427,6 +432,8 @@ def set_module_tensor_to_device( # clean pre and post foward hook if is_npu_available(): torch.npu.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() else: torch.cuda.empty_cache() @@ -1351,7 +1358,20 @@ def load_state_dict(checkpoint_file, device_map=None): else: progress_bar = None for device in devices: - with safe_open(checkpoint_file, framework="pt", device=device) as f: + target_device = device + + if is_xpu_available(): + current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors")) + + if compare_versions(current_safetensors_version, "<", "0.4.2"): + raise ModuleNotFoundError( + f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}" + ) + + if isinstance(device, int): + target_device = f"xpu:{device}" + + with safe_open(checkpoint_file, framework="pt", device=target_device) as f: for key in device_weights[device]: if progress_bar is not None: progress_bar.set_postfix(dev=device, refresh=False) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4d4dedff447..27b2dc46457 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -26,7 +26,13 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available +from .imports import ( + is_npu_available, + is_torch_distributed_available, + is_torch_version, + is_tpu_available, + is_xpu_available, +) if is_tpu_available(check_device=False): @@ -171,6 +177,11 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None): # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). elif device == torch.device("npu"): device = "npu:0" + elif is_xpu_available(): + if isinstance(device, int): + device = f"xpu:{device}" + elif device == torch.device("xpu"): + device = "xpu:0" try: return tensor.to(device, non_blocking=non_blocking) except TypeError: # .to() doesn't accept non_blocking as kwarg