diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 52317fe6715..567c70499e9 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -53,6 +53,7 @@ is_torch_xla_available, is_torchvision_available, is_transformers_available, + is_triton_available, is_wandb_available, is_xpu_available, str_to_bool, @@ -213,7 +214,7 @@ def require_transformers(test_case): def require_timm(test_case): """ - Decorator marking a test that requires transformers. These tests are skipped when they are not. + Decorator marking a test that requires timm. These tests are skipped when they are not. """ return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case) @@ -225,6 +226,13 @@ def require_torchvision(test_case): return unittest.skipUnless(is_torchvision_available(), "test requires the torchvision library")(test_case) +def require_triton(test_case): + """ + Decorator marking a test that requires triton. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_triton_available(), "test requires the triton library")(test_case) + + def require_schedulefree(test_case): """ Decorator marking a test that requires schedulefree. These tests are skipped when they are not. diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 43061d0f5c9..c21b8a3232a 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -109,6 +109,7 @@ is_torchvision_available, is_transformer_engine_available, is_transformers_available, + is_triton_available, is_wandb_available, is_xpu_available, ) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 571a07be072..669044c3910 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -248,6 +248,10 @@ def is_timm_available(): return _is_package_available("timm") +def is_triton_available(): + return _is_package_available("triton") + + def is_aim_available(): package_exists = _is_package_available("aim") if package_exists: diff --git a/tests/test_utils.py b/tests/test_utils.py index 30ab53004fc..aa95ed9703f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,12 +27,12 @@ from accelerate.state import PartialState from accelerate.test_utils.testing import ( - require_cuda, require_huggingface_suite, require_non_cpu, require_non_torch_xla, require_torch_min_version, require_tpu, + require_triton, torch_device, ) from accelerate.test_utils.training import RegressionModel @@ -190,15 +190,16 @@ def test_can_undo_fp16_conversion(self): model = extract_model_from_parallel(model, keep_fp32_wrapper=False) _ = pickle.dumps(model) - @require_cuda + @require_triton + @require_non_cpu @require_torch_min_version(version="2.0") def test_dynamo(self): model = RegressionModel() model._original_forward = model.forward - model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward) + model.forward = torch.autocast(device_type=torch_device, dtype=torch.float16)(model.forward) model.forward = convert_outputs_to_fp32(model.forward) model.forward = torch.compile(model.forward, backend="inductor") - inputs = torch.randn(4, 10).cuda() + inputs = torch.randn(4, 10).to(torch_device) _ = model(inputs) def test_extract_model(self):