Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 112 additions & 15 deletions src/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"DownBlock3D",
"CrossAttnDownBlock3D",
Expand All @@ -79,6 +81,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
Expand All @@ -100,6 +103,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
dropout=dropout,
)
if down_block_type == "DownBlockMotion":
return DownBlockMotion(
Expand All @@ -115,6 +119,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
Expand All @@ -139,6 +145,8 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
Expand Down Expand Up @@ -189,7 +197,8 @@ def get_up_block(
temporal_num_attention_heads: int = 8,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
dropout: float = 0.0,
) -> Union[
"UpBlock3D",
Expand All @@ -212,6 +221,7 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
Expand All @@ -234,6 +244,7 @@ def get_up_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
if up_block_type == "UpBlockMotion":
return UpBlockMotion(
Expand All @@ -250,6 +261,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None:
Expand All @@ -275,6 +288,8 @@ def get_up_block(
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
dropout=dropout,
)
elif up_block_type == "UpBlockSpatioTemporal":
# added for SDV
Expand Down Expand Up @@ -948,14 +963,31 @@ def __init__(
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
temporal_num_attention_heads: int = 1,
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []

# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
)

# support for variable number of attention head per temporal layers
if isinstance(temporal_num_attention_heads, int):
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
elif len(temporal_num_attention_heads) != num_layers:
raise ValueError(
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
)

for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
Expand All @@ -974,15 +1006,16 @@ def __init__(
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
num_attention_heads=temporal_num_attention_heads[i],
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
)
)

Expand Down Expand Up @@ -1065,7 +1098,7 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -1084,6 +1117,7 @@ def __init__(
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
Expand All @@ -1093,6 +1127,22 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads

# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)

# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)

for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
Expand All @@ -1116,7 +1166,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand All @@ -1141,6 +1191,7 @@ def __init__(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
Expand Down Expand Up @@ -1257,7 +1308,7 @@ def __init__(
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -1275,6 +1326,7 @@ def __init__(
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
Expand All @@ -1284,6 +1336,22 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads

# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
)

# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
)

for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
Expand All @@ -1309,7 +1377,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand All @@ -1333,6 +1401,7 @@ def __init__(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
Expand Down Expand Up @@ -1467,11 +1536,20 @@ def __init__(
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()
resnets = []
motion_modules = []

# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
)

for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
Expand All @@ -1495,6 +1573,7 @@ def __init__(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=temporal_norm_num_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
Expand Down Expand Up @@ -1596,7 +1675,7 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
Expand All @@ -1605,20 +1684,37 @@ def __init__(
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: float = False,
use_linear_projection: float = False,
upcast_attention: float = False,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
):
super().__init__()

self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)

# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
elif len(transformer_layers_per_block) != num_layers:
raise ValueError(
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)

# support for variable transformer layers per temporal block
if isinstance(temporal_transformer_layers_per_block, int):
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
elif len(temporal_transformer_layers_per_block) != num_layers:
raise ValueError(
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
)

# there is always at least one resnet
resnets = [
ResnetBlock2D(
Expand All @@ -1637,14 +1733,14 @@ def __init__(
attentions = []
motion_modules = []

for _ in range(num_layers):
for i in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
Expand Down Expand Up @@ -1682,6 +1778,7 @@ def __init__(
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
Expand Down
Loading