Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def __init__(
pre_norm=True,
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None,
output_scale_factor=1.0,
Expand All @@ -479,6 +480,7 @@ def __init__(
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act

if groups_out is None:
groups_out = groups
Expand Down Expand Up @@ -570,7 +572,9 @@ def forward(self, input_tensor, temb):
hidden_states = self.conv1(hidden_states)

if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]

if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
Expand Down
27 changes: 27 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
Expand All @@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
Expand Down Expand Up @@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
Expand Down Expand Up @@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
Expand Down Expand Up @@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
Expand Down Expand Up @@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
Expand Down Expand Up @@ -562,6 +574,7 @@ def __init__(
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()

Expand All @@ -585,6 +598,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
Expand Down Expand Up @@ -615,6 +629,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)

Expand Down Expand Up @@ -1247,6 +1262,7 @@ def __init__(
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
Expand All @@ -1265,6 +1281,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)

Expand All @@ -1284,6 +1301,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
Expand Down Expand Up @@ -1337,6 +1355,7 @@ def __init__(
cross_attention_dim=1280,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()

Expand All @@ -1362,6 +1381,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
Expand Down Expand Up @@ -1394,6 +1414,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
Expand Down Expand Up @@ -2237,6 +2258,7 @@ def __init__(
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
Expand All @@ -2257,6 +2279,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)

Expand All @@ -2276,6 +2299,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]
Expand Down Expand Up @@ -2329,6 +2353,7 @@ def __init__(
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
Expand All @@ -2355,6 +2380,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
Expand Down Expand Up @@ -2387,6 +2413,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
Expand Down Expand Up @@ -291,6 +293,8 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)

Expand Down Expand Up @@ -321,6 +325,7 @@ def __init__(
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
Expand Down Expand Up @@ -369,6 +374,8 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __init__(
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
Expand Down Expand Up @@ -382,6 +384,8 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)

Expand Down Expand Up @@ -412,6 +416,7 @@ def __init__(
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
Expand Down Expand Up @@ -460,6 +465,8 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -1434,6 +1441,7 @@ def __init__(
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()

Expand All @@ -1457,6 +1465,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
Expand Down Expand Up @@ -1487,6 +1496,7 @@ def __init__(
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)

Expand Down