Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XPU inference #2383

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
infer_auto_device_map,
is_npu_available,
is_torch_version,
is_xpu_available,
load_checkpoint_in_model,
offload_state_dict,
parse_flag_from_env,
Expand Down Expand Up @@ -451,6 +452,8 @@ def wrapper(*args, **kwargs):
model.to = add_warning(model.to, model)
if is_npu_available():
model.npu = add_warning(model.npu, model)
elif is_xpu_available():
model.xpu = add_warning(model.xpu, model)
else:
model.cuda = add_warning(model.cuda, model)

Expand All @@ -459,6 +462,8 @@ def wrapper(*args, **kwargs):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
elif is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
if device != "disk":
model.to(device)
else:
Expand Down
22 changes: 21 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextlib
import gc
import importlib
import inspect
import json
import logging
Expand All @@ -24,6 +25,7 @@
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union

import packaging
import torch
import torch.nn as nn

Expand All @@ -33,6 +35,7 @@
from .imports import is_mps_available, is_npu_available, is_peft_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import compare_versions


if is_npu_available(check_device=False):
Expand Down Expand Up @@ -367,6 +370,8 @@ def set_module_tensor_to_device(
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
Expand Down Expand Up @@ -427,6 +432,8 @@ def set_module_tensor_to_device(
# clean pre and post foward hook
if is_npu_available():
torch.npu.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down Expand Up @@ -1351,7 +1358,20 @@ def load_state_dict(checkpoint_file, device_map=None):
else:
progress_bar = None
for device in devices:
with safe_open(checkpoint_file, framework="pt", device=device) as f:
target_device = device

if is_xpu_available():
current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors"))

if compare_versions(current_safetensors_version, "<", "0.4.2"):
raise ModuleNotFoundError(
f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}"
)

if isinstance(device, int):
target_device = f"xpu:{device}"

with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
for key in device_weights[device]:
if progress_bar is not None:
progress_bar.set_postfix(dev=device, refresh=False)
Expand Down
13 changes: 12 additions & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available
from .imports import (
is_npu_available,
is_torch_distributed_available,
is_torch_version,
is_tpu_available,
is_xpu_available,
)


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -171,6 +177,11 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
# `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
elif device == torch.device("npu"):
device = "npu:0"
elif is_xpu_available():
if isinstance(device, int):
device = f"xpu:{device}"
notsyncing marked this conversation as resolved.
Show resolved Hide resolved
elif device == torch.device("xpu"):
device = "xpu:0"
try:
return tensor.to(device, non_blocking=non_blocking)
except TypeError: # .to() doesn't accept non_blocking as kwarg
Expand Down
Loading