From 15c08a18b8b04e65c5d87660b9136ee5b9c54aed Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Thu, 1 Feb 2024 10:56:21 +0100 Subject: [PATCH 1/2] Update testing_utils.py --- src/diffusers/utils/testing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 86e31eb688cd..87551d1ceb60 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -840,7 +840,7 @@ def _is_torch_fp16_available(device): return True except Exception as e: - if device.type == "cuda": + if device == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" ) @@ -860,7 +860,7 @@ def _is_torch_fp64_available(device): return True except Exception as e: - if device.type == "cuda": + if device == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" ) From 1630bac667282facebc3e58232cb16974493f7d6 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 2 Feb 2024 09:10:13 +0100 Subject: [PATCH 2/2] Update testing_utils.py --- src/diffusers/utils/testing_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 87551d1ceb60..edbf6f31a833 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -840,7 +840,7 @@ def _is_torch_fp16_available(device): return True except Exception as e: - if device == "cuda": + if device.type == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" ) @@ -854,13 +854,15 @@ def _is_torch_fp64_available(device): import torch + device = torch.device(device) + try: x = torch.zeros((2, 2), dtype=torch.float64).to(device) _ = torch.mul(x, x) return True except Exception as e: - if device == "cuda": + if device.type == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" )