diff --git a/src/accelerate/test_utils/__init__.py b/src/accelerate/test_utils/__init__.py index 0dbee933ed3..36e78d15487 100644 --- a/src/accelerate/test_utils/__init__.py +++ b/src/accelerate/test_utils/__init__.py @@ -13,6 +13,8 @@ require_multi_gpu, require_multi_xpu, require_non_cpu, + require_non_xpu, + require_npu, require_pippy, require_single_device, require_single_gpu, diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index cdeee5fa995..b7b3e088ea4 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -129,6 +129,20 @@ def require_xpu(test_case): return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case) +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_npu(test_case): + """ + Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available. + """ + return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case) + + def require_mps(test_case): """ Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps` diff --git a/tests/test_kwargs_handlers.py b/tests/test_kwargs_handlers.py index ce817785a55..28c12b79a4c 100644 --- a/tests/test_kwargs_handlers.py +++ b/tests/test_kwargs_handlers.py @@ -21,7 +21,13 @@ from accelerate import Accelerator, DistributedDataParallelKwargs, GradScalerKwargs from accelerate.state import AcceleratorState -from accelerate.test_utils import device_count, execute_subprocess_async, require_multi_device, require_non_cpu +from accelerate.test_utils import ( + device_count, + execute_subprocess_async, + require_multi_device, + require_non_cpu, + require_non_xpu, +) from accelerate.utils import AutocastKwargs, KwargsHandler, TorchDynamoPlugin, clear_environment @@ -41,6 +47,7 @@ def test_kwargs_handler(self): self.assertDictEqual(MockClass(a=2, c=2.25).to_kwargs(), {"a": 2, "c": 2.25}) @require_non_cpu + @require_non_xpu def test_grad_scaler_kwargs(self): # If no defaults are changed, `to_kwargs` returns an empty dict. scaler_handler = GradScalerKwargs(init_scale=1024, growth_factor=2) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index ff67ec0979e..ad3210ecdb9 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -19,7 +19,7 @@ from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.test_utils import require_cpu, require_non_cpu +from accelerate.test_utils import require_cpu, require_non_cpu, require_non_xpu @require_cpu @@ -37,6 +37,7 @@ def test_accelerated_optimizer_pickling(self): @require_non_cpu +@require_non_xpu class OptimizerTester(unittest.TestCase): def test_accelerated_optimizer_step_was_skipped(self): model = torch.nn.Linear(5, 5)