-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Core] refactor embeddings
#8722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+1,351
−1,293
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
abf67a8
refactor embeddings
sayakpaul 173e7b0
fix
sayakpaul 2a1dff8
fix morer
sayakpaul aee0d07
Merge branch 'main' into embeddings-refactor
sayakpaul cb5d919
resolve conflicts
sayakpaul 776bc54
Merge branch 'main' into embeddings-refactor
sayakpaul acd8461
reflect changes from HunyuanCombinedTimestepTextSizeStyleEmbedding
sayakpaul 3d42fc7
resolve conflicts
sayakpaul f888685
fix
sayakpaul 28478a6
patch embedding to position
sayakpaul b8876b6
fix
sayakpaul 50932aa
style
sayakpaul 5f9a995
Merge branch 'main' into embeddings-refactor
sayakpaul 733609a
up
sayakpaul f94e3fc
move hunyuan stuff to image_text
sayakpaul dc3ef41
Merge branch 'main' into embeddings-refactor
sayakpaul 836be38
Merge branch 'main' into embeddings-refactor
sayakpaul 91b75c6
move labelembedding to others
sayakpaul 498ec77
move more to combibed.
sayakpaul 792deca
up
sayakpaul 19c4c8e
up
sayakpaul 48d7a28
fix import
sayakpaul a61d541
changes from https://github.com/huggingface/diffusers/pull/8764
sayakpaul 59d4db7
Merge branch 'main' into embeddings-refactor
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| from .combined import ( | ||
| CombinedTimestepLabelEmbeddings, | ||
| CombinedTimestepTextProjEmbeddings, | ||
| HunyuanCombinedTimestepTextSizeStyleEmbedding, | ||
| HunyuanDiTAttentionPool, | ||
| ImageHintTimeEmbedding, | ||
| PixArtAlphaCombinedTimestepSizeEmbeddings, | ||
| TextImageProjection, | ||
| TextImageTimeEmbedding, | ||
| ) | ||
| from .image import ( | ||
| ImagePositionalEmbeddings, | ||
| ImageProjection, | ||
| ImageTimeEmbedding, | ||
| IPAdapterFaceIDImageProjection, | ||
| IPAdapterFaceIDPlusImageProjection, | ||
| IPAdapterFullImageProjection, | ||
| IPAdapterPlusImageProjection, | ||
| IPAdapterPlusImageProjectionBlock, | ||
| MultiIPAdapterImageProjection, | ||
| ) | ||
| from .others import ( | ||
| GLIGENTextBoundingboxProjection, | ||
| LabelEmbedding, | ||
| get_fourier_embeds_from_boundingbox, | ||
| ) | ||
| from .position import ( | ||
| PatchEmbed, | ||
| SinusoidalPositionalEmbedding, | ||
| apply_rotary_emb, | ||
| get_1d_rotary_pos_embed, | ||
| get_1d_sincos_pos_embed_from_grid, | ||
| get_2d_rotary_pos_embed, | ||
| get_2d_rotary_pos_embed_from_grid, | ||
| get_2d_sincos_pos_embed, | ||
| get_2d_sincos_pos_embed_from_grid, | ||
| ) | ||
| from .text import ( | ||
| AttentionPooling, | ||
| PixArtAlphaTextProjection, | ||
| TextTimeEmbedding, | ||
| ) | ||
| from .timestep import GaussianFourierProjection, TimestepEmbedding, Timesteps, get_timestep_embedding | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,262 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| class TextImageProjection(nn.Module): | ||
| def __init__( | ||
| self, | ||
| text_embed_dim: int = 1024, | ||
| image_embed_dim: int = 768, | ||
| cross_attention_dim: int = 768, | ||
| num_image_text_embeds: int = 10, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.num_image_text_embeds = num_image_text_embeds | ||
| self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | ||
| self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) | ||
|
|
||
| def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | ||
| batch_size = text_embeds.shape[0] | ||
|
|
||
| # image | ||
| image_text_embeds = self.image_embeds(image_embeds) | ||
| image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | ||
|
|
||
| # text | ||
| text_embeds = self.text_proj(text_embeds) | ||
|
|
||
| return torch.cat([image_text_embeds, text_embeds], dim=1) | ||
|
|
||
|
|
||
| class TextImageTimeEmbedding(nn.Module): | ||
| def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): | ||
| super().__init__() | ||
| self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) | ||
| self.text_norm = nn.LayerNorm(time_embed_dim) | ||
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | ||
|
|
||
| def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | ||
| # text | ||
| time_text_embeds = self.text_proj(text_embeds) | ||
| time_text_embeds = self.text_norm(time_text_embeds) | ||
|
|
||
| # image | ||
| time_image_embeds = self.image_proj(image_embeds) | ||
|
|
||
| return time_image_embeds + time_text_embeds | ||
|
|
||
|
|
||
| class ImageHintTimeEmbedding(nn.Module): | ||
| def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | ||
| super().__init__() | ||
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | ||
| self.image_norm = nn.LayerNorm(time_embed_dim) | ||
| self.input_hint_block = nn.Sequential( | ||
| nn.Conv2d(3, 16, 3, padding=1), | ||
| nn.SiLU(), | ||
| nn.Conv2d(16, 16, 3, padding=1), | ||
| nn.SiLU(), | ||
| nn.Conv2d(16, 32, 3, padding=1, stride=2), | ||
| nn.SiLU(), | ||
| nn.Conv2d(32, 32, 3, padding=1), | ||
| nn.SiLU(), | ||
| nn.Conv2d(32, 96, 3, padding=1, stride=2), | ||
| nn.SiLU(), | ||
| nn.Conv2d(96, 96, 3, padding=1), | ||
| nn.SiLU(), | ||
| nn.Conv2d(96, 256, 3, padding=1, stride=2), | ||
| nn.SiLU(), | ||
| nn.Conv2d(256, 4, 3, padding=1), | ||
| ) | ||
|
|
||
| def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): | ||
| # image | ||
| time_image_embeds = self.image_proj(image_embeds) | ||
| time_image_embeds = self.image_norm(time_image_embeds) | ||
| hint = self.input_hint_block(hint) | ||
| return time_image_embeds, hint | ||
|
|
||
|
|
||
| class CombinedTimestepLabelEmbeddings(nn.Module): | ||
| def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): | ||
| super().__init__() | ||
| from .others import LabelEmbedding | ||
| from .timestep import TimestepEmbedding, Timesteps | ||
|
|
||
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) | ||
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | ||
| self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) | ||
|
|
||
| def forward(self, timestep, class_labels, hidden_dtype=None): | ||
| timesteps_proj = self.time_proj(timestep) | ||
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | ||
|
|
||
| class_labels = self.class_embedder(class_labels) # (N, D) | ||
|
|
||
| conditioning = timesteps_emb + class_labels # (N, D) | ||
|
|
||
| return conditioning | ||
|
|
||
|
|
||
| class CombinedTimestepTextProjEmbeddings(nn.Module): | ||
| def __init__(self, embedding_dim, pooled_projection_dim): | ||
| super().__init__() | ||
| from .text import PixArtAlphaTextProjection | ||
| from .timestep import TimestepEmbedding, Timesteps | ||
|
|
||
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | ||
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | ||
| self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | ||
|
|
||
| def forward(self, timestep, pooled_projection): | ||
| timesteps_proj = self.time_proj(timestep) | ||
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | ||
|
|
||
| pooled_projections = self.text_embedder(pooled_projection) | ||
|
|
||
| conditioning = timesteps_emb + pooled_projections | ||
|
|
||
| return conditioning | ||
|
|
||
|
|
||
| class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | ||
| """ | ||
| For PixArt-Alpha. | ||
|
|
||
| Reference: | ||
| https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | ||
| """ | ||
|
|
||
| def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | ||
| super().__init__() | ||
| from .timestep import TimestepEmbedding, Timesteps | ||
|
|
||
| self.outdim = size_emb_dim | ||
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | ||
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | ||
|
|
||
| self.use_additional_conditions = use_additional_conditions | ||
| if use_additional_conditions: | ||
| self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | ||
| self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | ||
| self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | ||
|
|
||
| def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | ||
| timesteps_proj = self.time_proj(timestep) | ||
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | ||
|
|
||
| if self.use_additional_conditions: | ||
| resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) | ||
| resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) | ||
| aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) | ||
| aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) | ||
| conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) | ||
| else: | ||
| conditioning = timesteps_emb | ||
|
|
||
| return conditioning | ||
|
|
||
|
|
||
| class HunyuanDiTAttentionPool(nn.Module): | ||
| # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 | ||
|
|
||
| def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | ||
| super().__init__() | ||
| self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) | ||
| self.k_proj = nn.Linear(embed_dim, embed_dim) | ||
| self.q_proj = nn.Linear(embed_dim, embed_dim) | ||
| self.v_proj = nn.Linear(embed_dim, embed_dim) | ||
| self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | ||
| self.num_heads = num_heads | ||
|
|
||
| def forward(self, x): | ||
| x = x.permute(1, 0, 2) # NLC -> LNC | ||
| x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC | ||
| x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC | ||
| x, _ = F.multi_head_attention_forward( | ||
| query=x[:1], | ||
| key=x, | ||
| value=x, | ||
| embed_dim_to_check=x.shape[-1], | ||
| num_heads=self.num_heads, | ||
| q_proj_weight=self.q_proj.weight, | ||
| k_proj_weight=self.k_proj.weight, | ||
| v_proj_weight=self.v_proj.weight, | ||
| in_proj_weight=None, | ||
| in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | ||
| bias_k=None, | ||
| bias_v=None, | ||
| add_zero_attn=False, | ||
| dropout_p=0, | ||
| out_proj_weight=self.c_proj.weight, | ||
| out_proj_bias=self.c_proj.bias, | ||
| use_separate_proj_weight=True, | ||
| training=self.training, | ||
| need_weights=False, | ||
| ) | ||
| return x.squeeze(0) | ||
|
|
||
|
|
||
| class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): | ||
| def __init__( | ||
| self, | ||
| embedding_dim, | ||
| pooled_projection_dim=1024, | ||
| seq_len=256, | ||
| cross_attention_dim=2048, | ||
| use_style_cond_and_image_meta_size=True, | ||
| ): | ||
| super().__init__() | ||
| from .text import PixArtAlphaTextProjection | ||
| from .timestep import TimestepEmbedding, Timesteps | ||
|
|
||
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | ||
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | ||
|
|
||
| self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | ||
|
|
||
| self.pooler = HunyuanDiTAttentionPool( | ||
| seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim | ||
| ) | ||
|
|
||
| # Here we use a default learned embedder layer for future extension. | ||
| self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size | ||
| if use_style_cond_and_image_meta_size: | ||
| self.style_embedder = nn.Embedding(1, embedding_dim) | ||
| extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim | ||
| else: | ||
| extra_in_dim = pooled_projection_dim | ||
|
|
||
| self.extra_embedder = PixArtAlphaTextProjection( | ||
| in_features=extra_in_dim, | ||
| hidden_size=embedding_dim * 4, | ||
| out_features=embedding_dim, | ||
| act_fn="silu_fp32", | ||
| ) | ||
|
|
||
| def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): | ||
| timesteps_proj = self.time_proj(timestep) | ||
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) | ||
|
|
||
| # extra condition1: text | ||
| pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) | ||
|
|
||
| if self.use_style_cond_and_image_meta_size: | ||
| # extra condition2: image meta size embdding | ||
| image_meta_size = self.size_proj(image_meta_size.view(-1)) | ||
| image_meta_size = image_meta_size.to(dtype=hidden_dtype) | ||
| image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) | ||
|
|
||
| # extra condition3: style embedding | ||
| style_embedding = self.style_embedder(style) # (N, embedding_dim) | ||
|
|
||
| # Concatenate all extra vectors | ||
| extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) | ||
| else: | ||
| extra_cond = torch.cat([pooled_projections], dim=1) | ||
|
|
||
| conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] | ||
|
|
||
| return conditioning |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way nothing should break in terms of imports.