Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch with accelerate xpu #25714

Merged
merged 22 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
"is_torch_xpu_available",
"is_vision_available",
"logging",
],
Expand Down Expand Up @@ -4814,6 +4815,7 @@
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchvision_available,
is_vision_available,
logging,
Expand Down
26 changes: 26 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_torchvision_available,
Expand Down Expand Up @@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case):
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)


def require_torch_xpu(test_case):
"""
Decorator marking a test that requires XPU and IPEX.

These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
"""
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)


def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are
skipped on a machine without IPEX or multiple XPUs.

To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
"""
if not is_torch_xpu_available():
return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)

return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)


if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
Expand All @@ -641,6 +665,8 @@ def require_torch_multi_npu(test_case):
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"
elif _run_third_party_device_tests and is_torch_xpu_available():
torch_device = "xpu"
else:
torch_device = "cpu"

Expand Down
37 changes: 31 additions & 6 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_torch_mps_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xpu_available,
requires_backends,
)

Expand Down Expand Up @@ -97,6 +98,8 @@ def set_seed(seed: int):
# ^^ safe to call this function even if cuda is not available
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)
if is_tf_available():
tf.random.set_seed(seed)

Expand Down Expand Up @@ -420,6 +423,11 @@ def __init__(self, skip_memory_metrics=False):
elif is_torch_mps_available():
import torch

self.torch = torch
self.gpu = {}
elif is_torch_xpu_available():
import torch

self.torch = torch
self.gpu = {}
else:
Expand Down Expand Up @@ -472,12 +480,19 @@ def start(self):
gc.collect()

if self.torch is not None:
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
if torch.cuda.is_available():
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()

# gpu
if self.torch is not None:
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
if torch.cuda.is_available():
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()

# cpu
self.cpu_mem_used_at_start = self.cpu_mem_used()
Expand All @@ -501,7 +516,10 @@ def stop(self, stage):
gc.collect()

if self.torch is not None:
self.torch.cuda.empty_cache()
if torch.cuda.is_available():
self.torch.cuda.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()

# concepts:
# - alloc_delta: the difference of allocated memory between the end and the start
Expand All @@ -510,8 +528,15 @@ def stop(self, stage):

# gpu
if self.torch is not None:
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
if torch.cuda.is_available():
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
else:
raise ValueError("No available GPU device found!")

self.gpu[self.cur_stage] = {
"begin": self.gpu_mem_used_at_start,
"end": self.gpu_mem_used_now,
Expand Down
41 changes: 33 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_torch_npu_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
logging,
requires_backends,
)
Expand Down Expand Up @@ -194,9 +195,9 @@ class TrainingArguments:
prediction_loss_only (`bool`, *optional*, defaults to `False`):
When performing evaluation and generating predictions, only returns the loss.
per_device_train_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU/MPS/NPU core/CPU for training.
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation.
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.

Expand Down Expand Up @@ -1357,11 +1358,20 @@ def __post_init__(self):
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.use_cpu and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif not self.use_cpu:
if torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif not is_torch_xpu_available():
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
# xpu
from .pytorch_utils import is_torch_greater_or_equal_than_1_12

if not is_torch_greater_or_equal_than_1_12:
raise ValueError(
"Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
)

if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
Expand Down Expand Up @@ -1416,14 +1426,15 @@ def __post_init__(self):
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX) or CPU/TPU/NeuronCore devices."
)

if self.torchdynamo is not None:
Expand Down Expand Up @@ -1779,6 +1790,10 @@ def _setup_devices(self) -> "torch.device":
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
device = torch.device("xpu:0")
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
Expand Down Expand Up @@ -1807,6 +1822,12 @@ def _setup_devices(self) -> "torch.device":
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
# Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
if "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self._n_gpu = torch.xpu.device_count()
device = torch.device("xpu:0")
torch.xpu.set_device(device)
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
warnings.warn(
Expand All @@ -1824,6 +1845,10 @@ def _setup_devices(self) -> "torch.device":
elif self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,25 @@ def get_major_and_minor_from_version(full_version):
return True


@lru_cache
def is_torch_xpu_available(check_device=False):
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if not is_ipex_available():
return False

import intel_extension_for_pytorch # noqa: F401
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
import torch

if check_device:
try:
# Will raise a RuntimeError if no XPU is found
_ = torch.xpu.device_count()
return torch.xpu.is_available()
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


def is_bitsandbytes_available():
if not is_torch_available():
return False
Expand Down
15 changes: 15 additions & 0 deletions tests/trainer/test_trainer_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
require_torch_multi_xpu,
require_torch_neuroncore,
require_torch_npu,
)
Expand Down Expand Up @@ -158,6 +159,20 @@ def test_trainer(self):
# successful return here == success - any errors would have caused an error in the sub-call


@require_torch_multi_xpu
class TestTrainerDistributedXPU(TestCasePlus):
def test_trainer(self):
distributed_args = f"""--nproc_per_node={torch.xpu.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir}".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call


if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#
Expand Down
Loading