diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index a1288b4edb3d..ff8b6aa9b41c 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -5,10 +5,15 @@ import torch +from packaging import version + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" -torch_device = "mps" if torch.backends.mps.is_available() else torch_device +is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") + +if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False):