Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ [Core] Add FreeU mechanism #5164

Merged
merged 63 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
eec915a
✨ Added Fourier filter function to upsample blocks
kadirnar Sep 24, 2023
8f2ee6b
🔧 Update Fourier_filter for float16 support
kadirnar Sep 24, 2023
ea33fc5
✨ Added UNetFreeUConfig to UNet model for FreeU adaptation 🛠️
kadirnar Sep 24, 2023
6d77fff
move unet to its original form and add fourier_filter to torch_utils.
sayakpaul Sep 27, 2023
7c4676b
implement freeU enable mechanism
sayakpaul Sep 27, 2023
a152d91
implement disable mechanism
sayakpaul Sep 27, 2023
c0de18d
resolution index.
sayakpaul Sep 27, 2023
a0ced0e
correct resolution idx condition.
sayakpaul Sep 27, 2023
19d5ab0
fix copies.
sayakpaul Sep 27, 2023
f3f9441
no need to use resolution_idx in vae.
sayakpaul Sep 27, 2023
266987c
spell out the kwargs
sayakpaul Sep 27, 2023
83846f5
proper config property
sayakpaul Sep 27, 2023
cb6ebe8
fix attribution setting
sayakpaul Sep 27, 2023
9c8eacd
place unet hasattr properly.
sayakpaul Sep 27, 2023
988ef76
fix: attribute access.
sayakpaul Sep 27, 2023
755a5c8
proper disable
sayakpaul Sep 27, 2023
37b091e
remove validation method.
sayakpaul Sep 27, 2023
4fd4adf
debug
sayakpaul Sep 27, 2023
40291ba
debug
sayakpaul Sep 27, 2023
2e15e94
debug
sayakpaul Sep 27, 2023
1e9c79a
debug
sayakpaul Sep 27, 2023
883fc9b
debug
sayakpaul Sep 27, 2023
ba08f30
debug
sayakpaul Sep 27, 2023
540974a
potential fix.
sayakpaul Sep 27, 2023
d38c251
add: doc.
sayakpaul Sep 27, 2023
785c0a0
fix copies
sayakpaul Sep 27, 2023
466e054
add: tests.
sayakpaul Sep 27, 2023
80e560e
add: support freeU in SDXL.
sayakpaul Sep 28, 2023
670b34b
set default value of resolution idx.
sayakpaul Sep 28, 2023
c5fc938
set default values for resolution_idx.
sayakpaul Sep 28, 2023
05ea56e
fix copies
sayakpaul Sep 28, 2023
93ee867
fix rest.
sayakpaul Sep 28, 2023
535eb59
fix copies
sayakpaul Sep 28, 2023
d277e64
address PR comments.
sayakpaul Sep 28, 2023
bb2d368
run fix-copies
sayakpaul Sep 28, 2023
e95b186
move apply_free_u to utils and other minors.
sayakpaul Sep 30, 2023
902cf7d
introduce support for video (unet3D)
sayakpaul Sep 30, 2023
15b1052
minor ups
sayakpaul Oct 2, 2023
7dcc939
consistent fix-copies.
sayakpaul Oct 2, 2023
3bf28bb
consistent stuff
sayakpaul Oct 2, 2023
64ade67
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
08d61fa
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
d68663c
fix-copies
sayakpaul Oct 2, 2023
0e0af08
add: rest
sayakpaul Oct 2, 2023
8d0a204
add: docs.
sayakpaul Oct 2, 2023
55ad535
fix: tests
sayakpaul Oct 2, 2023
797b4b9
fix: doc path
sayakpaul Oct 2, 2023
86419b9
Merge branch 'main' into add-freeU
sayakpaul Oct 2, 2023
5e27ff7
Merge branch 'main' into add-freeU
patrickvonplaten Oct 4, 2023
2ad4953
Merge branch 'main' into add-freeU
sayakpaul Oct 5, 2023
1a8e5d1
Apply suggestions from code review
sayakpaul Oct 5, 2023
adc9d5c
style up
sayakpaul Oct 5, 2023
c4f99d4
move to techniques.
sayakpaul Oct 5, 2023
518e4b1
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
dee3781
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
fc39d22
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
053f3ed
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
d8ef3a1
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
3da96e2
add: slow test for sd freeu.
sayakpaul Oct 5, 2023
d8c8771
add: slow test for video with freeu
sayakpaul Oct 5, 2023
0d34cf0
add: slow test for video with freeu
sayakpaul Oct 5, 2023
8e72f85
add: slow test for video with freeu
sayakpaul Oct 5, 2023
aa1a061
style
sayakpaul Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

from ..utils import is_torch_version, logging
from ..utils.torch_utils import fourier_filter
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
Expand Down Expand Up @@ -249,6 +250,7 @@ def get_up_block(
add_upsample,
resnet_eps,
resnet_act_fn,
resolution_idx,
transformer_layers_per_block=1,
num_attention_heads=None,
resnet_groups=None,
Expand Down Expand Up @@ -281,6 +283,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -295,6 +298,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -314,6 +318,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -337,6 +342,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -362,6 +368,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
Expand All @@ -377,6 +384,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -390,6 +398,7 @@ def get_up_block(
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -402,6 +411,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -415,6 +425,7 @@ def get_up_block(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -430,6 +441,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand All @@ -441,6 +453,7 @@ def get_up_block(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resolution_idx=resolution_idx,
dropout=dropout,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
Expand Down Expand Up @@ -1993,6 +2006,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2075,6 +2089,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
Expand Down Expand Up @@ -2103,6 +2119,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
Expand Down Expand Up @@ -2181,6 +2198,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand All @@ -2194,11 +2212,34 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)

for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# Courtesy:
# https://github.com/ChenyangSi/FreeU
# https://github.com/lyn-rgb/FreeU_Diffusers
if is_freeu_enabled is not None:
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if self.resolution_idx == 0:
num_half_channels = hidden_states.shape[1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * self.b1
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if self.resolution_idx == 1:
num_half_channels = hidden_states.shape[1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * self.b2
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
# ---------------------------------------------------------
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -2252,6 +2293,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2292,12 +2334,36 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)

for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]

# Courtesy:
# https://github.com/ChenyangSi/FreeU
# https://github.com/lyn-rgb/FreeU_Diffusers
if is_freeu_enabled:
# --------------- FreeU code -----------------------
# Only operate on the first two stages
if self.resolution_idx == 0:
num_half_channels = hidden_states.shape[-1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * self.b1
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
if self.resolution_idx == 1:
num_half_channels = hidden_states.shape[-1] // 2
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * self.b2
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=self.s2)

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
Expand Down Expand Up @@ -2331,6 +2397,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2370,6 +2437,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
Expand All @@ -2386,6 +2455,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2449,6 +2519,8 @@ def __init__(
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
Expand All @@ -2469,6 +2541,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2553,6 +2626,8 @@ def __init__(
self.skip_norm = None
self.act = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2589,6 +2664,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2651,6 +2727,8 @@ def __init__(
self.skip_norm = None
self.act = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2684,6 +2762,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2743,6 +2822,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
Expand Down Expand Up @@ -2784,6 +2864,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
Expand Down Expand Up @@ -2873,6 +2954,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand Down Expand Up @@ -2947,6 +3029,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 5,
resnet_eps: float = 1e-5,
Expand Down Expand Up @@ -2988,6 +3071,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
Expand Down Expand Up @@ -3027,6 +3111,7 @@ def __init__(
in_channels: int,
out_channels: int,
temb_channels: int,
resolution_idx: int,
dropout: float = 0.0,
num_layers: int = 4,
resnet_eps: float = 1e-5,
Expand Down Expand Up @@ -3104,6 +3189,7 @@ def __init__(
self.upsamplers = None

self.gradient_checkpointing = False
self.resolution_idx = resolution_idx

def forward(
self,
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def __init__(
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
Expand Down Expand Up @@ -731,6 +732,20 @@ def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def enable_freeu(self, **kwargs):
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "b1", kwargs["b1"])
setattr(upsample_block, "b2", kwargs["b2"])
setattr(upsample_block, "s1", kwargs["s1"])
setattr(upsample_block, "s2", kwargs["s2"])

def disable_freeu(self):
freeu_keys = {"s1", "s2", "b1", "b2"}
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(self, k, None)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
resolution_idx=None,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down
Loading
Loading