Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ [Core] Add FreeU mechanism #5164

Merged
merged 63 commits into from
Oct 5, 2023
Merged
Changes from 3 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
eec915a
✨ Added Fourier filter function to upsample blocks
kadirnar Sep 24, 2023
8f2ee6b
🔧 Update Fourier_filter for float16 support
kadirnar Sep 24, 2023
ea33fc5
✨ Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️
kadirnar Sep 24, 2023
6d77fff
move unet to its original form and add fourier_filter to torch_utils.
sayakpaul Sep 27, 2023
7c4676b
implement freeU enable mechanism
sayakpaul Sep 27, 2023
a152d91
implement disable mechanism
sayakpaul Sep 27, 2023
c0de18d
resolution index.
sayakpaul Sep 27, 2023
a0ced0e
correct resolution idx condition.
sayakpaul Sep 27, 2023
19d5ab0
fix copies.
sayakpaul Sep 27, 2023
f3f9441
no need to use resolution_idx in vae.
sayakpaul Sep 27, 2023
266987c
spell out the kwargs
sayakpaul Sep 27, 2023
83846f5
proper config property
sayakpaul Sep 27, 2023
cb6ebe8
fix attribution setting
sayakpaul Sep 27, 2023
9c8eacd
place unet hasattr properly.
sayakpaul Sep 27, 2023
988ef76
fix: attribute access.
sayakpaul Sep 27, 2023
755a5c8
proper disable
sayakpaul Sep 27, 2023
37b091e
remove validation method.
sayakpaul Sep 27, 2023
4fd4adf
debug
sayakpaul Sep 27, 2023
40291ba
debug
sayakpaul Sep 27, 2023
2e15e94
debug
sayakpaul Sep 27, 2023
1e9c79a
debug
sayakpaul Sep 27, 2023
883fc9b
debug
sayakpaul Sep 27, 2023
ba08f30
debug
sayakpaul Sep 27, 2023
540974a
potential fix.
sayakpaul Sep 27, 2023
d38c251
add: doc.
sayakpaul Sep 27, 2023
785c0a0
fix copies
sayakpaul Sep 27, 2023
466e054
add: tests.
sayakpaul Sep 27, 2023
80e560e
add: support freeU in SDXL.
sayakpaul Sep 28, 2023
670b34b
set default value of resolution idx.
sayakpaul Sep 28, 2023
c5fc938
set default values for resolution_idx.
sayakpaul Sep 28, 2023
05ea56e
fix copies
sayakpaul Sep 28, 2023
93ee867
fix rest.
sayakpaul Sep 28, 2023
535eb59
fix copies
sayakpaul Sep 28, 2023
d277e64
address PR comments.
sayakpaul Sep 28, 2023
bb2d368
run fix-copies
sayakpaul Sep 28, 2023
e95b186
move apply_free_u to utils and other minors.
sayakpaul Sep 30, 2023
902cf7d
introduce support for video (unet3D)
sayakpaul Sep 30, 2023
15b1052
minor ups
sayakpaul Oct 2, 2023
7dcc939
consistent fix-copies.
sayakpaul Oct 2, 2023
3bf28bb
consistent stuff
sayakpaul Oct 2, 2023
64ade67
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
08d61fa
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
d68663c
fix-copies
sayakpaul Oct 2, 2023
0e0af08
add: rest
sayakpaul Oct 2, 2023
8d0a204
add: docs.
sayakpaul Oct 2, 2023
55ad535
fix: tests
sayakpaul Oct 2, 2023
797b4b9
fix: doc path
sayakpaul Oct 2, 2023
86419b9
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
5e27ff7
Merge branch 'main' into add-freeU
patrickvonplaten Oct 4, 2023
2ad4953
Merge branch 'main' into add-freeU
sayakpaul Oct 5, 2023
1a8e5d1
Apply suggestions from code review
sayakpaul Oct 5, 2023
adc9d5c
style up
sayakpaul Oct 5, 2023
c4f99d4
move to techniques.
sayakpaul Oct 5, 2023
518e4b1
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
dee3781
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
fc39d22
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
053f3ed
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
d8ef3a1
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
3da96e2
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
d8c8771
add: slow test for video with freeu
sayakpaul Oct 5, 2023
0d34cf0
add: slow test for video with freeu
sayakpaul Oct 5, 2023
8e72f85
add: slow test for video with freeu
sayakpaul Oct 5, 2023
aa1a061
style
sayakpaul Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,54 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

@dataclass
class UNetFreeUConfig:
enabled: bool = False
s1: float = 1.0
s2: float = 1.0
b1: float = 1.0
b2: float = 1.0

def sd21(self):
self.s1 = 0.9
self.s2 = 0.2
self.b1 = 1.1
self.b2 = 1.2

def ones(self):
self.s1 = 1.0
self.s2 = 1.0
self.b1 = 1.0
self.b2 = 1.0

def fourier_filter(x_in, threshold, scale):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch.fft import fftn, ifftn, fftshift, ifftshift

x = x_in
B, C, H, W = x.shape

# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)

# FFT
x_freq = fftn(x, dim=(-2, -1))
x_freq = fftshift(x_freq, dim=(-2, -1))

B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)

crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask

# IFFT
x_freq = ifftshift(x_freq, dim=(-2, -1))
x_filtered = ifftn(x_freq, dim=(-2, -1)).real

return x_filtered.to(dtype=x_in.dtype)


@dataclass
class UNet2DConditionOutput(BaseOutput):
Expand Down Expand Up @@ -212,11 +260,12 @@ def __init__(
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
freeu = UNetFreeUConfig()
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()

self.sample_size = sample_size

self.freeu = freeu
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
Expand Down Expand Up @@ -1008,6 +1057,22 @@ def forward(
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

# FreeU: Free Lunch in Diffusion U-Net: https://arxiv.org/pdf/2309.11497.pdf
if self.freeu.enabled:
if sample.shape[1] == 1280:
sample[:, :640] *= self.freeu.b1
res_samples = tuple(
fourier_filter(res_sample, threshold=1, scale=self.freeu.s1)
for res_sample in res_samples
)

if sample.shape[1] == 640:
sample[:, :320] *= self.freeu.b2
res_samples = tuple(
fourier_filter(res_sample, threshold=1, scale=self.freeu.s2)
for res_sample in res_samples
)

# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
Expand Down