From bbcba8d8b20fae37b6a75c373c7e257345016dd6 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sat, 8 Apr 2023 18:12:46 -0700 Subject: [PATCH] resnet skip time activation and output scale factor --- src/diffusers/models/resnet.py | 6 ++++- src/diffusers/models/unet_2d_blocks.py | 27 +++++++++++++++++++ src/diffusers/models/unet_2d_condition.py | 7 +++++ .../versatile_diffusion/modeling_text_unet.py | 10 +++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 98f8f19c896a..d9d539959c09 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -459,6 +459,7 @@ def __init__( pre_norm=True, eps=1e-6, non_linearity="swish", + skip_time_act=False, time_embedding_norm="default", # default, scale_shift, ada_group kernel=None, output_scale_factor=1.0, @@ -479,6 +480,7 @@ def __init__( self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups @@ -570,7 +572,9 @@ def forward(self, input_tensor, temb): hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..0aeca6f508d0 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -42,6 +42,8 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -68,6 +70,8 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": return AttnDownBlock2D( @@ -119,6 +123,8 @@ def get_down_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -214,6 +220,8 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -241,6 +249,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -279,6 +289,8 @@ def get_up_block( cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -562,6 +574,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, + skip_time_act=False, ): super().__init__() @@ -585,6 +598,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ] attentions = [] @@ -615,6 +629,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -1247,6 +1262,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, + skip_time_act=False, ): super().__init__() resnets = [] @@ -1265,6 +1281,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -1284,6 +1301,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, down=True, ) ] @@ -1337,6 +1355,7 @@ def __init__( cross_attention_dim=1280, output_scale_factor=1.0, add_downsample=True, + skip_time_act=False, ): super().__init__() @@ -1362,6 +1381,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) attentions.append( @@ -1394,6 +1414,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, down=True, ) ] @@ -2237,6 +2258,7 @@ def __init__( resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + skip_time_act=False, ): super().__init__() resnets = [] @@ -2257,6 +2279,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) @@ -2276,6 +2299,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, up=True, ) ] @@ -2329,6 +2353,7 @@ def __init__( cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, + skip_time_act=False, ): super().__init__() resnets = [] @@ -2355,6 +2380,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) attentions.append( @@ -2387,6 +2413,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, up=True, ) ] diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4d237286fb32..263304cf5454 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -146,6 +146,8 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, @@ -291,6 +293,8 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, ) self.down_blocks.append(down_block) @@ -321,6 +325,7 @@ def __init__( attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, ) elif mid_block_type is None: self.mid_block = None @@ -369,6 +374,8 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index deaa709ab319..a2e85043f971 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -232,6 +232,8 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, @@ -382,6 +384,8 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, ) self.down_blocks.append(down_block) @@ -412,6 +416,7 @@ def __init__( attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, ) elif mid_block_type is None: self.mid_block = None @@ -460,6 +465,8 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -1434,6 +1441,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, cross_attention_dim=1280, + skip_time_act=False, ): super().__init__() @@ -1457,6 +1465,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ] attentions = [] @@ -1487,6 +1496,7 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) )