Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# fftn does not support bfloat16
elif x.dtype == torch.bfloat16:
# fftn does not support bfloat16, and produces the experimental ComplexHalf
# dtype (torch.complex32) when given float16, which is numerically unstable
# and triggers a UserWarning. Upcast any non-float32 dtype to float32.
elif x.dtype != torch.float32:
x = x.to(dtype=torch.float32)

# FFT
Expand Down
43 changes: 43 additions & 0 deletions tests/others/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self):
), f"Expected deprecation message substring not found, got: {messages}"


class FourierFilterTester(unittest.TestCase):
"""Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper)."""

def _run_without_complexhalf_warning(self, dtype):
import torch

from diffusers.utils.torch_utils import fourier_filter

x = torch.randn(1, 4, 32, 32, dtype=dtype)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
out = fourier_filter(x, threshold=1, scale=0.5)

messages = [str(w.message) for w in caught]
assert not any("ComplexHalf" in m for m in messages), (
f"Unexpected ComplexHalf warning emitted by fourier_filter: {messages}"
)
return out

def test_fourier_filter_float16_no_complexhalf_warning(self):
import torch

out = self._run_without_complexhalf_warning(torch.float16)
assert out.dtype == torch.float16

def test_fourier_filter_bfloat16_no_complexhalf_warning(self):
import torch

out = self._run_without_complexhalf_warning(torch.bfloat16)
assert out.dtype == torch.bfloat16

def test_fourier_filter_preserves_dtype_and_shape(self):
import torch

from diffusers.utils.torch_utils import fourier_filter

for dtype in (torch.float32, torch.float16, torch.bfloat16):
x = torch.randn(2, 3, 16, 16, dtype=dtype)
out = fourier_filter(x, threshold=1, scale=0.5)
assert out.dtype == dtype
assert out.shape == x.shape


# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):
def test_expectations(self):
Expand Down
Loading