Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,293 changes: 0 additions & 1,293 deletions src/diffusers/models/embeddings.py

This file was deleted.

43 changes: 43 additions & 0 deletions src/diffusers/models/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .combined import (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way nothing should break in terms of imports.

CombinedTimestepLabelEmbeddings,
CombinedTimestepTextProjEmbeddings,
HunyuanCombinedTimestepTextSizeStyleEmbedding,
HunyuanDiTAttentionPool,
ImageHintTimeEmbedding,
PixArtAlphaCombinedTimestepSizeEmbeddings,
TextImageProjection,
TextImageTimeEmbedding,
)
from .image import (
ImagePositionalEmbeddings,
ImageProjection,
ImageTimeEmbedding,
IPAdapterFaceIDImageProjection,
IPAdapterFaceIDPlusImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
IPAdapterPlusImageProjectionBlock,
MultiIPAdapterImageProjection,
)
from .others import (
GLIGENTextBoundingboxProjection,
LabelEmbedding,
get_fourier_embeds_from_boundingbox,
)
from .position import (
PatchEmbed,
SinusoidalPositionalEmbedding,
apply_rotary_emb,
get_1d_rotary_pos_embed,
get_1d_sincos_pos_embed_from_grid,
get_2d_rotary_pos_embed,
get_2d_rotary_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_from_grid,
)
from .text import (
AttentionPooling,
PixArtAlphaTextProjection,
TextTimeEmbedding,
)
from .timestep import GaussianFourierProjection, TimestepEmbedding, Timesteps, get_timestep_embedding
262 changes: 262 additions & 0 deletions src/diffusers/models/embeddings/combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class TextImageProjection(nn.Module):
def __init__(
self,
text_embed_dim: int = 1024,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 10,
):
super().__init__()

self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
batch_size = text_embeds.shape[0]

# image
image_text_embeds = self.image_embeds(image_embeds)
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)

# text
text_embeds = self.text_proj(text_embeds)

return torch.cat([image_text_embeds, text_embeds], dim=1)


class TextImageTimeEmbedding(nn.Module):
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
# text
time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds)

# image
time_image_embeds = self.image_proj(image_embeds)

return time_image_embeds + time_text_embeds


class ImageHintTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 96, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(96, 96, 3, padding=1),
nn.SiLU(),
nn.Conv2d(96, 256, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(256, 4, 3, padding=1),
)

def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
hint = self.input_hint_block(hint)
return time_image_embeds, hint


class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
from .others import LabelEmbedding
from .timestep import TimestepEmbedding, Timesteps

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)

def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)

class_labels = self.class_embedder(class_labels) # (N, D)

conditioning = timesteps_emb + class_labels # (N, D)

return conditioning


class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
from .text import PixArtAlphaTextProjection
from .timestep import TimestepEmbedding, Timesteps

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)

pooled_projections = self.text_embedder(pooled_projection)

conditioning = timesteps_emb + pooled_projections

return conditioning


class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.

Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""

def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
from .timestep import TimestepEmbedding, Timesteps

self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)

def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)

if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb

return conditioning


class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6

def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads

def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x.squeeze(0)


class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(
self,
embedding_dim,
pooled_projection_dim=1024,
seq_len=256,
cross_attention_dim=2048,
use_style_cond_and_image_meta_size=True,
):
super().__init__()
from .text import PixArtAlphaTextProjection
from .timestep import TimestepEmbedding, Timesteps

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)

self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)

# Here we use a default learned embedder layer for future extension.
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
if use_style_cond_and_image_meta_size:
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
else:
extra_in_dim = pooled_projection_dim

self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
out_features=embedding_dim,
act_fn="silu_fp32",
)

def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)

# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)

if self.use_style_cond_and_image_meta_size:
# extra condition2: image meta size embdding
image_meta_size = self.size_proj(image_meta_size.view(-1))
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)

# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)

# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
else:
extra_cond = torch.cat([pooled_projections], dim=1)

conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]

return conditioning
Loading