From abf67a88ca3d3332a2b72d57f2a96c2e345c893c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 27 Jun 2024 16:52:50 +0530 Subject: [PATCH 01/15] refactor embeddings --- src/diffusers/models/embeddings.py | 1260 ----------------- src/diffusers/models/embeddings/__init__.py | 34 + src/diffusers/models/embeddings/combined.py | 173 +++ src/diffusers/models/embeddings/image_text.py | 630 +++++++++ src/diffusers/models/embeddings/others.py | 245 ++++ src/diffusers/models/embeddings/position.py | 140 ++ src/diffusers/models/embeddings/timestep.py | 146 ++ 7 files changed, 1368 insertions(+), 1260 deletions(-) create mode 100644 src/diffusers/models/embeddings/__init__.py create mode 100644 src/diffusers/models/embeddings/combined.py create mode 100644 src/diffusers/models/embeddings/image_text.py create mode 100644 src/diffusers/models/embeddings/others.py create mode 100644 src/diffusers/models/embeddings/position.py create mode 100644 src/diffusers/models/embeddings/timestep.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cb64bc61f3e9..196860c9f1c6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -11,1263 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from ..utils import deprecate -from .activations import FP32SiLU, get_activation -from .attention_processor import Attention - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 -): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_size = (grid_size, grid_size) - - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for SD3 cropping.""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=1, - pos_embed_type="sincos", - pos_embed_max_size=None, # For SD3 cropping - ): - super().__init__() - - num_patches = (height // patch_size) * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - self.pos_embed_max_size = pos_embed_max_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - self.patch_size = patch_size - self.height, self.width = height // patch_size, width // patch_size - self.base_size = height // patch_size - self.interpolation_scale = interpolation_scale - - # Calculate positional embeddings based on max size or default - if pos_embed_max_size: - grid_size = pos_embed_max_size - else: - grid_size = int(num_patches**0.5) - - if pos_embed_type is None: - self.pos_embed = None - elif pos_embed_type == "sincos": - pos_embed = get_2d_sincos_pos_embed( - embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - persistent = True if pos_embed_max_size else False - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) - else: - raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") - - def cropped_pos_embed(self, height, width): - """Crops positional embeddings for SD3 compatibility.""" - if self.pos_embed_max_size is None: - raise ValueError("`pos_embed_max_size` must be set for cropping.") - - height = height // self.patch_size - width = width // self.patch_size - if height > self.pos_embed_max_size: - raise ValueError( - f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - if width > self.pos_embed_max_size: - raise ValueError( - f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - - top = (self.pos_embed_max_size - height) // 2 - left = (self.pos_embed_max_size - width) // 2 - spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) - spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] - spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) - return spatial_pos_embed - - def forward(self, latent): - if self.pos_embed_max_size is not None: - height, width = latent.shape[-2:] - else: - height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - if self.pos_embed is None: - return latent.to(latent.dtype) - # Interpolate or crop positional embeddings as needed - if self.pos_embed_max_size: - pos_embed = self.cropped_pos_embed(height, width) - else: - if self.height != height or self.width != width: - pos_embed = get_2d_sincos_pos_embed( - embed_dim=self.pos_embed.shape[-1], - grid_size=(height, width), - base_size=self.base_size, - interpolation_scale=self.interpolation_scale, - ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) - else: - pos_embed = self.pos_embed - - return (latent + pos_embed).to(latent.dtype) - - -def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): - """ - RoPE for image tokens with 2d structure. - - Args: - embed_dim: (`int`): - The embedding dimension size - crops_coords (`Tuple[int]`) - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the positional embedding. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - - Returns: - `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. - """ - start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) # [2, W, H] - - grid = grid.reshape([2, 1, *grid.shape[1:]]) - pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) - return pos_embed - - -def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): - assert embed_dim % 4 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) - emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) - - if use_real: - cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) - sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) - return cos, sin - else: - emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) - return emb - - -def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end - index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 - data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - if isinstance(pos, int): - pos = np.arange(pos) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] - t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] - if use_real: - freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] - return freqs_cos, freqs_sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis - - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - - return out - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - sample_proj_bias=True, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - self.act = get_activation(act_fn) - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) - - if post_act_fn is None: - self.post_act = None - else: - self.post_act = get_activation(post_act_fn) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - - self.weight = self.W - - def forward(self, x): - if self.log: - x = torch.log(x) - - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - - if self.flip_sin_to_cos: - out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) - else: - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out - - -class SinusoidalPositionalEmbedding(nn.Module): - """Apply positional information to a sequence of embeddings. - - Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to - them - - Args: - embed_dim: (int): Dimension of the positional embedding. - max_seq_length: Maximum sequence length to apply positional embeddings - - """ - - def __init__(self, embed_dim: int, max_seq_length: int = 32): - super().__init__() - position = torch.arange(max_seq_length).unsqueeze(1) - div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) - pe = torch.zeros(1, max_seq_length, embed_dim) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) - - def forward(self, x): - _, seq_length, _ = x.shape - x = x + self.pe[:, :seq_length] - return x - - -class ImagePositionalEmbeddings(nn.Module): - """ - Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the - height and width of the latent space. - - For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 - - For VQ-diffusion: - - Output vector embeddings are used as input for the transformer. - - Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. - - Args: - num_embed (`int`): - Number of embeddings for the latent pixels embeddings. - height (`int`): - Height of the latent image i.e. the number of height embeddings. - width (`int`): - Width of the latent image i.e. the number of width embeddings. - embed_dim (`int`): - Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. - """ - - def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, - ): - super().__init__() - - self.height = height - self.width = width - self.num_embed = num_embed - self.embed_dim = embed_dim - - self.emb = nn.Embedding(self.num_embed, embed_dim) - self.height_emb = nn.Embedding(self.height, embed_dim) - self.width_emb = nn.Embedding(self.width, embed_dim) - - def forward(self, index): - emb = self.emb(index) - - height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) - - # 1 x H x D -> 1 x H x 1 x D - height_emb = height_emb.unsqueeze(2) - - width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) - - # 1 x W x D -> 1 x 1 x W x D - width_emb = width_emb.unsqueeze(1) - - pos_emb = height_emb + width_emb - - # 1 x H x W x D -> 1 x L xD - pos_emb = pos_emb.view(1, self.height * self.width, -1) - - emb = emb + pos_emb[:, : emb.shape[1], :] - - return emb - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels: torch.LongTensor, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class TextImageProjection(nn.Module): - def __init__( - self, - text_embed_dim: int = 1024, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 10, - ): - super().__init__() - - self.num_image_text_embeds = num_image_text_embeds - self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) - self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - batch_size = text_embeds.shape[0] - - # image - image_text_embeds = self.image_embeds(image_embeds) - image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - - # text - text_embeds = self.text_proj(text_embeds) - - return torch.cat([image_text_embeds, text_embeds], dim=1) - - -class ImageProjection(nn.Module): - def __init__( - self, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 32, - ): - super().__init__() - - self.num_image_text_embeds = num_image_text_embeds - self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor): - batch_size = image_embeds.shape[0] - - # image - image_embeds = self.image_embeds(image_embeds) - image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - image_embeds = self.norm(image_embeds) - return image_embeds - - -class IPAdapterFullImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): - super().__init__() - from .attention import FeedForward - - self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor): - return self.norm(self.ff(image_embeds)) - - -class IPAdapterFaceIDImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): - super().__init__() - from .attention import FeedForward - - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor): - x = self.ff(image_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - return self.norm(x) - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - class_labels = self.class_embedder(class_labels) # (N, D) - - conditioning = timesteps_emb + class_labels # (N, D) - - return conditioning - - -class CombinedTimestepTextProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") - - def forward(self, timestep, pooled_projection): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - pooled_projections = self.text_embedder(pooled_projection) - - conditioning = timesteps_emb + pooled_projections - - return conditioning - - -class HunyuanDiTAttentionPool(nn.Module): - # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 - - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.permute(1, 0, 2) # NLC -> LNC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - return x.squeeze(0) - - -class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.pooler = HunyuanDiTAttentionPool( - seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim - ) - # Here we use a default learned embedder layer for future extension. - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim - self.extra_embedder = PixArtAlphaTextProjection( - in_features=extra_in_dim, - hidden_size=embedding_dim * 4, - out_features=embedding_dim, - act_fn="silu_fp32", - ) - - def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) - - # extra condition1: text - pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - - # extra condition2: image meta size embdding - image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) - - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) - conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] - - return conditioning - - -class TextTimeEmbedding(nn.Module): - def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): - super().__init__() - self.norm1 = nn.LayerNorm(encoder_dim) - self.pool = AttentionPooling(num_heads, encoder_dim) - self.proj = nn.Linear(encoder_dim, time_embed_dim) - self.norm2 = nn.LayerNorm(time_embed_dim) - - def forward(self, hidden_states): - hidden_states = self.norm1(hidden_states) - hidden_states = self.pool(hidden_states) - hidden_states = self.proj(hidden_states) - hidden_states = self.norm2(hidden_states) - return hidden_states - - -class TextImageTimeEmbedding(nn.Module): - def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) - self.text_norm = nn.LayerNorm(time_embed_dim) - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - # text - time_text_embeds = self.text_proj(text_embeds) - time_text_embeds = self.text_norm(time_text_embeds) - - # image - time_image_embeds = self.image_proj(image_embeds) - - return time_image_embeds + time_text_embeds - - -class ImageTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - - def forward(self, image_embeds: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - return time_image_embeds - - -class ImageHintTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - self.input_hint_block = nn.Sequential( - nn.Conv2d(3, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(256, 4, 3, padding=1), - ) - - def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - hint = self.input_hint_block(hint) - return time_image_embeds, hint - - -class AttentionPooling(nn.Module): - # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 - - def __init__(self, num_heads, embed_dim, dtype=None): - super().__init__() - self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.num_heads = num_heads - self.dim_per_head = embed_dim // self.num_heads - - def forward(self, x): - bs, length, width = x.size() - - def shape(x): - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, -1, self.num_heads, self.dim_per_head) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) - # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) - x = x.transpose(1, 2) - return x - - class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) - x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) - - # (bs*n_heads, class_token_length, dim_per_head) - q = shape(self.q_proj(class_token)) - # (bs*n_heads, length+class_token_length, dim_per_head) - k = shape(self.k_proj(x)) - v = shape(self.v_proj(x)) - - # (bs*n_heads, class_token_length, length+class_token_length): - scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - # (bs*n_heads, dim_per_head, class_token_length) - a = torch.einsum("bts,bcs->bct", weight, v) - - # (bs, length+1, width) - a = a.reshape(bs, -1, 1).transpose(1, 2) - - return a[:, 0, :] # cls_token - - -def get_fourier_embeds_from_boundingbox(embed_dim, box): - """ - Args: - embed_dim: int - box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline - Returns: - [B x N x embed_dim] tensor of positional embeddings - """ - - batch_size, num_boxes = box.shape[:2] - - emb = 100 ** (torch.arange(embed_dim) / embed_dim) - emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) - emb = emb * box.unsqueeze(-1) - - emb = torch.stack((emb.sin(), emb.cos()), dim=-1) - emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) - - return emb - - -class GLIGENTextBoundingboxProjection(nn.Module): - def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): - super().__init__() - self.positive_len = positive_len - self.out_dim = out_dim - - self.fourier_embedder_dim = fourier_freqs - self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy - - if isinstance(out_dim, tuple): - out_dim = out_dim[0] - - if feature_type == "text-only": - self.linears = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - - elif feature_type == "text-image": - self.linears_text = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.linears_image = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - - self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) - - def forward( - self, - boxes, - masks, - positive_embeddings=None, - phrases_masks=None, - image_masks=None, - phrases_embeddings=None, - image_embeddings=None, - ): - masks = masks.unsqueeze(-1) - - # embedding position (it may includes padding as placeholder) - xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C - - # learnable null embedding - xyxy_null = self.null_position_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null - - # positionet with text only information - if positive_embeddings is not None: - # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null - - objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) - - # positionet with text and image information - else: - phrases_masks = phrases_masks.unsqueeze(-1) - image_masks = image_masks.unsqueeze(-1) - - # learnable null embedding - text_null = self.null_text_feature.view(1, 1, -1) - image_null = self.null_image_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null - image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null - - objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) - objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) - objs = torch.cat([objs_text, objs_image], dim=1) - - return objs - - -class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): - """ - For PixArt-Alpha. - - Reference: - https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 - """ - - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): - super().__init__() - - self.outdim = size_emb_dim - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.use_additional_conditions = use_additional_conditions - if use_additional_conditions: - self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) - - def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - if self.use_additional_conditions: - resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) - resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) - aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) - aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) - conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) - else: - conditioning = timesteps_emb - - return conditioning - - -class PixArtAlphaTextProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - - Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py - """ - - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): - super().__init__() - if out_features is None: - out_features = hidden_size - self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) - if act_fn == "gelu_tanh": - self.act_1 = nn.GELU(approximate="tanh") - elif act_fn == "silu": - self.act_1 = nn.SiLU() - elif act_fn == "silu_fp32": - self.act_1 = FP32SiLU() - else: - raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) - - def forward(self, caption): - hidden_states = self.linear_1(caption) - hidden_states = self.act_1(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - -class IPAdapterPlusImageProjectionBlock(nn.Module): - def __init__( - self, - embed_dims: int = 768, - dim_head: int = 64, - heads: int = 16, - ffn_ratio: float = 4, - ) -> None: - super().__init__() - from .attention import FeedForward - - self.ln0 = nn.LayerNorm(embed_dims) - self.ln1 = nn.LayerNorm(embed_dims) - self.attn = Attention( - query_dim=embed_dims, - dim_head=dim_head, - heads=heads, - out_bias=False, - ) - self.ff = nn.Sequential( - nn.LayerNorm(embed_dims), - FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ) - - def forward(self, x, latents, residual): - encoder_hidden_states = self.ln0(x) - latents = self.ln1(latents) - encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) - latents = self.attn(latents, encoder_hidden_states) + residual - latents = self.ff(latents) + latents - return latents - - -class IPAdapterPlusImageProjection(nn.Module): - """Resampler of IP-Adapter Plus. - - Args: - embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, - that is the same - number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): - The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults - to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_queries (int): - The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio - of feedforward network hidden - layer channels. Defaults to 4. - """ - - def __init__( - self, - embed_dims: int = 768, - output_dims: int = 1024, - hidden_dims: int = 1280, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_queries: int = 8, - ffn_ratio: float = 4, - ) -> None: - super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) - - self.proj_in = nn.Linear(embed_dims, hidden_dims) - - self.proj_out = nn.Linear(hidden_dims, output_dims) - self.norm_out = nn.LayerNorm(output_dims) - - self.layers = nn.ModuleList( - [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - x (torch.Tensor): Input Tensor. - Returns: - torch.Tensor: Output Tensor. - """ - latents = self.latents.repeat(x.size(0), 1, 1) - - x = self.proj_in(x) - - for block in self.layers: - residual = latents - latents = block(x, latents, residual) - - latents = self.proj_out(latents) - return self.norm_out(latents) - - -class IPAdapterFaceIDPlusImageProjection(nn.Module): - """FacePerceiverResampler of IP-Adapter Plus. - - Args: - embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, - that is the same - number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): - The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults - to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. - ffn_ratio (float): The expansion ratio of feedforward network hidden - layer channels. Defaults to 4. - ffproj_ratio (float): The expansion ratio of feedforward network hidden - layer channels (for ID embeddings). Defaults to 4. - """ - - def __init__( - self, - embed_dims: int = 768, - output_dims: int = 768, - hidden_dims: int = 1280, - id_embeddings_dim: int = 512, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_tokens: int = 4, - num_queries: int = 8, - ffn_ratio: float = 4, - ffproj_ratio: int = 2, - ) -> None: - super().__init__() - from .attention import FeedForward - - self.num_tokens = num_tokens - self.embed_dim = embed_dims - self.clip_embeds = None - self.shortcut = False - self.shortcut_scale = 1.0 - - self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) - self.norm = nn.LayerNorm(embed_dims) - - self.proj_in = nn.Linear(hidden_dims, embed_dims) - - self.proj_out = nn.Linear(embed_dims, output_dims) - self.norm_out = nn.LayerNorm(output_dims) - - self.layers = nn.ModuleList( - [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] - ) - - def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - id_embeds (torch.Tensor): Input Tensor (ID embeds). - Returns: - torch.Tensor: Output Tensor. - """ - id_embeds = id_embeds.to(self.clip_embeds.dtype) - id_embeds = self.proj(id_embeds) - id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) - id_embeds = self.norm(id_embeds) - latents = id_embeds - - clip_embeds = self.proj_in(self.clip_embeds) - x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) - - for block in self.layers: - residual = latents - latents = block(x, latents, residual) - - latents = self.proj_out(latents) - out = self.norm_out(latents) - if self.shortcut: - out = id_embeds + self.shortcut_scale * out - return out - - -class MultiIPAdapterImageProjection(nn.Module): - def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): - super().__init__() - self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) - - def forward(self, image_embeds: List[torch.Tensor]): - projected_image_embeds = [] - - # currently, we accept `image_embeds` as - # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] - # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] - if not isinstance(image_embeds, list): - deprecation_message = ( - "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." - ) - deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) - image_embeds = [image_embeds.unsqueeze(1)] - - if len(image_embeds) != len(self.image_projection_layers): - raise ValueError( - f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" - ) - - for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): - batch_size, num_images = image_embed.shape[0], image_embed.shape[1] - image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - image_embed = image_projection_layer(image_embed) - image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) - - projected_image_embeds.append(image_embed) - - return projected_image_embeds diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py new file mode 100644 index 000000000000..8e91540dd19a --- /dev/null +++ b/src/diffusers/models/embeddings/__init__.py @@ -0,0 +1,34 @@ +from .combined import ( + CombinedTimestepLabelEmbeddings, + CombinedTimestepTextProjEmbeddings, + HunyuanCombinedTimestepTextSizeStyleEmbedding, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) +from .image_text import ( + ImageHintTimeEmbedding, + ImagePositionalEmbeddings, + ImageProjection, + ImageTimeEmbedding, + IPAdapterFaceIDImageProjection, + IPAdapterFaceIDPlusImageProjection, + IPAdapterFullImageProjection, + IPAdapterPlusImageProjection, + IPAdapterPlusImageProjectionBlock, + MultiIPAdapterImageProjection, + PatchEmbed, + PixArtAlphaTextProjection, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + get_1d_sincos_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_from_grid, +) +from .position import ( + SinusoidalPositionalEmbedding, + apply_rotary_emb, + get_1d_rotary_pos_embed, + get_2d_rotary_pos_embed, + get_2d_rotary_pos_embed_from_grid, +) +from .timestep import TimestepEmbedding, Timesteps, get_timestep_embedding diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py new file mode 100644 index 000000000000..d056b9268789 --- /dev/null +++ b/src/diffusers/models/embeddings/combined.py @@ -0,0 +1,173 @@ +import math + +import torch +import torch.nn as nn + + +# Copied from diffusers.models.embeddings.timestep.get_timestep_embedding +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + from .others import LabelEmbedding + from .timestep import TimestepEmbedding, Timesteps + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + from .image_text import PixArtAlphaTextProjection + from .timestep import TimestepEmbedding, Timesteps + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + from .timestep import TimestepEmbedding, Timesteps + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + super().__init__() + from .image_text import PixArtAlphaTextProjection + from .others import HunyuanDiTAttentionPool + from .timestep import TimestepEmbedding, Timesteps + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + # extra condition2: image meta size embdding + image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py new file mode 100644 index 000000000000..decd1d806148 --- /dev/null +++ b/src/diffusers/models/embeddings/image_text.py @@ -0,0 +1,630 @@ +from typing import List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...utils import deprecate +from ..activations import FP32SiLU + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for SD3 cropping.""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + persistent = True if pos_embed_max_size else False + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class ImageProjection(nn.Module): + def __init__( + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + batch_size = image_embeds.shape[0] + + # image + image_embeds = self.image_embeds(image_embeds) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + image_embeds = self.norm(image_embeds) + return image_embeds + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from ..attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + return self.norm(self.ff(image_embeds)) + + +class IPAdapterFaceIDImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from ..attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class IPAdapterPlusImageProjectionBlock(nn.Module): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from ..attention import Attention, FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + +class IPAdapterPlusImageProjection(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input Tensor. + Returns: + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from ..attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + id_embeds (torch.Tensor): Input Tensor (ID embeds). + Returns: + torch.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + +class MultiIPAdapterImageProjection(nn.Module): + def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): + super().__init__() + self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.Tensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + deprecation_message = ( + "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." + ) + deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + from .others import AttentionPooling + + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint diff --git a/src/diffusers/models/embeddings/others.py b/src/diffusers/models/embeddings/others.py new file mode 100644 index 000000000000..ae5d5cf04738 --- /dev/null +++ b/src/diffusers/models/embeddings/others.py @@ -0,0 +1,245 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ + + batch_size, num_boxes = box.shape[:2] + + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) + + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) + + return emb + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder_dim = fourier_freqs + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C + + # learnable null embedding + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + # positionet with text only information + if positive_embeddings is not None: + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + + # positionet with text and image information + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + # learnable null embedding + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token diff --git a/src/diffusers/models/embeddings/position.py b/src/diffusers/models/embeddings/position.py new file mode 100644 index 000000000000..01bcf579dc94 --- /dev/null +++ b/src/diffusers/models/embeddings/position.py @@ -0,0 +1,140 @@ +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + + +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/src/diffusers/models/embeddings/timestep.py b/src/diffusers/models/embeddings/timestep.py new file mode 100644 index 000000000000..9f5166d0f76e --- /dev/null +++ b/src/diffusers/models/embeddings/timestep.py @@ -0,0 +1,146 @@ +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +from ..activations import get_activation + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + self.weight = self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out From 173e7b0037367b0c342dd35f77ab115abdc41a8b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 27 Jun 2024 17:02:15 +0530 Subject: [PATCH 02/15] fix --- src/diffusers/models/embeddings.py | 13 ------------- src/diffusers/models/embeddings/__init__.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 src/diffusers/models/embeddings.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py deleted file mode 100644 index 196860c9f1c6..000000000000 --- a/src/diffusers/models/embeddings.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index 8e91540dd19a..adc3270e5e2d 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -31,4 +31,4 @@ get_2d_rotary_pos_embed, get_2d_rotary_pos_embed_from_grid, ) -from .timestep import TimestepEmbedding, Timesteps, get_timestep_embedding +from .timestep import GaussianFourierProjection, TimestepEmbedding, Timesteps, get_timestep_embedding From 2a1dff8873f6dbbe0aab3e7d38692d1818d86569 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 27 Jun 2024 17:04:49 +0530 Subject: [PATCH 03/15] fix morer --- src/diffusers/models/embeddings/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index adc3270e5e2d..cc4c13b08c19 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -24,6 +24,13 @@ get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_from_grid, ) +from .others import ( + AttentionPooling, + GLIGENTextBoundingboxProjection, + HunyuanDiTAttentionPool, + LabelEmbedding, + get_fourier_embeds_from_boundingbox, +) from .position import ( SinusoidalPositionalEmbedding, apply_rotary_emb, From acd8461eb9b503459973167232a35ef28f8d6b44 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 13:21:49 +0530 Subject: [PATCH 04/15] reflect changes from HunyuanCombinedTimestepTextSizeStyleEmbedding Co-authored-by: Yiyi Xu --- src/diffusers/models/embeddings/combined.py | 50 ++------------------- 1 file changed, 3 insertions(+), 47 deletions(-) diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index d056b9268789..a2996cb8b4b6 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -1,53 +1,7 @@ -import math - import torch import torch.nn as nn -# Copied from diffusers.models.embeddings.timestep.get_timestep_embedding -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -138,6 +92,8 @@ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.pooler = HunyuanDiTAttentionPool( seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim ) @@ -159,7 +115,7 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) # extra condition2: image meta size embdding - image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) + image_meta_size = self.size_proj(image_meta_size.view(-1)) image_meta_size = image_meta_size.to(dtype=hidden_dtype) image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) From f8886857fea1ab2509c7dd79b0b830c9a9a18762 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 13:37:33 +0530 Subject: [PATCH 05/15] fix --- src/diffusers/models/embeddings/combined.py | 39 +++++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index a2996cb8b4b6..6f0c87ff2be7 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -83,7 +83,14 @@ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): super().__init__() from .image_text import PixArtAlphaTextProjection from .others import HunyuanDiTAttentionPool @@ -97,9 +104,15 @@ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross self.pooler = HunyuanDiTAttentionPool( seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim ) + # Here we use a default learned embedder layer for future extension. - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + self.extra_embedder = PixArtAlphaTextProjection( in_features=extra_in_dim, hidden_size=embedding_dim * 4, @@ -114,16 +127,20 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde # extra condition1: text pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - # extra condition2: image meta size embdding - image_meta_size = self.size_proj(image_meta_size.view(-1)) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embdding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] return conditioning From 28478a69e5a819cc10ec4f8d97a8be02a2e010b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 13:40:24 +0530 Subject: [PATCH 06/15] patch embedding to position --- src/diffusers/models/embeddings/__init__.py | 8 +- src/diffusers/models/embeddings/position.py | 163 ++++++++++++++++++++ 2 files changed, 167 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index cc4c13b08c19..c7ed25039da9 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -15,14 +15,10 @@ IPAdapterPlusImageProjection, IPAdapterPlusImageProjectionBlock, MultiIPAdapterImageProjection, - PatchEmbed, PixArtAlphaTextProjection, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, - get_1d_sincos_pos_embed_from_grid, - get_2d_sincos_pos_embed, - get_2d_sincos_pos_embed_from_grid, ) from .others import ( AttentionPooling, @@ -32,10 +28,14 @@ get_fourier_embeds_from_boundingbox, ) from .position import ( + PatchEmbed, SinusoidalPositionalEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, + get_1d_sincos_pos_embed_from_grid, get_2d_rotary_pos_embed, get_2d_rotary_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_from_grid, ) from .timestep import GaussianFourierProjection, TimestepEmbedding, Timesteps, get_timestep_embedding diff --git a/src/diffusers/models/embeddings/position.py b/src/diffusers/models/embeddings/position.py index 01bcf579dc94..0eb9e856b1f9 100644 --- a/src/diffusers/models/embeddings/position.py +++ b/src/diffusers/models/embeddings/position.py @@ -113,6 +113,169 @@ def apply_rotary_emb( return out +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for SD3 cropping.""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + persistent = True if pos_embed_max_size else False + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + class SinusoidalPositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. From b8876b6351f7e67fed09f50e2f07a1b25d539af4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 13:41:39 +0530 Subject: [PATCH 07/15] fix --- src/diffusers/models/embeddings/image_text.py | 163 ------------------ 1 file changed, 163 deletions(-) diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py index decd1d806148..d27067c63379 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image_text.py @@ -8,169 +8,6 @@ from ..activations import FP32SiLU -def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 -): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_size = (grid_size, grid_size) - - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for SD3 cropping.""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=1, - pos_embed_type="sincos", - pos_embed_max_size=None, # For SD3 cropping - ): - super().__init__() - - num_patches = (height // patch_size) * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - self.pos_embed_max_size = pos_embed_max_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - self.patch_size = patch_size - self.height, self.width = height // patch_size, width // patch_size - self.base_size = height // patch_size - self.interpolation_scale = interpolation_scale - - # Calculate positional embeddings based on max size or default - if pos_embed_max_size: - grid_size = pos_embed_max_size - else: - grid_size = int(num_patches**0.5) - - if pos_embed_type is None: - self.pos_embed = None - elif pos_embed_type == "sincos": - pos_embed = get_2d_sincos_pos_embed( - embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - persistent = True if pos_embed_max_size else False - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) - else: - raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") - - def cropped_pos_embed(self, height, width): - """Crops positional embeddings for SD3 compatibility.""" - if self.pos_embed_max_size is None: - raise ValueError("`pos_embed_max_size` must be set for cropping.") - - height = height // self.patch_size - width = width // self.patch_size - if height > self.pos_embed_max_size: - raise ValueError( - f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - if width > self.pos_embed_max_size: - raise ValueError( - f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - - top = (self.pos_embed_max_size - height) // 2 - left = (self.pos_embed_max_size - width) // 2 - spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) - spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] - spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) - return spatial_pos_embed - - def forward(self, latent): - if self.pos_embed_max_size is not None: - height, width = latent.shape[-2:] - else: - height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - if self.pos_embed is None: - return latent.to(latent.dtype) - # Interpolate or crop positional embeddings as needed - if self.pos_embed_max_size: - pos_embed = self.cropped_pos_embed(height, width) - else: - if self.height != height or self.width != width: - pos_embed = get_2d_sincos_pos_embed( - embed_dim=self.pos_embed.shape[-1], - grid_size=(height, width), - base_size=self.base_size, - interpolation_scale=self.interpolation_scale, - ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) - else: - pos_embed = self.pos_embed - - return (latent + pos_embed).to(latent.dtype) - - class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the From 50932aaa3544341586af7a1449ce79519ba3c54c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Jul 2024 13:42:20 +0530 Subject: [PATCH 08/15] style --- src/diffusers/models/embeddings/image_text.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py index d27067c63379..49e7538b5a7d 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image_text.py @@ -1,6 +1,5 @@ from typing import List, Tuple, Union -import numpy as np import torch import torch.nn as nn From 733609ae8bf40267b41196b7be1265a05d4abf12 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 3 Jul 2024 08:22:06 +0530 Subject: [PATCH 09/15] up --- src/diffusers/models/embeddings/__init__.py | 6 +- src/diffusers/models/embeddings/combined.py | 79 ++++++++++- src/diffusers/models/embeddings/image_text.py | 53 ++++++- src/diffusers/models/embeddings/others.py | 129 ------------------ 4 files changed, 131 insertions(+), 136 deletions(-) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index c7ed25039da9..aebd610bcdee 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -2,9 +2,12 @@ CombinedTimestepLabelEmbeddings, CombinedTimestepTextProjEmbeddings, HunyuanCombinedTimestepTextSizeStyleEmbedding, + HunyuanDiTAttentionPool, + LabelEmbedding, PixArtAlphaCombinedTimestepSizeEmbeddings, ) from .image_text import ( + AttentionPooling, ImageHintTimeEmbedding, ImagePositionalEmbeddings, ImageProjection, @@ -21,10 +24,7 @@ TextTimeEmbedding, ) from .others import ( - AttentionPooling, GLIGENTextBoundingboxProjection, - HunyuanDiTAttentionPool, - LabelEmbedding, get_fourier_embeds_from_boundingbox, ) from .position import ( diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index 6f0c87ff2be7..b1c05bc7c4f8 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -1,11 +1,47 @@ import torch import torch.nn as nn +import torch.nn.functional as F + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() - from .others import LabelEmbedding from .timestep import TimestepEmbedding, Timesteps self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) @@ -82,6 +118,46 @@ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): return conditioning +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): def __init__( self, @@ -93,7 +169,6 @@ def __init__( ): super().__init__() from .image_text import PixArtAlphaTextProjection - from .others import HunyuanDiTAttentionPool from .timestep import TimestepEmbedding, Timesteps self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py index 49e7538b5a7d..fab5172defb3 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image_text.py @@ -1,3 +1,4 @@ +import math from typing import List, Tuple, Union import torch @@ -386,11 +387,59 @@ def forward(self, image_embeds: List[torch.Tensor]): return projected_image_embeds +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() - from .others import AttentionPooling - self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) diff --git a/src/diffusers/models/embeddings/others.py b/src/diffusers/models/embeddings/others.py index ae5d5cf04738..295cc61eebe6 100644 --- a/src/diffusers/models/embeddings/others.py +++ b/src/diffusers/models/embeddings/others.py @@ -1,8 +1,5 @@ -import math - import torch import torch.nn as nn -import torch.nn.functional as F def get_fourier_embeds_from_boundingbox(embed_dim, box): @@ -117,129 +114,3 @@ def forward( objs = torch.cat([objs_text, objs_image], dim=1) return objs - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels: torch.LongTensor, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class HunyuanDiTAttentionPool(nn.Module): - # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 - - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.permute(1, 0, 2) # NLC -> LNC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - return x.squeeze(0) - - -class AttentionPooling(nn.Module): - # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 - - def __init__(self, num_heads, embed_dim, dtype=None): - super().__init__() - self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.num_heads = num_heads - self.dim_per_head = embed_dim // self.num_heads - - def forward(self, x): - bs, length, width = x.size() - - def shape(x): - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, -1, self.num_heads, self.dim_per_head) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) - # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) - x = x.transpose(1, 2) - return x - - class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) - x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) - - # (bs*n_heads, class_token_length, dim_per_head) - q = shape(self.q_proj(class_token)) - # (bs*n_heads, length+class_token_length, dim_per_head) - k = shape(self.k_proj(x)) - v = shape(self.v_proj(x)) - - # (bs*n_heads, class_token_length, length+class_token_length): - scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - # (bs*n_heads, dim_per_head, class_token_length) - a = torch.einsum("bts,bcs->bct", weight, v) - - # (bs, length+1, width) - a = a.reshape(bs, -1, 1).transpose(1, 2) - - return a[:, 0, :] # cls_token From f94e3fc0163c0fb882dccb78cfa7439ac8f718e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 3 Jul 2024 08:52:13 +0530 Subject: [PATCH 10/15] move hunyuan stuff to image_text --- src/diffusers/models/embeddings/__init__.py | 4 +- src/diffusers/models/embeddings/combined.py | 104 ------------------ src/diffusers/models/embeddings/image_text.py | 104 ++++++++++++++++++ 3 files changed, 106 insertions(+), 106 deletions(-) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index aebd610bcdee..7ee962e18a12 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -1,13 +1,13 @@ from .combined import ( CombinedTimestepLabelEmbeddings, CombinedTimestepTextProjEmbeddings, - HunyuanCombinedTimestepTextSizeStyleEmbedding, - HunyuanDiTAttentionPool, LabelEmbedding, PixArtAlphaCombinedTimestepSizeEmbeddings, ) from .image_text import ( AttentionPooling, + HunyuanCombinedTimestepTextSizeStyleEmbedding, + HunyuanDiTAttentionPool, ImageHintTimeEmbedding, ImagePositionalEmbeddings, ImageProjection, diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index b1c05bc7c4f8..7607d0af1023 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F class LabelEmbedding(nn.Module): @@ -116,106 +115,3 @@ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): conditioning = timesteps_emb return conditioning - - -class HunyuanDiTAttentionPool(nn.Module): - # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 - - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.permute(1, 0, 2) # NLC -> LNC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - return x.squeeze(0) - - -class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__( - self, - embedding_dim, - pooled_projection_dim=1024, - seq_len=256, - cross_attention_dim=2048, - use_style_cond_and_image_meta_size=True, - ): - super().__init__() - from .image_text import PixArtAlphaTextProjection - from .timestep import TimestepEmbedding, Timesteps - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - - self.pooler = HunyuanDiTAttentionPool( - seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim - ) - - # Here we use a default learned embedder layer for future extension. - self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size - if use_style_cond_and_image_meta_size: - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim - else: - extra_in_dim = pooled_projection_dim - - self.extra_embedder = PixArtAlphaTextProjection( - in_features=extra_in_dim, - hidden_size=embedding_dim * 4, - out_features=embedding_dim, - act_fn="silu_fp32", - ) - - def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) - - # extra condition1: text - pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - - if self.use_style_cond_and_image_meta_size: - # extra condition2: image meta size embdding - image_meta_size = self.size_proj(image_meta_size.view(-1)) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) - - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) - else: - extra_cond = torch.cat([pooled_projections], dim=1) - - conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] - - return conditioning diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py index fab5172defb3..eeb55d969bb0 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image_text.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...utils import deprecate from ..activations import FP32SiLU @@ -513,3 +514,106 @@ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): + super().__init__() + from .image_text import PixArtAlphaTextProjection + from .timestep import TimestepEmbedding, Timesteps + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + + # Here we use a default learned embedder layer for future extension. + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embdding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) + + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning From 91b75c64061ed2d37220dcc5f9262c43239f9d2d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 07:52:09 +0530 Subject: [PATCH 11/15] move labelembedding to others --- src/diffusers/models/embeddings/__init__.py | 2 +- src/diffusers/models/embeddings/combined.py | 37 +-------------------- src/diffusers/models/embeddings/others.py | 36 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index 7ee962e18a12..1f60f0a884cd 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -1,7 +1,6 @@ from .combined import ( CombinedTimestepLabelEmbeddings, CombinedTimestepTextProjEmbeddings, - LabelEmbedding, PixArtAlphaCombinedTimestepSizeEmbeddings, ) from .image_text import ( @@ -26,6 +25,7 @@ from .others import ( GLIGENTextBoundingboxProjection, get_fourier_embeds_from_boundingbox, + LabelEmbedding, ) from .position import ( PatchEmbed, diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index 7607d0af1023..c5f9c1dec2af 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -2,46 +2,11 @@ import torch.nn as nn -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels: torch.LongTensor, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() from .timestep import TimestepEmbedding, Timesteps + from .others import LabelEmbedding self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) diff --git a/src/diffusers/models/embeddings/others.py b/src/diffusers/models/embeddings/others.py index 295cc61eebe6..cb7edb78eb14 100644 --- a/src/diffusers/models/embeddings/others.py +++ b/src/diffusers/models/embeddings/others.py @@ -114,3 +114,39 @@ def forward( objs = torch.cat([objs_text, objs_image], dim=1) return objs + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings \ No newline at end of file From 498ec778d30c95266104a1cffdc1d134be0dba29 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 07:55:35 +0530 Subject: [PATCH 12/15] move more to combibed. --- src/diffusers/models/embeddings/__init__.py | 6 +- src/diffusers/models/embeddings/combined.py | 76 ++++++++++++++++++ src/diffusers/models/embeddings/image_text.py | 77 ------------------- 3 files changed, 79 insertions(+), 80 deletions(-) diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index 1f60f0a884cd..412b7cdce5b7 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -2,12 +2,14 @@ CombinedTimestepLabelEmbeddings, CombinedTimestepTextProjEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings, + ImageHintTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, ) from .image_text import ( AttentionPooling, HunyuanCombinedTimestepTextSizeStyleEmbedding, HunyuanDiTAttentionPool, - ImageHintTimeEmbedding, ImagePositionalEmbeddings, ImageProjection, ImageTimeEmbedding, @@ -18,8 +20,6 @@ IPAdapterPlusImageProjectionBlock, MultiIPAdapterImageProjection, PixArtAlphaTextProjection, - TextImageProjection, - TextImageTimeEmbedding, TextTimeEmbedding, ) from .others import ( diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index c5f9c1dec2af..c3074194d1a9 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -2,6 +2,82 @@ import torch.nn as nn +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint + + class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image_text.py index eeb55d969bb0..301a5d400a4a 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image_text.py @@ -74,33 +74,6 @@ def forward(self, index): return emb -class TextImageProjection(nn.Module): - def __init__( - self, - text_embed_dim: int = 1024, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 10, - ): - super().__init__() - - self.num_image_text_embeds = num_image_text_embeds - self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) - self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - batch_size = text_embeds.shape[0] - - # image - image_text_embeds = self.image_embeds(image_embeds) - image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - - # text - text_embeds = self.text_proj(text_embeds) - - return torch.cat([image_text_embeds, text_embeds], dim=1) - - class ImageProjection(nn.Module): def __init__( self, @@ -454,24 +427,6 @@ def forward(self, hidden_states): return hidden_states -class TextImageTimeEmbedding(nn.Module): - def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) - self.text_norm = nn.LayerNorm(time_embed_dim) - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - # text - time_text_embeds = self.text_proj(text_embeds) - time_text_embeds = self.text_norm(time_text_embeds) - - # image - time_image_embeds = self.image_proj(image_embeds) - - return time_image_embeds + time_text_embeds - - class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() @@ -484,38 +439,6 @@ def forward(self, image_embeds: torch.Tensor): time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds - -class ImageHintTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - self.input_hint_block = nn.Sequential( - nn.Conv2d(3, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(256, 4, 3, padding=1), - ) - - def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - hint = self.input_hint_block(hint) - return time_image_embeds, hint - - class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 From 792deca54038540be70e23fb0e20719fca364073 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 08:46:15 +0530 Subject: [PATCH 13/15] up --- src/diffusers/models/embeddings/__init__.py | 18 +- src/diffusers/models/embeddings/combined.py | 110 ++++++++- .../embeddings/{image_text.py => image.py} | 226 +----------------- src/diffusers/models/embeddings/others.py | 2 +- src/diffusers/models/embeddings/text.py | 101 ++++++++ 5 files changed, 232 insertions(+), 225 deletions(-) rename src/diffusers/models/embeddings/{image_text.py => image.py} (59%) create mode 100644 src/diffusers/models/embeddings/text.py diff --git a/src/diffusers/models/embeddings/__init__.py b/src/diffusers/models/embeddings/__init__.py index 412b7cdce5b7..3ab1a6681bc9 100644 --- a/src/diffusers/models/embeddings/__init__.py +++ b/src/diffusers/models/embeddings/__init__.py @@ -1,15 +1,14 @@ from .combined import ( CombinedTimestepLabelEmbeddings, CombinedTimestepTextProjEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings, + HunyuanCombinedTimestepTextSizeStyleEmbedding, + HunyuanDiTAttentionPool, ImageHintTimeEmbedding, + PixArtAlphaCombinedTimestepSizeEmbeddings, TextImageProjection, TextImageTimeEmbedding, ) -from .image_text import ( - AttentionPooling, - HunyuanCombinedTimestepTextSizeStyleEmbedding, - HunyuanDiTAttentionPool, +from .image import ( ImagePositionalEmbeddings, ImageProjection, ImageTimeEmbedding, @@ -19,13 +18,11 @@ IPAdapterPlusImageProjection, IPAdapterPlusImageProjectionBlock, MultiIPAdapterImageProjection, - PixArtAlphaTextProjection, - TextTimeEmbedding, ) from .others import ( GLIGENTextBoundingboxProjection, - get_fourier_embeds_from_boundingbox, LabelEmbedding, + get_fourier_embeds_from_boundingbox, ) from .position import ( PatchEmbed, @@ -38,4 +35,9 @@ get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_from_grid, ) +from .text import ( + AttentionPooling, + PixArtAlphaTextProjection, + TextTimeEmbedding, +) from .timestep import GaussianFourierProjection, TimestepEmbedding, Timesteps, get_timestep_embedding diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index c3074194d1a9..eab7280dd8dc 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F class TextImageProjection(nn.Module): @@ -27,7 +28,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) - + class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): @@ -45,7 +46,7 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds - + class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): @@ -81,8 +82,8 @@ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() - from .timestep import TimestepEmbedding, Timesteps from .others import LabelEmbedding + from .timestep import TimestepEmbedding, Timesteps self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -156,3 +157,106 @@ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): conditioning = timesteps_emb return conditioning + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): + super().__init__() + from .text import PixArtAlphaTextProjection + from .timestep import TimestepEmbedding, Timesteps + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + + # Here we use a default learned embedder layer for future extension. + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embdding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) + + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning diff --git a/src/diffusers/models/embeddings/image_text.py b/src/diffusers/models/embeddings/image.py similarity index 59% rename from src/diffusers/models/embeddings/image_text.py rename to src/diffusers/models/embeddings/image.py index 301a5d400a4a..5733cd5cabf1 100644 --- a/src/diffusers/models/embeddings/image_text.py +++ b/src/diffusers/models/embeddings/image.py @@ -1,12 +1,9 @@ -import math from typing import List, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from ...utils import deprecate -from ..activations import FP32SiLU class ImagePositionalEmbeddings(nn.Module): @@ -97,6 +94,19 @@ def forward(self, image_embeds: torch.Tensor): return image_embeds +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() @@ -125,35 +135,6 @@ def forward(self, image_embeds: torch.Tensor): return self.norm(x) -class PixArtAlphaTextProjection(nn.Module): - """ - Projects caption embeddings. Also handles dropout for classifier-free guidance. - - Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py - """ - - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): - super().__init__() - if out_features is None: - out_features = hidden_size - self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) - if act_fn == "gelu_tanh": - self.act_1 = nn.GELU(approximate="tanh") - elif act_fn == "silu": - self.act_1 = nn.SiLU() - elif act_fn == "silu_fp32": - self.act_1 = FP32SiLU() - else: - raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) - - def forward(self, caption): - hidden_states = self.linear_1(caption) - hidden_states = self.act_1(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self, @@ -359,184 +340,3 @@ def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds.append(image_embed) return projected_image_embeds - - -class AttentionPooling(nn.Module): - # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 - - def __init__(self, num_heads, embed_dim, dtype=None): - super().__init__() - self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.num_heads = num_heads - self.dim_per_head = embed_dim // self.num_heads - - def forward(self, x): - bs, length, width = x.size() - - def shape(x): - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, -1, self.num_heads, self.dim_per_head) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) - # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) - x = x.transpose(1, 2) - return x - - class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) - x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) - - # (bs*n_heads, class_token_length, dim_per_head) - q = shape(self.q_proj(class_token)) - # (bs*n_heads, length+class_token_length, dim_per_head) - k = shape(self.k_proj(x)) - v = shape(self.v_proj(x)) - - # (bs*n_heads, class_token_length, length+class_token_length): - scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - # (bs*n_heads, dim_per_head, class_token_length) - a = torch.einsum("bts,bcs->bct", weight, v) - - # (bs, length+1, width) - a = a.reshape(bs, -1, 1).transpose(1, 2) - - return a[:, 0, :] # cls_token - - -class TextTimeEmbedding(nn.Module): - def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): - super().__init__() - self.norm1 = nn.LayerNorm(encoder_dim) - self.pool = AttentionPooling(num_heads, encoder_dim) - self.proj = nn.Linear(encoder_dim, time_embed_dim) - self.norm2 = nn.LayerNorm(time_embed_dim) - - def forward(self, hidden_states): - hidden_states = self.norm1(hidden_states) - hidden_states = self.pool(hidden_states) - hidden_states = self.proj(hidden_states) - hidden_states = self.norm2(hidden_states) - return hidden_states - - -class ImageTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - - def forward(self, image_embeds: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - return time_image_embeds - -class HunyuanDiTAttentionPool(nn.Module): - # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 - - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.permute(1, 0, 2) # NLC -> LNC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - return x.squeeze(0) - - -class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__( - self, - embedding_dim, - pooled_projection_dim=1024, - seq_len=256, - cross_attention_dim=2048, - use_style_cond_and_image_meta_size=True, - ): - super().__init__() - from .image_text import PixArtAlphaTextProjection - from .timestep import TimestepEmbedding, Timesteps - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - - self.pooler = HunyuanDiTAttentionPool( - seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim - ) - - # Here we use a default learned embedder layer for future extension. - self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size - if use_style_cond_and_image_meta_size: - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim - else: - extra_in_dim = pooled_projection_dim - - self.extra_embedder = PixArtAlphaTextProjection( - in_features=extra_in_dim, - hidden_size=embedding_dim * 4, - out_features=embedding_dim, - act_fn="silu_fp32", - ) - - def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) - - # extra condition1: text - pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - - if self.use_style_cond_and_image_meta_size: - # extra condition2: image meta size embdding - image_meta_size = self.size_proj(image_meta_size.view(-1)) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) - - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) - else: - extra_cond = torch.cat([pooled_projections], dim=1) - - conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] - - return conditioning diff --git a/src/diffusers/models/embeddings/others.py b/src/diffusers/models/embeddings/others.py index cb7edb78eb14..17dba71dbc7e 100644 --- a/src/diffusers/models/embeddings/others.py +++ b/src/diffusers/models/embeddings/others.py @@ -149,4 +149,4 @@ def forward(self, labels: torch.LongTensor, force_drop_ids=None): if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) - return embeddings \ No newline at end of file + return embeddings diff --git a/src/diffusers/models/embeddings/text.py b/src/diffusers/models/embeddings/text.py new file mode 100644 index 000000000000..ea520d108ac0 --- /dev/null +++ b/src/diffusers/models/embeddings/text.py @@ -0,0 +1,101 @@ +import math + +import torch +import torch.nn as nn + +from ..activations import FP32SiLU + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states From 48d7a2862f0d3ea63f2a5acd521e21bccf5449bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 09:03:40 +0530 Subject: [PATCH 14/15] fix import --- src/diffusers/models/embeddings/combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings/combined.py b/src/diffusers/models/embeddings/combined.py index eab7280dd8dc..c33e2a14add6 100644 --- a/src/diffusers/models/embeddings/combined.py +++ b/src/diffusers/models/embeddings/combined.py @@ -103,7 +103,7 @@ def forward(self, timestep, class_labels, hidden_dtype=None): class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() - from .image_text import PixArtAlphaTextProjection + from .text import PixArtAlphaTextProjection from .timestep import TimestepEmbedding, Timesteps self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) From a61d541f51051b9a11409b87cfc90fccda2d55f2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Jul 2024 09:04:38 +0530 Subject: [PATCH 15/15] changes from https://github.com/huggingface/diffusers/pull/8764 --- src/diffusers/models/embeddings/timestep.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/embeddings/timestep.py b/src/diffusers/models/embeddings/timestep.py index 9f5166d0f76e..d401119b0cf4 100644 --- a/src/diffusers/models/embeddings/timestep.py +++ b/src/diffusers/models/embeddings/timestep.py @@ -129,9 +129,11 @@ def __init__( if set_W_to_weight: # to delete later + del self.weight self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W + del self.W def forward(self, x): if self.log: