From b8dca9f099f8a8abcac4e43d39750190117e2754 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 11:34:02 +0530 Subject: [PATCH 01/29] Added args, kwargs to ```U --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d57949976d30..47d674a5e7c5 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -587,7 +587,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, *args, **kwargs): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: From 577d1e246f49afaed0913b2806056df5e3c72956 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 11:38:50 +0530 Subject: [PATCH 02/29] Add UNetMidBlock2D as a supported mid block type --- src/diffusers/models/unet_2d_condition.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4039fbfcc67a..14d228b652a1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -43,6 +43,7 @@ ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( + UNetMidBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, @@ -500,6 +501,20 @@ def __init__( only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + add_attention=False + ) elif mid_block_type is None: self.mid_block = None else: From 2c8d80435955d82177b121546e4e1afa81f36169 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 11:44:10 +0530 Subject: [PATCH 03/29] Fix extra init input for UNetMidBlock2D, change allowed types for Mid-block init --- src/diffusers/models/unet_2d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 14d228b652a1..b537f7c7f765 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -173,7 +173,7 @@ def __init__( "CrossAttnDownBlock2D", "DownBlock2D", ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + mid_block_type: Optional[str] = ("UNetMidBlock2DCrossAttn", "UNetMidBlock2D"), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), @@ -512,7 +512,6 @@ def __init__( output_scale_factor=mid_block_scale_factor, resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, add_attention=False ) elif mid_block_type is None: From e07c6156f89a81977c4d63c218c314c2d1909163 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 12:48:20 +0530 Subject: [PATCH 04/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b537f7c7f765..38161d945a07 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -185,7 +185,8 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -267,6 +268,11 @@ def __init__( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, Tuple[Tuple[int]]) and reverse_transformer_layers_per_block is None: + raise ValueError( + "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( @@ -420,6 +426,8 @@ def __init__( if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) From 8344b2d22f16b8fff193a440ca40548143aa50b9 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 12:51:43 +0530 Subject: [PATCH 05/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 38161d945a07..6faddf98afa0 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -106,10 +106,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or Tuple[Tuple[int]] , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : TODO + encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -186,7 +188,7 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]] = None, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, From 025b1ec2330350da0fb23b669fc21c19d55a83af Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 12:55:11 +0530 Subject: [PATCH 06/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 6faddf98afa0..6fd3e6588925 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -187,7 +187,7 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, @@ -270,7 +270,7 @@ def __init__( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - if isinstance(transformer_layers_per_block, Tuple[Tuple[int]]) and reverse_transformer_layers_per_block is None: + if isinstance(transformer_layers_per_block, Tuple[Tuple]) and reverse_transformer_layers_per_block is None: raise ValueError( "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." ) From 6230214847f25aea3db73c97c829933ff631242b Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:01:02 +0530 Subject: [PATCH 07/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 6fd3e6588925..c0f385aad9f9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -269,11 +269,13 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - - if isinstance(transformer_layers_per_block, Tuple[Tuple]) and reverse_transformer_layers_per_block is None: - raise ValueError( - "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." - ) + try: + if isinstance(transformer_layers_per_block, Tuple[Tuple]) and reverse_transformer_layers_per_block is None: + raise ValueError( + "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." + ) + except: + print(type(tuple[tuple])) # input conv_in_padding = (conv_in_kernel - 1) // 2 From 80b891b6830e0a013caa47ec5ecd7ede8c7cdc02 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:02:35 +0530 Subject: [PATCH 08/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c0f385aad9f9..73a820baa621 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -275,7 +275,7 @@ def __init__( "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." ) except: - print(type(tuple[tuple])) + print(type(Tuple[Tuple])) # input conv_in_padding = (conv_in_kernel - 1) // 2 From c4e4d40d7d024db4eae570284ef159da4383b645 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:04:00 +0530 Subject: [PATCH 09/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 73a820baa621..838defc94cdd 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -269,13 +269,10 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - try: - if isinstance(transformer_layers_per_block, Tuple[Tuple]) and reverse_transformer_layers_per_block is None: + if isinstance(transformer_layers_per_block, Tuple[List]) and reverse_transformer_layers_per_block is None: raise ValueError( "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." ) - except: - print(type(Tuple[Tuple])) # input conv_in_padding = (conv_in_kernel - 1) // 2 From 6176fd5a471f3a3a0577403f702b72c201286fb9 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:05:53 +0530 Subject: [PATCH 10/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 838defc94cdd..ecce49fb95c8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -269,7 +269,7 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - if isinstance(transformer_layers_per_block, Tuple[List]) and reverse_transformer_layers_per_block is None: + if isinstance(transformer_layers_per_block, list) and isinstance(transformer_layers_per_block[0], list) and reverse_transformer_layers_per_block is None: raise ValueError( "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." ) From 73737f8cab555d16cf1d6212a3c3a8883f8db5fb Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 13:09:45 +0530 Subject: [PATCH 11/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ecce49fb95c8..4b250097b556 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -536,7 +536,7 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) if reverse_transformer_blocks is None else reverse_transformer_layers_per_block only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] From bc53a32e7e4e3d53f4ba39b242310b69a1a10c62 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 14:55:26 +0530 Subject: [PATCH 12/29] Update unet_2d_blocks.py --- src/diffusers/models/unet_2d_blocks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 47d674a5e7c5..defc4a7bb655 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -604,7 +604,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -988,7 +988,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -2137,7 +2137,7 @@ def __init__( resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", From 314730bab66952492ff107f30472c136a03d7993 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 18:02:33 +0530 Subject: [PATCH 13/29] Update unet_2d_blocks.py --- src/diffusers/models/unet_2d_blocks.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index defc4a7bb655..0649e7496aab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -624,6 +624,11 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block= [transformer_layers_per_block] * num_layers + # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -648,7 +653,7 @@ def __init__( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1011,6 +1016,8 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block= [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -1034,7 +1041,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -2160,6 +2167,9 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block= [transformer_layers_per_block] * num_layers + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -2184,7 +2194,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, From 545a0b0fbf86f7310d07929df9713672c03d6f95 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 18:04:17 +0530 Subject: [PATCH 14/29] Update unet_2d_blocks.py --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0649e7496aab..4ab62179fa9d 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch From bf772686adcaa99964b404a2626ccc4d140f69e8 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 18:06:17 +0530 Subject: [PATCH 15/29] Update unet_2d_condition.py --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4b250097b556..9c908be35ce6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -536,7 +536,7 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) if reverse_transformer_blocks is None else reverse_transformer_layers_per_block + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] From 1998c17f111c43b70dfd7edf31e07a709c42b52c Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Sun, 15 Oct 2023 18:20:54 +0530 Subject: [PATCH 16/29] Update unet_2d_blocks.py --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 4ab62179fa9d..29ae8ee33c45 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -646,7 +646,7 @@ def __init__( ] attentions = [] - for _ in range(num_layers): + for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( From 81d682ed5bdfc1a58766f8680faca4aadd30551e Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:53:18 +0530 Subject: [PATCH 17/29] Updated docstring, increased check strictness Updated the docstring for ```UNet2DConditionModel``` to include ```reverse_transformer_layers_per_block``` and updated checking for nested list type ```transformer_layers_per_block``` --- src/diffusers/models/unet_2d_condition.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 9c908be35ce6..2e7498c830fd 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -106,12 +106,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or Tuple[Tuple[int]] , *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or Tuple[Tuple] , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : TODO - + reverse_transformer_layers_per_block : (`Tuple[Tuple]`. *optional*, required if passing transformer_layers_per_block of type `Tuple[Tuple]): + Provides the same functionality as transformer_layer_per_block, but for the up_blocks in the U-Net. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -269,10 +271,12 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - if isinstance(transformer_layers_per_block, list) and isinstance(transformer_layers_per_block[0], list) and reverse_transformer_layers_per_block is None: - raise ValueError( + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." - ) + ) # input conv_in_padding = (conv_in_kernel - 1) // 2 From 93ce7d3d0b353d6f347ad94e8a7592fb74819200 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Tue, 17 Oct 2023 14:09:43 +0530 Subject: [PATCH 18/29] Add basic shape-check test for asymmetrical unets --- tests/models/test_models_unet_2d_condition.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d8b412aa12d9..019c34891f08 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -606,6 +606,22 @@ def test_pickle(self): assert (sample - sample_copy).abs().max() < 1e-4 + def test_asymmetrical_unet(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + # Add asymmetry to configs + init_dict["transformer_layers_per_block"] = [[3,2],1] + init_dict["reverse_transformer_layers_per_block"] = [[3,4],1] + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + output = model(**inputs_dict).sample + expected_shape = inputs_dict["sample"].shape + + # Check if input and output shapes are the same + self.assertEqual(output.shape, expected_shape , "Input and output shapes do not match") + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): From e6b937cc19bb156c3d6d258db2ae13c7e5e99617 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:07:48 +0530 Subject: [PATCH 19/29] Update src/diffusers/models/unet_2d_blocks.py Removed blank line Co-authored-by: Sayak Paul --- src/diffusers/models/unet_2d_blocks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 29ae8ee33c45..c624eec0e8ae 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -625,7 +625,6 @@ def __init__( resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): transformer_layers_per_block= [transformer_layers_per_block] * num_layers From 0e9f6f484be81e00f43e1517dbfa6f037e16b867 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:10:39 +0530 Subject: [PATCH 20/29] Update unet_2d_condition.py Remove blank space --- src/diffusers/models/unet_2d_condition.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a2bb53b22604..344c4ad92a38 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -431,8 +431,6 @@ def __init__( if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) From a91d08566c9aa22ea3d078371ef4d65c3bdab5a6 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:01:58 +0530 Subject: [PATCH 21/29] Update unet_2d_condition.py Changed docstring for `mid_block_type` --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 344c4ad92a38..c974afa97a35 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -87,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. From d75f2039b26a1cf0802a3d54ac1c66c48b27b1d5 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Tue, 17 Oct 2023 20:23:39 +0530 Subject: [PATCH 22/29] Fixed docstring and wrong default value --- src/diffusers/models/unet_2d_condition.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c974afa97a35..60cbe0a7bbfe 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -106,12 +106,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or Tuple[Tuple] , *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`. *optional*, required if passing transformer_layers_per_block of type `Tuple[Tuple]): - Provides the same functionality as transformer_layer_per_block, but for the up_blocks in the U-Net. Only relevant for + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. + Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): @@ -177,7 +178,7 @@ def __init__( "CrossAttnDownBlock2D", "DownBlock2D", ), - mid_block_type: Optional[str] = ("UNetMidBlock2DCrossAttn", "UNetMidBlock2D"), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), From 97c47b603ba586fdaadabb273132d95e34b7dd7f Mon Sep 17 00:00:00 2001 From: Gothos Date: Tue, 17 Oct 2023 22:28:48 +0530 Subject: [PATCH 23/29] Reformat with black --- src/diffusers/models/unet_2d_blocks.py | 8 ++++---- src/diffusers/models/unet_2d_condition.py | 16 +++++++++------- tests/models/test_models_unet_2d_condition.py | 6 +++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c624eec0e8ae..f75eb0a75a69 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -624,9 +624,9 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - # support for variable transformer layers per block + # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block= [transformer_layers_per_block] * num_layers + transformer_layers_per_block = [transformer_layers_per_block] * num_layers # there is always at least one resnet resnets = [ @@ -1016,7 +1016,7 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block= [transformer_layers_per_block] * num_layers + transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -2167,7 +2167,7 @@ def __init__( self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block= [transformer_layers_per_block] * num_layers + transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 60cbe0a7bbfe..daa7a47ead5c 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -110,8 +110,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. @@ -275,9 +275,7 @@ def __init__( if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): - raise ValueError( - "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." - ) + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input conv_in_padding = (conv_in_kernel - 1) // 2 @@ -524,7 +522,7 @@ def __init__( output_scale_factor=mid_block_scale_factor, resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False + add_attention=False, ) elif mid_block_type is None: self.mid_block = None @@ -539,7 +537,11 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 019c34891f08..0db336a88029 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -609,8 +609,8 @@ def test_pickle(self): def test_asymmetrical_unet(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() # Add asymmetry to configs - init_dict["transformer_layers_per_block"] = [[3,2],1] - init_dict["reverse_transformer_layers_per_block"] = [[3,4],1] + init_dict["transformer_layers_per_block"] = [[3, 2], 1] + init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] torch.manual_seed(0) model = self.model_class(**init_dict) @@ -620,7 +620,7 @@ def test_asymmetrical_unet(self): expected_shape = inputs_dict["sample"].shape # Check if input and output shapes are the same - self.assertEqual(output.shape, expected_shape , "Input and output shapes do not match") + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") @slow From 196ab3e7706b176d33ddfd9513cda626442cae13 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:05:05 +0000 Subject: [PATCH 24/29] Reformat with necessary commands --- src/diffusers/models/unet_2d_condition.py | 10 +-- .../alt_diffusion/pipeline_alt_diffusion.py | 1 + .../pipeline_alt_diffusion_img2img.py | 1 + .../stable_diffusion/convert_from_ckpt.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 64 +++++++++++++++---- 5 files changed, 58 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index daa7a47ead5c..0b664bebf5de 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -111,8 +111,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. - Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): @@ -148,9 +148,9 @@ class conditioning with `class_embed_type` equal to `None`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 18518cc3783f..f2e3c457bc62 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index de8f1071d073..1e10c9b04d46 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index e97f66bbcb24..2fb6b21e63ef 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if "disable_self_attentions" in unet_params: config["only_cross_attention"] = unet_params.disable_self_attentions - if "num_classes" in unet_params and type(unet_params.num_classes) == int: + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): config["num_class_embeds"] = unet_params.num_classes if controlnet: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 717db3bbdb34..80c02ee7e3e9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -281,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or + Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. @@ -300,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -337,9 +342,9 @@ class conditioning with `class_embed_type` equal to `None`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. @@ -384,7 +389,8 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -475,6 +481,10 @@ def __init__( "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" f" {layers_per_block}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input conv_in_padding = (conv_in_kernel - 1) // 2 @@ -710,6 +720,19 @@ def __init__( only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) + elif mid_block_type == "UNetMidBlockFlat": + self.mid_block = UNetMidBlockFlat( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) elif mid_block_type is None: self.mid_block = None else: @@ -723,7 +746,11 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -1557,7 +1584,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1580,6 +1607,8 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -1603,7 +1632,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1823,7 +1852,7 @@ def __init__( resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1846,6 +1875,9 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1870,7 +1902,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1991,7 +2023,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -2011,6 +2043,10 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + # there is always at least one resnet resnets = [ ResnetBlockFlat( @@ -2028,14 +2064,14 @@ def __init__( ] attentions = [] - for _ in range(num_layers): + for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, From bbebc2374328b05cdc7749bd378a4ac74e199cd2 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Wed, 18 Oct 2023 10:38:26 +0000 Subject: [PATCH 25/29] Add UNetMidBlockFlat to versatile_diffusion/modeling_text_unet.py to ensure consistency --- .../versatile_diffusion/modeling_text_unet.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 80c02ee7e3e9..e3c4c3fb7e7b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -2015,6 +2015,103 @@ def custom_forward(*inputs): return hidden_states +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + "It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to" + f" `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, *args, **kwargs): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( From 0a73108a7da6f756e6394cdf9c954e250691d68e Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Thu, 19 Oct 2023 05:14:43 +0000 Subject: [PATCH 26/29] Removed args, kwargs, use on mid-block type --- src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/models/unet_2d_condition.py | 23 +++++++++++------ .../versatile_diffusion/modeling_text_unet.py | 25 ++++++++++++------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f75eb0a75a69..13f1fa8b72e8 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -587,7 +587,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, *args, **kwargs): + def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0b664bebf5de..68c76b7662ec 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1090,14 +1090,21 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block( + sample, + emb, + ) + # To support T2I-Adapter-XL if ( is_adapter diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index e3c4c3fb7e7b..3111e6c0465b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1308,14 +1308,21 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block( + sample, + emb, + ) + # To support T2I-Adapter-XL if ( is_adapter @@ -2102,7 +2109,7 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, *args, **kwargs): + def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: From bfbf85c33c9575ddaea14cbf895d434f2bca360d Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Thu, 19 Oct 2023 05:21:09 +0000 Subject: [PATCH 27/29] Make fix-copies --- .../alt_diffusion/pipeline_alt_diffusion.py | 1 + .../pipeline_alt_diffusion_img2img.py | 1 + .../versatile_diffusion/modeling_text_unet.py | 30 +++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 18518cc3783f..f2e3c457bc62 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index de8f1071d073..1e10c9b04d46 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3111e6c0465b..320ddfa9ea76 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -2024,6 +2024,36 @@ def custom_forward(*inputs): # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlat(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + def __init__( self, in_channels: int, From 8389a851a4365f3dbf51096340f54ccd9a20f859 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Fri, 20 Oct 2023 12:22:04 +0530 Subject: [PATCH 28/29] Update src/diffusers/models/unet_2d_condition.py Wrap into single line Co-authored-by: Sayak Paul --- src/diffusers/models/unet_2d_condition.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 68c76b7662ec..1a242ff165f6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1100,10 +1100,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: - sample = self.mid_block( - sample, - emb, - ) + sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( From 252ca8506b51e3715e5cd5517fbb335f3c4408f3 Mon Sep 17 00:00:00 2001 From: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Date: Fri, 20 Oct 2023 08:49:01 +0000 Subject: [PATCH 29/29] make fix-copies --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 320ddfa9ea76..63f6e7d63800 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1318,10 +1318,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: - sample = self.mid_block( - sample, - emb, - ) + sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if (