Skip to content

Commit

Permalink
https://github.com/huggingface/diffusers/pull/5164
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru committed Sep 24, 2023
1 parent 48664d6 commit b768795
Showing 1 changed file with 66 additions and 1 deletion.
67 changes: 66 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,54 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

@dataclass
class UNetFreeUConfig:
enabled: bool = False
s1: float = 1.0
s2: float = 1.0
b1: float = 1.0
b2: float = 1.0

def sd21(self):
self.s1 = 0.9
self.s2 = 0.2
self.b1 = 1.1
self.b2 = 1.2

def ones(self):
self.s1 = 1.0
self.s2 = 1.0
self.b1 = 1.0
self.b2 = 1.0

def fourier_filter(x_in, threshold, scale):
import torch
from torch.fft import fftn, ifftn, fftshift, ifftshift

x = x_in
B, C, H, W = x.shape

# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)

# FFT
x_freq = fftn(x, dim=(-2, -1))
x_freq = fftshift(x_freq, dim=(-2, -1))

B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)

crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask

# IFFT
x_freq = ifftshift(x_freq, dim=(-2, -1))
x_filtered = ifftn(x_freq, dim=(-2, -1)).real

return x_filtered.to(dtype=x_in.dtype)


@dataclass
class UNet2DConditionOutput(BaseOutput):
Expand Down Expand Up @@ -212,11 +260,12 @@ def __init__(
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
freeu = UNetFreeUConfig()
):
super().__init__()

self.sample_size = sample_size

self.freeu = freeu
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
Expand Down Expand Up @@ -1008,6 +1057,22 @@ def forward(
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

# FreeU: Free Lunch in Diffusion U-Net: https://arxiv.org/pdf/2309.11497.pdf
if self.freeu.enabled:
if sample.shape[1] == 1280:
sample[:, :640] *= self.freeu.b1
res_samples = tuple(
fourier_filter(res_sample, threshold=1, scale=self.freeu.s1)
for res_sample in res_samples
)

if sample.shape[1] == 640:
sample[:, :320] *= self.freeu.b2
res_samples = tuple(
fourier_filter(res_sample, threshold=1, scale=self.freeu.s2)
for res_sample in res_samples
)

# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
Expand Down

0 comments on commit b768795

Please sign in to comment.