From 4678f9c1c6ff1634bf4d41bd53a107a77e0efdf3 Mon Sep 17 00:00:00 2001 From: Mathis Koroglu Date: Tue, 21 May 2024 11:21:45 +0200 Subject: [PATCH 1/3] Motion Model / Adapter versatility - allow to use a different number of layers per block - allow to use a different number of transformer per layers per block - allow a different number of motion attention head per block - use dropout argument in get_down/up_block in 3d blocks --- src/diffusers/models/unets/unet_3d_blocks.py | 127 +++++++++++-- .../models/unets/unet_motion_model.py | 169 +++++++++++++++--- 2 files changed, 259 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 35a732bdb9ec..0c3e3abe2087 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -58,7 +58,9 @@ def get_down_block( resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + dropout: float = 0.0, ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", @@ -79,6 +81,7 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: @@ -100,6 +103,7 @@ def get_down_block( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, ) if down_block_type == "DownBlockMotion": return DownBlockMotion( @@ -115,6 +119,8 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, + dropout=dropout, ) elif down_block_type == "CrossAttnDownBlockMotion": if cross_attention_dim is None: @@ -139,6 +145,8 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, + dropout=dropout, ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV @@ -189,7 +197,8 @@ def get_up_block( temporal_num_attention_heads: int = 8, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, dropout: float = 0.0, ) -> Union[ "UpBlock3D", @@ -212,6 +221,7 @@ def get_up_block( resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, + dropout=dropout, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: @@ -234,6 +244,7 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, + dropout=dropout, ) if up_block_type == "UpBlockMotion": return UpBlockMotion( @@ -250,6 +261,8 @@ def get_up_block( resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, + dropout=dropout, ) elif up_block_type == "CrossAttnUpBlockMotion": if cross_attention_dim is None: @@ -275,6 +288,8 @@ def get_up_block( resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, + dropout=dropout, ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV @@ -948,14 +963,31 @@ def __init__( output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, - temporal_num_attention_heads: int = 1, + temporal_num_attention_heads: Union[int, Tuple[int]] = 1, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] motion_modules = [] + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -974,15 +1006,16 @@ def __init__( ) motion_modules.append( TransformerTemporalModel( - num_attention_heads=temporal_num_attention_heads, + num_attention_heads=temporal_num_attention_heads[i], in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, activation_fn="geglu", positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, - attention_head_dim=out_channels // temporal_num_attention_heads, + attention_head_dim=out_channels // temporal_num_attention_heads[i], ) ) @@ -1065,7 +1098,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1084,6 +1117,7 @@ def __init__( temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] @@ -1093,6 +1127,22 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -1116,7 +1166,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1141,6 +1191,7 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, @@ -1257,7 +1308,7 @@ def __init__( resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1275,6 +1326,7 @@ def __init__( temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] @@ -1284,6 +1336,22 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1309,7 +1377,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1333,6 +1401,7 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, @@ -1467,11 +1536,20 @@ def __init__( temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] motion_modules = [] + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1495,6 +1573,7 @@ def __init__( TransformerTemporalModel( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=temporal_norm_num_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, @@ -1596,7 +1675,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1605,13 +1684,14 @@ def __init__( num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, - dual_cross_attention: float = False, - use_linear_projection: float = False, - upcast_attention: float = False, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, attention_type: str = "default", temporal_num_attention_heads: int = 1, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() @@ -1619,6 +1699,22 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -1637,14 +1733,14 @@ def __init__( attentions = [] motion_modules = [] - for _ in range(num_layers): + for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1682,6 +1778,7 @@ def __init__( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index b224d9d73317..a51e51b60ad6 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -57,7 +57,8 @@ def __init__( self, in_channels: int, layers_per_block: int = 2, - num_attention_heads: int = 8, + transformer_layers_per_block: int = 8, + num_attention_heads: Union[int, Tuple[int]] = 8, attention_bias: bool = False, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", @@ -67,10 +68,19 @@ def __init__( super().__init__() self.motion_modules = nn.ModuleList([]) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block + elif len(transformer_layers_per_block) != layers_per_block: + raise ValueError( + f"The number of transformer layers per block must match the number of layers per block, " + f"got {layers_per_block} and {len(transformer_layers_per_block)}" + ) + for i in range(layers_per_block): self.motion_modules.append( TransformerTemporalModel( in_channels=in_channels, + num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, @@ -88,9 +98,11 @@ class MotionAdapter(ModelMixin, ConfigMixin): def __init__( self, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - motion_layers_per_block: int = 2, + motion_layers_per_block: Union[int, Tuple[int]] = 2, + motion_transformer_per_layers: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, motion_mid_block_layers_per_block: int = 1, - motion_num_attention_heads: int = 8, + motion_transformer_per_mid_layers: Union[int, Tuple[int]] = 1, + motion_num_attention_heads: Union[int, Tuple[int]] = 8, motion_norm_num_groups: int = 32, motion_max_seq_length: int = 32, use_motion_mid_block: bool = True, @@ -101,11 +113,15 @@ def __init__( Args: block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each UNet block. - motion_layers_per_block (`int`, *optional*, defaults to 2): + motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2): The number of motion layers per UNet block. + motion_transformer_per_layers (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in each block. motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): The number of motion layers in the middle UNet block. - motion_num_attention_heads (`int`, *optional*, defaults to 8): + motion_transformer_per_mid_layers (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in the middle block. + motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8): The number of heads to use in each attention layer of the motion module. motion_norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use in each group normalization layer of the motion module. @@ -119,6 +135,33 @@ def __init__( down_blocks = [] up_blocks = [] + if isinstance(motion_layers_per_block, int): + motion_layers_per_block = [motion_layers_per_block] * len(block_out_channels) + elif len(motion_layers_per_block) != len(block_out_channels): + raise ValueError( + f"The number of motion layers per block must match the number of blocks, " + f"got {len(block_out_channels)} and {len(motion_layers_per_block)}" + ) + + if isinstance(motion_transformer_per_layers, int): + motion_transformer_per_layers = [motion_transformer_per_layers] * len(block_out_channels) + + if isinstance(motion_transformer_per_mid_layers, int): + motion_transformer_per_mid_layers = [motion_transformer_per_mid_layers] * motion_mid_block_layers_per_block + elif len(motion_transformer_per_mid_layers) != motion_mid_block_layers_per_block: + raise ValueError( + f"The number of layers per mid block ({motion_mid_block_layers_per_block}) " + f"must match the length of motion_transformer_per_mid_layers ({len(motion_transformer_per_mid_layers)})" + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = [motion_num_attention_heads] * len(block_out_channels) + elif len(motion_num_attention_heads) != len(block_out_channels): + raise ValueError( + f"The length of the attention head number tuple in the motion module must match the " + f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}" + ) + if conv_in_channels: # input self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) @@ -134,9 +177,10 @@ def __init__( cross_attention_dim=None, activation_fn="geglu", attention_bias=False, - num_attention_heads=motion_num_attention_heads, + num_attention_heads=motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block, + layers_per_block=motion_layers_per_block[i], + transformer_layers_per_block=motion_transformer_per_layers[i], ) ) @@ -147,15 +191,20 @@ def __init__( cross_attention_dim=None, activation_fn="geglu", attention_bias=False, - num_attention_heads=motion_num_attention_heads, - layers_per_block=motion_mid_block_layers_per_block, + num_attention_heads=motion_num_attention_heads[-1], max_seq_length=motion_max_seq_length, + layers_per_block=motion_mid_block_layers_per_block, + transformer_layers_per_block=motion_transformer_per_mid_layers, ) else: self.mid_block = None reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] + + reversed_motion_layers_per_block = list(reversed(motion_layers_per_block)) + reversed_motion_motion_transformer_per_layers = list(reversed(motion_transformer_per_layers)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) for i, channel in enumerate(reversed_block_out_channels): output_channel = reversed_block_out_channels[i] up_blocks.append( @@ -165,9 +214,10 @@ def __init__( cross_attention_dim=None, activation_fn="geglu", attention_bias=False, - num_attention_heads=motion_num_attention_heads, + num_attention_heads=reversed_motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block + 1, + layers_per_block=reversed_motion_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_motion_motion_transformer_per_layers[i], ) ) @@ -208,7 +258,7 @@ def __init__( "CrossAttnUpBlockMotion", ), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", @@ -216,12 +266,18 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, + temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, + transformer_layers_per_midblock: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_midblock: Optional[Union[int, Tuple[int]]] = 1, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, - motion_num_attention_heads: int = 8, - use_motion_mid_block: int = True, + motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8, + reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None, + use_motion_mid_block: bool = True, + mid_block_layers: int = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, addition_embed_type: Optional[str] = None, @@ -264,6 +320,16 @@ def __init__( if isinstance(layer_number_per_block, list): raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + if ( + isinstance(temporal_transformer_layers_per_block, list) + and reverse_temporal_transformer_layers_per_block is None + ): + for layer_number_per_block in temporal_transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( + "Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet." + ) + # input conv_in_kernel = 3 conv_out_kernel = 3 @@ -304,6 +370,20 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if isinstance(reverse_transformer_layers_per_block, int): + reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types) + + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + if isinstance(reverse_temporal_transformer_layers_per_block, int): + reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len( + down_block_types + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -326,13 +406,19 @@ def __init__( downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, - temporal_num_attention_heads=motion_num_attention_heads, + temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, transformer_layers_per_block=transformer_layers_per_block[i], + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) self.down_blocks.append(down_block) # mid + if transformer_layers_per_midblock is None: + transformer_layers_per_midblock = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + if use_motion_mid_block: self.mid_block = UNetMidBlockCrossAttnMotion( in_channels=block_out_channels[-1], @@ -345,9 +431,11 @@ def __init__( resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, - temporal_num_attention_heads=motion_num_attention_heads, + num_layers=mid_block_layers, + temporal_num_attention_heads=motion_num_attention_heads[-1], temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=transformer_layers_per_block[-1], + transformer_layers_per_block=transformer_layers_per_midblock, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_midblock, ) else: @@ -362,7 +450,8 @@ def __init__( resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, - transformer_layers_per_block=transformer_layers_per_block[-1], + num_layers=mid_block_layers, + transformer_layers_per_block=transformer_layers_per_midblock, ) # count how many layers upsample the images @@ -374,6 +463,8 @@ def __init__( reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -390,6 +481,16 @@ def __init__( else: add_upsample = False + if reverse_transformer_layers_per_block is not None: + curr_layer_transformer = reverse_transformer_layers_per_block[i] + else: + curr_layer_transformer = reversed_transformer_layers_per_block[i] + + if reverse_temporal_transformer_layers_per_block is not None: + curr_layer_temporal_transformer = reverse_temporal_transformer_layers_per_block[i] + else: + curr_layer_temporal_transformer = reversed_temporal_transformer_layers_per_block[i] + up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, @@ -406,9 +507,10 @@ def __init__( dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, - temporal_num_attention_heads=motion_num_attention_heads, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], + transformer_layers_per_block=curr_layer_transformer, + temporal_transformer_layers_per_block=curr_layer_temporal_transformer, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -440,6 +542,24 @@ def from_unet2d( if has_motion_adapter: motion_adapter.to(device=unet.device) + # check compatibility of number of blocks + if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]): + raise ValueError("Incompatible Motion Adapter, got different number of blocks") + + # check layers compatibility for each block + if isinstance(unet.config["layers_per_block"], int): + expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"]) + else: + expanded_layers_per_block = list(unet.config["layers_per_block"]) + if isinstance(motion_adapter.config["motion_layers_per_block"], int): + expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len( + motion_adapter.config["block_out_channels"] + ) + else: + expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"]) + if expanded_layers_per_block != expanded_adapter_layers_per_block: + raise ValueError("Incompatible Motion Adapter, got different number of layers per block") + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 config = dict(unet.config) config["_class_name"] = cls.__name__ @@ -458,13 +578,18 @@ def from_unet2d( up_blocks.append("CrossAttnUpBlockMotion") else: up_blocks.append("UpBlockMotion") - config["up_block_types"] = up_blocks if has_motion_adapter: config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"] + config["temporal_transformer_layers_per_midblock"] = motion_adapter.config[ + "motion_transformer_per_mid_layers" + ] + config["temporal_transformer_layers_per_block"] = motion_adapter.config["motion_transformer_per_layers"] + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] # For PIA UNets we need to set the number input channels to 9 if motion_adapter.config["conv_in_channels"]: From 9cbbc01b4824ee560f7ebb2d4f941a1b503a6889 Mon Sep 17 00:00:00 2001 From: Mathis Koroglu Date: Wed, 5 Jun 2024 11:43:55 +0200 Subject: [PATCH 2/3] Motion Model added arguments renamed & refactoring --- .../models/unets/unet_motion_model.py | 80 +++++++++---------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index a51e51b60ad6..3fd5a943a2f9 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -57,7 +57,7 @@ def __init__( self, in_channels: int, layers_per_block: int = 2, - transformer_layers_per_block: int = 8, + transformer_layers_per_block: Union[int, Tuple[int]] = 8, num_attention_heads: Union[int, Tuple[int]] = 8, attention_bias: bool = False, cross_attention_dim: Optional[int] = None, @@ -99,9 +99,9 @@ def __init__( self, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), motion_layers_per_block: Union[int, Tuple[int]] = 2, - motion_transformer_per_layers: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, + motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, motion_mid_block_layers_per_block: int = 1, - motion_transformer_per_mid_layers: Union[int, Tuple[int]] = 1, + motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1, motion_num_attention_heads: Union[int, Tuple[int]] = 8, motion_norm_num_groups: int = 32, motion_max_seq_length: int = 32, @@ -115,11 +115,11 @@ def __init__( The tuple of output channels for each UNet block. motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2): The number of motion layers per UNet block. - motion_transformer_per_layers (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1): + motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1): The number of transformer layers to use in each motion layer in each block. motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): The number of motion layers in the middle UNet block. - motion_transformer_per_mid_layers (`int` or `Tuple[int]`, *optional*, defaults to 1): + motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer layers to use in each motion layer in the middle block. motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8): The number of heads to use in each attention layer of the motion module. @@ -136,26 +136,28 @@ def __init__( up_blocks = [] if isinstance(motion_layers_per_block, int): - motion_layers_per_block = [motion_layers_per_block] * len(block_out_channels) + motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels) elif len(motion_layers_per_block) != len(block_out_channels): raise ValueError( f"The number of motion layers per block must match the number of blocks, " f"got {len(block_out_channels)} and {len(motion_layers_per_block)}" ) - if isinstance(motion_transformer_per_layers, int): - motion_transformer_per_layers = [motion_transformer_per_layers] * len(block_out_channels) + if isinstance(motion_transformer_layers_per_block, int): + motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels) - if isinstance(motion_transformer_per_mid_layers, int): - motion_transformer_per_mid_layers = [motion_transformer_per_mid_layers] * motion_mid_block_layers_per_block - elif len(motion_transformer_per_mid_layers) != motion_mid_block_layers_per_block: + if isinstance(motion_transformer_layers_per_mid_block, int): + motion_transformer_layers_per_mid_block = ( + motion_transformer_layers_per_mid_block, + ) * motion_mid_block_layers_per_block + elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block: raise ValueError( f"The number of layers per mid block ({motion_mid_block_layers_per_block}) " - f"must match the length of motion_transformer_per_mid_layers ({len(motion_transformer_per_mid_layers)})" + f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})" ) if isinstance(motion_num_attention_heads, int): - motion_num_attention_heads = [motion_num_attention_heads] * len(block_out_channels) + motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels) elif len(motion_num_attention_heads) != len(block_out_channels): raise ValueError( f"The length of the attention head number tuple in the motion module must match the " @@ -180,7 +182,7 @@ def __init__( num_attention_heads=motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block[i], - transformer_layers_per_block=motion_transformer_per_layers[i], + transformer_layers_per_block=motion_transformer_layers_per_block[i], ) ) @@ -194,7 +196,7 @@ def __init__( num_attention_heads=motion_num_attention_heads[-1], max_seq_length=motion_max_seq_length, layers_per_block=motion_mid_block_layers_per_block, - transformer_layers_per_block=motion_transformer_per_mid_layers, + transformer_layers_per_block=motion_transformer_layers_per_mid_block, ) else: self.mid_block = None @@ -203,7 +205,7 @@ def __init__( output_channel = reversed_block_out_channels[0] reversed_motion_layers_per_block = list(reversed(motion_layers_per_block)) - reversed_motion_motion_transformer_per_layers = list(reversed(motion_transformer_per_layers)) + reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block)) reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) for i, channel in enumerate(reversed_block_out_channels): output_channel = reversed_block_out_channels[i] @@ -217,7 +219,7 @@ def __init__( num_attention_heads=reversed_motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, layers_per_block=reversed_motion_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_motion_motion_transformer_per_layers[i], + transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i], ) ) @@ -269,8 +271,8 @@ def __init__( reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, - transformer_layers_per_midblock: Optional[Union[int, Tuple[int]]] = None, - temporal_transformer_layers_per_midblock: Optional[Union[int, Tuple[int]]] = 1, + transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, + temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, @@ -414,8 +416,8 @@ def __init__( self.down_blocks.append(down_block) # mid - if transformer_layers_per_midblock is None: - transformer_layers_per_midblock = ( + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 ) @@ -434,8 +436,8 @@ def __init__( num_layers=mid_block_layers, temporal_num_attention_heads=motion_num_attention_heads[-1], temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=transformer_layers_per_midblock, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_midblock, + transformer_layers_per_block=transformer_layers_per_mid_block, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block, ) else: @@ -451,7 +453,7 @@ def __init__( dual_cross_attention=False, use_linear_projection=use_linear_projection, num_layers=mid_block_layers, - transformer_layers_per_block=transformer_layers_per_midblock, + transformer_layers_per_block=transformer_layers_per_mid_block, ) # count how many layers upsample the images @@ -462,10 +464,14 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) - reversed_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) + if reverse_transformer_layers_per_block is None: + reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + if reverse_temporal_transformer_layers_per_block is None: + reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) + output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -481,16 +487,6 @@ def __init__( else: add_upsample = False - if reverse_transformer_layers_per_block is not None: - curr_layer_transformer = reverse_transformer_layers_per_block[i] - else: - curr_layer_transformer = reversed_transformer_layers_per_block[i] - - if reverse_temporal_transformer_layers_per_block is not None: - curr_layer_temporal_transformer = reverse_temporal_transformer_layers_per_block[i] - else: - curr_layer_temporal_transformer = reversed_temporal_transformer_layers_per_block[i] - up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, @@ -509,8 +505,8 @@ def __init__( use_linear_projection=use_linear_projection, temporal_num_attention_heads=reversed_motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, - transformer_layers_per_block=curr_layer_transformer, - temporal_transformer_layers_per_block=curr_layer_temporal_transformer, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -585,10 +581,12 @@ def from_unet2d( config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"] - config["temporal_transformer_layers_per_midblock"] = motion_adapter.config[ - "motion_transformer_per_mid_layers" + config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[ + "motion_transformer_layers_per_mid_block" + ] + config["temporal_transformer_layers_per_block"] = motion_adapter.config[ + "motion_transformer_layers_per_block" ] - config["temporal_transformer_layers_per_block"] = motion_adapter.config["motion_transformer_per_layers"] config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] # For PIA UNets we need to set the number input channels to 9 From cdb3b89a414ebf8e9b1233e253242e551c76a4e2 Mon Sep 17 00:00:00 2001 From: Mathis Koroglu Date: Mon, 10 Jun 2024 15:06:42 +0200 Subject: [PATCH 3/3] Add test for asymmetric UNetMotionModel --- tests/models/unets/test_models_unet_motion.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index 7d83b07c49fe..53833d6a075b 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -306,3 +306,36 @@ def test_forward_with_norm_groups(self): self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_asymmetric_motion_model(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["layers_per_block"] = (2, 3) + init_dict["transformer_layers_per_block"] = ((1, 2), (3, 4, 5)) + init_dict["reverse_transformer_layers_per_block"] = ((7, 6, 7, 4), (4, 2, 2)) + + init_dict["temporal_transformer_layers_per_block"] = ((2, 5), (2, 3, 5)) + init_dict["reverse_temporal_transformer_layers_per_block"] = ((5, 4, 3, 4), (3, 2, 2)) + + init_dict["num_attention_heads"] = (2, 4) + init_dict["motion_num_attention_heads"] = (4, 4) + init_dict["reverse_motion_num_attention_heads"] = (2, 2) + + init_dict["use_motion_mid_block"] = True + init_dict["mid_block_layers"] = 2 + init_dict["transformer_layers_per_mid_block"] = (1, 5) + init_dict["temporal_transformer_layers_per_mid_block"] = (2, 4) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")