diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a73ad4acf3c3..07036a4ee049 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -223,8 +223,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T # Non-power of 2 images must be float32 if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: x = x.to(dtype=torch.float32) - # fftn does not support bfloat16 - elif x.dtype == torch.bfloat16: + # fftn does not support bfloat16, and produces the experimental ComplexHalf + # dtype (torch.complex32) when given float16, which is numerically unstable + # and triggers a UserWarning. Upcast any non-float32 dtype to float32. + elif x.dtype != torch.float32: x = x.to(dtype=torch.float32) # FFT diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py index bb0656386394..7b445e3a21bd 100755 --- a/tests/others/test_utils.py +++ b/tests/others/test_utils.py @@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self): ), f"Expected deprecation message substring not found, got: {messages}" +class FourierFilterTester(unittest.TestCase): + """Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper).""" + + def _run_without_complexhalf_warning(self, dtype): + import torch + + from diffusers.utils.torch_utils import fourier_filter + + x = torch.randn(1, 4, 32, 32, dtype=dtype) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + out = fourier_filter(x, threshold=1, scale=0.5) + + messages = [str(w.message) for w in caught] + assert not any("ComplexHalf" in m for m in messages), ( + f"Unexpected ComplexHalf warning emitted by fourier_filter: {messages}" + ) + return out + + def test_fourier_filter_float16_no_complexhalf_warning(self): + import torch + + out = self._run_without_complexhalf_warning(torch.float16) + assert out.dtype == torch.float16 + + def test_fourier_filter_bfloat16_no_complexhalf_warning(self): + import torch + + out = self._run_without_complexhalf_warning(torch.bfloat16) + assert out.dtype == torch.bfloat16 + + def test_fourier_filter_preserves_dtype_and_shape(self): + import torch + + from diffusers.utils.torch_utils import fourier_filter + + for dtype in (torch.float32, torch.float16, torch.bfloat16): + x = torch.randn(2, 3, 16, 16, dtype=dtype) + out = fourier_filter(x, threshold=1, scale=0.5) + assert out.dtype == dtype + assert out.shape == x.shape + + # Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py class ExpectationsTester(unittest.TestCase): def test_expectations(self):