diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5766608ec4155d..5a8e9f93ed917a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -748,6 +748,7 @@ "is_torch_npu_available", "is_torch_tpu_available", "is_torchvision_available", + "is_torch_xpu_available", "is_vision_available", "logging", ], @@ -4814,6 +4815,7 @@ is_torch_neuroncore_available, is_torch_npu_available, is_torch_tpu_available, + is_torch_xpu_available, is_torchvision_available, is_vision_available, logging, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 85b947d706aa4a..a7e36322a6b099 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -100,6 +100,7 @@ is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, + is_torch_xpu_available, is_torchaudio_available, is_torchdynamo_available, is_torchvision_available, @@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case): return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) +def require_torch_xpu(test_case): + """ + Decorator marking a test that requires XPU and IPEX. + + These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch + version. + """ + return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case) + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are + skipped on a machine without IPEX or multiple XPUs. + + To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" + """ + if not is_torch_xpu_available(): + return unittest.skip("test requires IPEX and atleast one XPU device")(test_case) + + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch @@ -641,6 +665,8 @@ def require_torch_multi_npu(test_case): torch_device = "cuda" elif _run_third_party_device_tests and is_torch_npu_available(): torch_device = "npu" + elif _run_third_party_device_tests and is_torch_xpu_available(): + torch_device = "xpu" else: torch_device = "cpu" diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 30571597c235d1..931d0067e99d23 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -38,6 +38,7 @@ is_torch_mps_available, is_torch_npu_available, is_torch_tpu_available, + is_torch_xpu_available, requires_backends, ) @@ -97,6 +98,8 @@ def set_seed(seed: int): # ^^ safe to call this function even if cuda is not available if is_torch_npu_available(): torch.npu.manual_seed_all(seed) + if is_torch_xpu_available(): + torch.xpu.manual_seed_all(seed) if is_tf_available(): tf.random.set_seed(seed) @@ -420,6 +423,11 @@ def __init__(self, skip_memory_metrics=False): elif is_torch_mps_available(): import torch + self.torch = torch + self.gpu = {} + elif is_torch_xpu_available(): + import torch + self.torch = torch self.gpu = {} else: @@ -472,12 +480,19 @@ def start(self): gc.collect() if self.torch is not None: - self.torch.cuda.reset_peak_memory_stats() - self.torch.cuda.empty_cache() + if torch.cuda.is_available(): + self.torch.cuda.reset_peak_memory_stats() + self.torch.cuda.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.reset_peak_memory_stats() + self.torch.xpu.empty_cache() # gpu if self.torch is not None: - self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() + if torch.cuda.is_available(): + self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() # cpu self.cpu_mem_used_at_start = self.cpu_mem_used() @@ -501,7 +516,10 @@ def stop(self, stage): gc.collect() if self.torch is not None: - self.torch.cuda.empty_cache() + if torch.cuda.is_available(): + self.torch.cuda.empty_cache() + elif is_torch_xpu_available(): + self.torch.xpu.empty_cache() # concepts: # - alloc_delta: the difference of allocated memory between the end and the start @@ -510,8 +528,15 @@ def stop(self, stage): # gpu if self.torch is not None: - self.gpu_mem_used_now = self.torch.cuda.memory_allocated() - self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated() + if torch.cuda.is_available(): + self.gpu_mem_used_now = self.torch.cuda.memory_allocated() + self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated() + elif is_torch_xpu_available(): + self.gpu_mem_used_now = self.torch.xpu.memory_allocated() + self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated() + else: + raise ValueError("No available GPU device found!") + self.gpu[self.cur_stage] = { "begin": self.gpu_mem_used_at_start, "end": self.gpu_mem_used_now, diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 68458a64b0eb96..a17966928a4fa8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -50,6 +50,7 @@ is_torch_npu_available, is_torch_tf32_available, is_torch_tpu_available, + is_torch_xpu_available, logging, requires_backends, ) @@ -194,9 +195,9 @@ class TrainingArguments: prediction_loss_only (`bool`, *optional*, defaults to `False`): When performing evaluation and generating predictions, only returns the loss. per_device_train_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU/MPS/NPU core/CPU for training. + The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training. per_device_eval_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation. + The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation. gradient_accumulation_steps (`int`, *optional*, defaults to 1): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. @@ -1357,11 +1358,20 @@ def __post_init__(self): if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): # cpu raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") - elif not self.use_cpu and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): - # gpu - raise ValueError( - "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" - ) + elif not self.use_cpu: + if torch.cuda.is_available() and not is_torch_bf16_gpu_available(): + # gpu + raise ValueError( + "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" + ) + elif not is_torch_xpu_available(): + # xpu + from .pytorch_utils import is_torch_greater_or_equal_than_1_12 + + if not is_torch_greater_or_equal_than_1_12: + raise ValueError( + "Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed" + ) if self.fp16 and self.bf16: raise ValueError("At most one of fp16 and bf16 can be True, but not both") @@ -1416,6 +1426,7 @@ def __post_init__(self): self.framework == "pt" and is_torch_available() and (self.device.type != "cuda") + and (self.device.type != "xpu") and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) != "TPU") and (self.device.type != "cpu") @@ -1423,7 +1434,7 @@ def __post_init__(self): ): raise ValueError( "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" - " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." + " (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX) or CPU/TPU/NeuronCore devices." ) if self.torchdynamo is not None: @@ -1779,6 +1790,10 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 torch.cuda.set_device(device) + elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + device = torch.device("xpu:0") + self._n_gpu = 1 elif is_sagemaker_dp_enabled(): self.distributed_state = PartialState(_use_sagemaker_dp=True) self._n_gpu = 1 @@ -1807,6 +1822,12 @@ def _setup_devices(self) -> "torch.device": elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): # Already set _n_gpu pass + elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU: + if "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + self._n_gpu = torch.xpu.device_count() + device = torch.device("xpu:0") + torch.xpu.set_device(device) elif self.distributed_state.distributed_type == DistributedType.NO: if self.use_mps_device: warnings.warn( @@ -1824,6 +1845,10 @@ def _setup_devices(self) -> "torch.device": elif self.use_cpu: device = torch.device("cpu") self._n_gpu = 0 + elif is_torch_xpu_available(): + device = torch.device("xpu:0") + torch.xpu.set_device(device) + self._n_gpu = 1 elif is_torch_npu_available(): device = torch.device("npu:0") torch.npu.set_device(device) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 050ccae9c03d5f..68c39c732e3c35 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -172,6 +172,7 @@ is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, + is_torch_xpu_available, is_torchaudio_available, is_torchdistx_available, is_torchdynamo_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 0045d3345b21be..ae76a78ce21707 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -528,6 +528,25 @@ def get_major_and_minor_from_version(full_version): return True +@lru_cache +def is_torch_xpu_available(check_device=False): + "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" + if not is_ipex_available(): + return False + + import intel_extension_for_pytorch # noqa: F401 + import torch + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + def is_bitsandbytes_available(): if not is_torch_available(): return False diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 5a7734b8ba161d..8f867cf0beba37 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -22,6 +22,7 @@ execute_subprocess_async, get_torch_dist_unique_port, require_torch_multi_gpu, + require_torch_multi_xpu, require_torch_neuroncore, require_torch_npu, ) @@ -158,6 +159,20 @@ def test_trainer(self): # successful return here == success - any errors would have caused an error in the sub-call +@require_torch_multi_xpu +class TestTrainerDistributedXPU(TestCasePlus): + def test_trainer(self): + distributed_args = f"""--nproc_per_node={torch.xpu.device_count()} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_trainer_distributed.py + """.split() + output_dir = self.get_auto_remove_tmp_dir() + args = f"--output_dir {output_dir}".split() + cmd = ["torchrun"] + distributed_args + args + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + if __name__ == "__main__": # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: #