Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
Expand Down Expand Up @@ -184,6 +184,11 @@ def __init__(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)

# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
Expand Down Expand Up @@ -258,6 +263,9 @@ def __init__(
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)

if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
Expand All @@ -275,7 +283,7 @@ def __init__(

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
Expand Down Expand Up @@ -333,6 +341,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))

Expand All @@ -353,7 +362,7 @@ def __init__(

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
Expand Down
16 changes: 13 additions & 3 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
Expand Down Expand Up @@ -275,6 +275,12 @@ def __init__(
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
f" {layers_per_block}. `down_block_types`: {down_block_types}."
)

# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
Expand Down Expand Up @@ -349,6 +355,9 @@ def __init__(
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)

if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
Expand All @@ -366,7 +375,7 @@ def __init__(

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
Expand Down Expand Up @@ -424,6 +433,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))

Expand All @@ -444,7 +454,7 @@ def __init__(

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
Expand Down