Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,14 @@ def forward(self, x):


class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
def __init__(
self,
embedding_dim,
pooled_projection_dim=1024,
seq_len=256,
cross_attention_dim=2048,
use_style_cond_and_image_meta_size=True,
):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
Expand All @@ -726,9 +733,15 @@ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)

# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
if use_style_cond_and_image_meta_size:
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
else:
extra_in_dim = pooled_projection_dim

self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
Expand All @@ -743,16 +756,20 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)

# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
if self.use_style_cond_and_image_meta_size:
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)

# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)

# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
else:
extra_cond = torch.cat([pooled_projections], dim=1)

# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]

return conditioning
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/transformers/hunyuan_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
use_style_cond_and_image_meta_size (`bool`, *optional*):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""

@register_to_config
Expand All @@ -270,6 +272,7 @@ def __init__(
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
use_style_cond_and_image_meta_size: bool = True,
):
super().__init__()
self.out_channels = in_channels * 2 if learn_sigma else in_channels
Expand Down Expand Up @@ -301,6 +304,7 @@ def __init__(
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
)

# HunyuanDiT Blocks
Expand Down