diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 99f2b2fbdbc8..ebd45c09ae33 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -634,7 +634,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -654,6 +654,10 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + # there is always at least one resnet resnets = [ ResnetBlock2D( @@ -671,14 +675,14 @@ def __init__( ] attentions = [] - for _ in range(num_layers): + for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1018,7 +1022,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1041,6 +1045,8 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -1064,7 +1070,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -2167,7 +2173,7 @@ def __init__( resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -2190,6 +2196,9 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -2214,7 +2223,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0ce2e04ad99a..1a242ff165f6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -43,6 +43,7 @@ ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( + UNetMidBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, @@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. @@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -142,9 +148,9 @@ class conditioning with `class_embed_type` equal to `None`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. @@ -184,7 +190,8 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -265,6 +272,10 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input conv_in_padding = (conv_in_kernel - 1) // 2 @@ -500,6 +511,19 @@ def __init__( only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) elif mid_block_type is None: self.mid_block = None else: @@ -513,7 +537,11 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -1062,14 +1090,18 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + # To support T2I-Adapter-XL if ( is_adapter diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 18518cc3783f..f2e3c457bc62 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index de8f1071d073..1e10c9b04d46 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline( feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ + model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 073af945053b..ffe81ea44a27 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if "disable_self_attentions" in unet_params: config["only_cross_attention"] = unet_params.disable_self_attentions - if "num_classes" in unet_params and type(unet_params.num_classes) == int: + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): config["num_class_embeds"] = unet_params.num_classes if controlnet: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 717db3bbdb34..63f6e7d63800 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -281,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or + Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. @@ -300,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. @@ -337,9 +342,9 @@ class conditioning with `class_embed_type` equal to `None`. The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. @@ -384,7 +389,8 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, @@ -475,6 +481,10 @@ def __init__( "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" f" {layers_per_block}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input conv_in_padding = (conv_in_kernel - 1) // 2 @@ -710,6 +720,19 @@ def __init__( only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) + elif mid_block_type == "UNetMidBlockFlat": + self.mid_block = UNetMidBlockFlat( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) elif mid_block_type is None: self.mid_block = None else: @@ -723,7 +746,11 @@ def __init__( reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] @@ -1281,14 +1308,18 @@ def forward( # 4. mid if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + # To support T2I-Adapter-XL if ( is_adapter @@ -1557,7 +1588,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1580,6 +1611,8 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -1603,7 +1636,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1823,7 +1856,7 @@ def __init__( resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -1846,6 +1879,9 @@ def __init__( self.has_cross_attention = True self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels @@ -1870,7 +1906,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, @@ -1983,6 +2019,133 @@ def custom_forward(*inputs): return hidden_states +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlat(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + "It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to" + f" `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( @@ -1991,7 +2154,7 @@ def __init__( temb_channels: int, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -2011,6 +2174,10 @@ def __init__( self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + # there is always at least one resnet resnets = [ ResnetBlockFlat( @@ -2028,14 +2195,14 @@ def __init__( ] attentions = [] - for _ in range(num_layers): + for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, - num_layers=transformer_layers_per_block, + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d8b412aa12d9..0db336a88029 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -606,6 +606,22 @@ def test_pickle(self): assert (sample - sample_copy).abs().max() < 1e-4 + def test_asymmetrical_unet(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + # Add asymmetry to configs + init_dict["transformer_layers_per_block"] = [[3, 2], 1] + init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1] + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + output = model(**inputs_dict).sample + expected_shape = inputs_dict["sample"].shape + + # Check if input and output shapes are the same + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase):