Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
)
for d in range(num_layers)
]
Expand Down Expand Up @@ -387,14 +389,17 @@ def __init__(
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
Expand Down Expand Up @@ -461,7 +466,11 @@ def forward(self, hidden_states, context=None, timestep=None):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states

if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states

# 2. Cross-Attention
norm_hidden_states = (
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_down_block(
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
Expand Down Expand Up @@ -78,6 +79,7 @@ def get_down_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
Expand Down Expand Up @@ -143,6 +145,7 @@ def get_up_block(
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
Expand Down Expand Up @@ -174,6 +177,7 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
Expand Down Expand Up @@ -530,6 +534,7 @@ def __init__(
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -564,6 +569,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down Expand Up @@ -1129,6 +1135,7 @@ def __init__(
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -1165,6 +1172,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
Expand All @@ -109,6 +110,7 @@ def __init__(
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()

Expand All @@ -124,10 +126,17 @@ def __init__(

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand All @@ -153,6 +162,7 @@ def __init__(
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)

Expand All @@ -177,6 +187,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -207,6 +218,7 @@ def __init__(
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -258,6 +270,7 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Expand Down Expand Up @@ -310,6 +323,12 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

# 2. pre-process
sample = self.conv_in(sample)

Expand Down
23 changes: 23 additions & 0 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
Expand All @@ -177,6 +178,7 @@ def __init__(
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()

Expand All @@ -192,10 +194,17 @@ def __init__(

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])

if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)

if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)

Expand All @@ -221,6 +230,7 @@ def __init__(
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)

Expand All @@ -245,6 +255,7 @@ def __init__(
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
Expand Down Expand Up @@ -275,6 +286,7 @@ def __init__(
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
Expand Down Expand Up @@ -326,6 +338,7 @@ def forward(
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Expand Down Expand Up @@ -378,6 +391,12 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

# 2. pre-process
sample = self.conv_in(sample)

Expand Down Expand Up @@ -648,6 +667,7 @@ def __init__(
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -682,6 +702,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down Expand Up @@ -861,6 +882,7 @@ def __init__(
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -897,6 +919,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
Expand Down