From 8a25b1f6e076c79d6eb855a093070c5178132c3f Mon Sep 17 00:00:00 2001 From: William Berman Date: Sat, 8 Apr 2023 19:43:50 -0700 Subject: [PATCH] allow unet varying number of layers per block --- src/diffusers/models/unet_2d_condition.py | 15 ++++++++++++--- .../versatile_diffusion/modeling_text_unet.py | 16 +++++++++++++--- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4d237286fb32..c1d8628c47bb 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -132,7 +132,7 @@ def __init__( up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", @@ -184,6 +184,11 @@ def __init__( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( @@ -258,6 +263,9 @@ def __init__( if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the @@ -275,7 +283,7 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, @@ -333,6 +341,7 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) only_cross_attention = list(reversed(only_cross_attention)) @@ -353,7 +362,7 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index deaa709ab319..d490bc21b858 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -218,7 +218,7 @@ def __init__( ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", @@ -275,6 +275,12 @@ def __init__( f" {cross_attention_dim}. `down_block_types`: {down_block_types}." ) + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" + f" {layers_per_block}. `down_block_types`: {down_block_types}." + ) + # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( @@ -349,6 +355,9 @@ def __init__( if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the @@ -366,7 +375,7 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, @@ -424,6 +433,7 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) only_cross_attention = list(reversed(only_cross_attention)) @@ -444,7 +454,7 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel,