From fdc6fd7bd692a7b715943a94b7bbddae02c547e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 11 Oct 2024 22:21:42 +0200 Subject: [PATCH] update --- src/diffusers/models/embeddings.py | 727 +++++++++++++++++++---------- 1 file changed, 474 insertions(+), 253 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c250df29afbe..e48d8925142f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -78,6 +78,53 @@ def get_timestep_embedding( return emb +def aryan_get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, + device: Optional[torch.device] = None, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first + grid = torch.stack(grid).reshape(2, 1, spatial_size[1], spatial_size[0]) + pos_embed_spatial = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale + pos_embed_temporal = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + # [T, H*W, D // 4 * 3] + pos_embed_spatial = pos_embed_spatial.unsqueeze(0).repeat(temporal_size, 1, 1) + + # [T, H*W, D // 4] + pos_embed_temporal = pos_embed_temporal.unsqueeze(1).repeat(1, spatial_size[0] * spatial_size[1], 1) + + # [T, H*W, D] + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=2) + return pos_embed + + def get_3d_sincos_pos_embed( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], @@ -125,6 +172,29 @@ def get_3d_sincos_pos_embed( return pos_embed +def aryan_get_2d_sincos_pos_embed( + embed_dim: int, grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, extra_tokens: int = 0, interpolation_scale: float = 1.0, base_size: int = 16, device: Optional[torch.device] = None, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = torch.arange(grid_size[0], device=device, dtype=torch.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = torch.arange(grid_size[1], device=device, dtype=torch.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first + grid = torch.stack(grid).reshape(2, 1, grid_size[1], grid_size[0]) + + pos_embed = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + if cls_token and extra_tokens > 0: + pos_embed = torch.cat([pos_embed.new_zeros(extra_tokens, embed_dim), pos_embed]) + + return pos_embed + + def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): @@ -147,6 +217,19 @@ def get_2d_sincos_pos_embed( return pos_embed +def aryan_get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + # use half of dimensions to encode grid_w + emb_w = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") @@ -159,6 +242,27 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): return emb +def aryan_get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) + out = pos.reshape(-1)[:, None] * omega[None, :] # (M, D/2), outer product + + emb_sin = torch.sin(out) + emb_cos = torch.cos(out) + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return emb + + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) @@ -170,14 +274,378 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def aryan_get_3d_rotary_pos_embed( + embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], temporal_size: int, theta: int = 10000, use_real: bool = True, device: Optional[torch.device] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int, int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int, int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError("`use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32) + grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32) + grid_t = torch.linspace(0, 0 + (temporal_size - 0) * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +def get_3d_rotary_pos_embed( + embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +def aryan_get_2d_rotary_pos_embed(embed_dim: int, crops_coords: Tuple[int, int], grid_size: Tuple[int, int], use_real: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int, int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `(grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = torch.linspace(start[0], start[0] + (stop[0] - start[0]) * (grid_size[0] - 1) / grid_size[0], grid_size[0], dtype=torch.float32) + grid_w = torch.linspace(start[1], start[1] + (stop[1] - start[1]) * (grid_size[1] - 1) / grid_size[1], grid_size[1], dtype=torch.float32) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here, w goes first + grid = torch.stack(grid) + grid = grid.reshape(2, 1, *grid.shape[1:]) + + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def aryan_get_2d_rotary_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor, use_real: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_2d_rotary_pos_embed_lumina(embed_dim: int, len_h: int, len_w: int, linear_factor: float = 1.0, ntk_factor: float = 1.0) -> torch.Tensor: + assert embed_dim % 4 == 0 + + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (H, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (W, D/4) + emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) + emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) + + emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.Tensor, np.ndarray, int], + theta: float = 10000.0, + use_real: bool = False, + linear_factor: float = 1.0, + ntk_factor: float = 1.0, + repeat_interleave_real: bool = True, + freqs_dtype: torch.dtype = torch.float32, # torch.float32, torch.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`torch.Tensor` or `np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + the dtype of the frequency tensor. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb + return x_out.type_as(x) class PatchEmbed(nn.Module): @@ -442,253 +910,6 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): return embeds -def get_3d_rotary_pos_embed( - embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - RoPE for video tokens with 3D structure. - - Args: - embed_dim: (`int`): - The embedding dimension size, corresponding to hidden_size_head. - crops_coords (`Tuple[int]`): - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the spatial positional embedding (height, width). - temporal_size (`int`): - The size of the temporal dimension. - theta (`float`): - Scaling factor for frequency computation. - - Returns: - `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. - """ - if use_real is not True: - raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") - start, stop = crops_coords - grid_size_h, grid_size_w = grid_size - grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) - - # Compute dimensions for each axis - dim_t = embed_dim // 4 - dim_h = embed_dim // 8 * 3 - dim_w = embed_dim // 8 * 3 - - # Temporal frequencies - freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) - # Spatial frequencies for height and width - freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) - freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) - - # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor - def combine_time_height_width(freqs_t, freqs_h, freqs_w): - freqs_t = freqs_t[:, None, None, :].expand( - -1, grid_size_h, grid_size_w, -1 - ) # temporal_size, grid_size_h, grid_size_w, dim_t - freqs_h = freqs_h[None, :, None, :].expand( - temporal_size, -1, grid_size_w, -1 - ) # temporal_size, grid_size_h, grid_size_2, dim_h - freqs_w = freqs_w[None, None, :, :].expand( - temporal_size, grid_size_h, -1, -1 - ) # temporal_size, grid_size_h, grid_size_2, dim_w - - freqs = torch.cat( - [freqs_t, freqs_h, freqs_w], dim=-1 - ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) - freqs = freqs.view( - temporal_size * grid_size_h * grid_size_w, -1 - ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) - return freqs - - t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t - h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h - w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w - cos = combine_time_height_width(t_cos, h_cos, w_cos) - sin = combine_time_height_width(t_sin, h_sin, w_sin) - return cos, sin - - -def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): - """ - RoPE for image tokens with 2d structure. - - Args: - embed_dim: (`int`): - The embedding dimension size - crops_coords (`Tuple[int]`) - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the positional embedding. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - - Returns: - `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. - """ - start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) # [2, W, H] - - grid = grid.reshape([2, 1, *grid.shape[1:]]) - pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) - return pos_embed - - -def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): - assert embed_dim % 4 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_rotary_pos_embed( - embed_dim // 2, grid[0].reshape(-1), use_real=use_real - ) # (H*W, D/2) if use_real else (H*W, D/4) - emb_w = get_1d_rotary_pos_embed( - embed_dim // 2, grid[1].reshape(-1), use_real=use_real - ) # (H*W, D/2) if use_real else (H*W, D/4) - - if use_real: - cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) - sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) - return cos, sin - else: - emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) - return emb - - -def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): - assert embed_dim % 4 == 0 - - emb_h = get_1d_rotary_pos_embed( - embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor - ) # (H, D/4) - emb_w = get_1d_rotary_pos_embed( - embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor - ) # (W, D/4) - emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) - emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) - - emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) - return emb - - -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end - index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 - data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): - the dtype of the frequency tensor. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 - - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] - - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] - return freqs_cos, freqs_sin - elif use_real: - # stable audio - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] - return freqs_cos, freqs_sin - else: - # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis - - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - if use_real: - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - if use_real_unbind_dim == -1: - # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Used for Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) - else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - - return out - else: - # used for lumina - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - - return x_out.type_as(x) - - class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]):