From 27f81bd54f9182681637d503f55c6b8f483b5390 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 18 Nov 2024 17:30:24 +0100 Subject: [PATCH 01/47] update --- src/diffusers/models/attention_processor.py | 8 ++++++- src/diffusers/models/embeddings.py | 1 - .../models/transformers/transformer_mochi.py | 5 +++- .../pipelines/mochi/pipeline_mochi.py | 23 +++++++++++++++---- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 772aae7fcd2f..1e3d72e1ea0d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3524,6 +3524,7 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + joint_attention_mask=None, ) -> torch.Tensor: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -3579,9 +3580,14 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) + query = query * joint_attention_mask[:, None, :, None] + key = key * joint_attention_mask[:, None, :, None] + value = value * joint_attention_mask[:, None, :, None] + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states * joint_attention_mask[:, :, None] hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7cbd958e1d6e..789eef48afc0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -262,7 +262,6 @@ def forward(self, latent): height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8ac8b5dababa..c899b1185ea5 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -128,6 +128,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, + joint_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) @@ -137,11 +138,11 @@ def forward( ) else: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) - attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + joint_attention_mask=joint_attention_mask, ) hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) @@ -324,6 +325,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, + joint_attention_mask=None, return_dict: bool = True, ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -373,6 +375,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + joint_attention_mask=joint_attention_mask, ) hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 7a9cc41e2dde..d2ca80f9f395 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -17,6 +17,7 @@ import numpy as np import torch +import torch.nn.functional as F from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -245,7 +246,7 @@ def _get_t5_prompt_embeds( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -258,6 +259,14 @@ def _get_t5_prompt_embeds( return prompt_embeds, prompt_attention_mask + def prepare_joint_attention_mask(self, prompt_attention_mask, latents): + batch_size, channels, latent_frames, latent_height, latent_width = latents.shape + num_latents = latent_frames * latent_height * latent_width + num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) + mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) + + return mask + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt def encode_prompt( self, @@ -613,10 +622,6 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -630,6 +635,13 @@ def __call__( generator, latents, ) + joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) + negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0) # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 @@ -662,6 +674,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, + joint_attention_mask=joint_attention_mask, return_dict=False, )[0] From 30dd9f68453bf81e9f2951d96340338b418359e0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 18 Nov 2024 17:50:51 +0100 Subject: [PATCH 02/47] update --- src/diffusers/models/attention_processor.py | 11 ++++++----- .../models/transformers/transformer_mochi.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1e3d72e1ea0d..61aa8943ccd0 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3524,7 +3524,6 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - joint_attention_mask=None, ) -> torch.Tensor: query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -3580,14 +3579,16 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) - query = query * joint_attention_mask[:, None, :, None] - key = key * joint_attention_mask[:, None, :, None] - value = value * joint_attention_mask[:, None, :, None] + # Zero out tokens based on the attention mask + query = query * attention_mask[:, None, :, None] + key = key * attention_mask[:, None, :, None] + value = value * attention_mask[:, None, :, None] hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states * joint_attention_mask[:, :, None] + # Zero out tokens based on attention mask + hidden_states = hidden_states * attention_mask[:, :, None] hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index c899b1185ea5..50b84a913d7c 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -142,7 +142,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - joint_attention_mask=joint_attention_mask, + attention_mask=joint_attention_mask, ) hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) From 10275feacdab1a639c050c97f93540e7d21f0cce Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 20 Nov 2024 13:57:41 +0100 Subject: [PATCH 03/47] update --- src/diffusers/models/attention_processor.py | 32 ++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 61aa8943ccd0..1d7448c54b47 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3572,27 +3572,45 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): encoder_value.transpose(1, 2), ) - sequence_length = query.size(2) - encoder_sequence_length = encoder_query.size(2) + batch_size, heads, sequence_length, dim = query.shape + encoder_sequence_length = encoder_query.shape[2] + total_length = sequence_length + encoder_sequence_length query = torch.cat([query, encoder_query], dim=2) key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) # Zero out tokens based on the attention mask - query = query * attention_mask[:, None, :, None] - key = key * attention_mask[:, None, :, None] - value = value * attention_mask[:, None, :, None] + # query = query * attention_mask[:, None, :, None] + # key = key * attention_mask[:, None, :, None] + # value = value * attention_mask[:, None, :, None] + + query = query.view(1, query.size(1), -1, query.size(-1)) + key = key.view(1, key.size(1), -1, key.size(-1)) + value = value.view(1, value.size(1), -1, key.size(-1)) + + select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + __import__('ipdb').set_trace() + + query = torch.index_select(query, 2, select_index) + key = torch.index_select(key, 2, select_index) + value = torch.index_select(value, 2, select_index) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0) + output = torch.zeros( + batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype + ) + output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states) + hidden_states = output.view(batch_size, total_length, dim * heads) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) # Zero out tokens based on attention mask - hidden_states = hidden_states * attention_mask[:, :, None] + # hidden_states = hidden_states * attention_mask[:, :, None] hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) + __import__("ipdb").set_trace() # linear proj hidden_states = attn.to_out[0](hidden_states) From 79380ca7195de1cb4a8edae322a26c095d0e83ea Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 20 Nov 2024 19:41:08 +0100 Subject: [PATCH 04/47] update --- src/diffusers/models/attention_processor.py | 6 +- .../models/transformers/transformer_mochi.py | 3 +- .../pipelines/mochi/pipeline_mochi.py | 94 ++++++++++--------- 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1d7448c54b47..4e6202db51db 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch._prims_common import is_low_precision_dtype import torch.nn.functional as F from torch import nn @@ -3590,7 +3591,6 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): value = value.view(1, value.size(1), -1, key.size(-1)) select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - __import__('ipdb').set_trace() query = torch.index_select(query, 2, select_index) key = torch.index_select(key, 2, select_index) @@ -3604,13 +3604,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states) hidden_states = output.view(batch_size, total_length, dim * heads) - # Zero out tokens based on attention mask - # hidden_states = hidden_states * attention_mask[:, :, None] - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) - __import__("ipdb").set_trace() # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 50b84a913d7c..a96461291745 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -377,7 +377,8 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, joint_attention_mask=joint_attention_mask, ) - + print(f"block_{i} {hidden_states.norm()}") + print(f"block_{i} {encoder_hidden_states.norm()}") hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index d2ca80f9f395..3ca46d1da901 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,7 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...models.autoencoders import AutoencoderKL +from ...models.autoencoders import AutoencoderKLMochi from ...models.transformers import MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -56,7 +56,7 @@ >>> pipe.enable_model_cpu_offload() >>> pipe.enable_vae_tiling() >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." - >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] + >>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0] >>> export_to_video(frames, "mochi.mp4") ``` """ @@ -164,8 +164,8 @@ class MochiPipeline(DiffusionPipeline): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + vae ([`AutoencoderKLMochi`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. @@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, + vae: AutoencoderKLMochi, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, @@ -198,17 +198,11 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - # TODO: determine these scaling factors from model parameters - self.vae_spatial_scale_factor = 8 - self.vae_temporal_scale_factor = 6 - self.patch_size = 2 - - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_height = 480 - self.default_width = 848 + + self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8 + self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -259,14 +253,6 @@ def _get_t5_prompt_embeds( return prompt_embeds, prompt_attention_mask - def prepare_joint_attention_mask(self, prompt_attention_mask, latents): - batch_size, channels, latent_frames, latent_height, latent_width = latents.shape - num_latents = latent_frames * latent_height * latent_width - num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) - mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) - - return mask - # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt def encode_prompt( self, @@ -433,6 +419,13 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + def prepare_joint_attention_mask(self, prompt_attention_mask, latents): + batch_size, channels, latent_frames, latent_height, latent_width = latents.shape + num_latents = latent_frames * latent_height * latent_width + num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) + mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) + return mask + def prepare_latents( self, batch_size, @@ -445,9 +438,9 @@ def prepare_latents( generator, latents=None, ): - height = height // self.vae_spatial_scale_factor - width = width // self.vae_spatial_scale_factor - num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + height = height // self.vae_scale_factor_spatial + width = width // self.vae_scale_factor_spatial + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 shape = (batch_size, num_channels_latents, num_frames, height, width) @@ -487,7 +480,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 19, - num_inference_steps: int = 28, + num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, @@ -510,13 +503,13 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, *optional*, defaults to `self.default_height`): + height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`): The height in pixels of the generated image. This is set to 480 by default for the best results. - width (`int`, *optional*, defaults to `self.default_width`): + width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`): The width in pixels of the generated image. This is set to 848 by default for the best results. num_frames (`int`, defaults to `19`): The number of video frames to generate - num_inference_steps (`int`, *optional*, defaults to 50): + num_inference_steps (`int`, *optional*, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): @@ -576,8 +569,8 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.default_height - width = width or self.default_width + height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial + width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -622,6 +615,11 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) + + # if self.do_classifier_free_guidance: + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -635,13 +633,6 @@ def __call__( generator, latents, ) - joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) - negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0) # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 @@ -649,6 +640,9 @@ def __call__( sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = np.array(sigmas) + joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) + negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) + timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -665,11 +659,14 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + # timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + latent_model_input = latents timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( + noise_pred_text = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, @@ -679,8 +676,17 @@ def __call__( )[0] if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + encoder_attention_mask=negative_prompt_attention_mask, + joint_attention_mask=negative_joint_attention_mask, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 21b09979dc1b9b6f9b02e7a63e1eb0ee598f3986 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Nov 2024 13:21:32 +0100 Subject: [PATCH 05/47] update --- .../pipelines/mochi/pipeline_mochi.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 3ca46d1da901..6c792b858796 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -716,23 +716,30 @@ def __call__( if output_type == "latent": video = latents else: - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + with torch.autocast("cuda", torch.float32): + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From fcc59d01a99c8547ae4e8dfb942e0fba825c4ffa Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 23 Nov 2024 17:15:18 +0100 Subject: [PATCH 06/47] update --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/models/normalization.py | 6 +- .../models/transformers/transformer_mochi.py | 103 ++++++++++++++---- .../pipelines/mochi/pipeline_mochi.py | 47 ++++---- 4 files changed, 107 insertions(+), 50 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4e6202db51db..245118347c19 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from torch._prims_common import is_low_precision_dtype import torch.nn.functional as F from torch import nn diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 817b3fff2ea6..f74d5b36ac76 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -256,7 +256,9 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + scale_msa = scale_msa.float() + _hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None]) + hidden_states = _hidden_states.to(hidden_states.dtype) return hidden_states, gate_msa, scale_mlp, gate_mlp @@ -538,7 +540,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states * self.weight else: - hidden_states = hidden_states.to(input_dtype) + hidden_states = hidden_states # .to(input_dtype) return hidden_states diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index a96461291745..a38e098b8d53 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers from typing import Any, Dict, Optional, Tuple import torch @@ -26,12 +27,50 @@ from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm +from ..normalization import ( + AdaLayerNormContinuous, + LuminaLayerNormContinuous, + MochiRMSNormZero, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states, scale=None): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + if scale is not None: + hidden_states = hidden_states * scale + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" @@ -103,11 +142,11 @@ def __init__( ) # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True - self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm2 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) - self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm3 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -119,8 +158,8 @@ def __init__( bias=False, ) - self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm4 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) def forward( self, @@ -137,7 +176,9 @@ def forward( encoder_hidden_states, temb ) else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb).to( + encoder_hidden_states.dtype + ) attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, @@ -145,20 +186,35 @@ def forward( attention_mask=joint_attention_mask, ) - hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) - norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + # hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) + # norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + # ff_output = self.ff(norm_hidden_states) + # hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) + + hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) + norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) + hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) if not self.context_pre_only: + # encoder_hidden_states = encoder_hidden_states + self.norm2_context( + # context_attn_hidden_states + # ) * torch.tanh(enc_gate_msa).unsqueeze(1) + # norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + # context_ff_output = self.ff_context(norm_encoder_hidden_states) + # encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( + # enc_gate_mlp + # ).unsqueeze(1) encoder_hidden_states = encoder_hidden_states + self.norm2_context( - context_attn_hidden_states - ) * torch.tanh(enc_gate_msa).unsqueeze(1) - norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) + ) + norm_encoder_hidden_states = self.norm3_context( + encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float()) + ) context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( - enc_gate_mlp - ).unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.norm4_context( + context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) + ) return hidden_states, encoder_hidden_states @@ -309,7 +365,11 @@ def __init__( ) self.norm_out = AdaLayerNormContinuous( - inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + inner_dim, + inner_dim, + elementwise_affine=False, + eps=1e-6, + norm_type="layer_norm", ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) @@ -335,7 +395,10 @@ def forward( post_patch_width = width // p temb, encoder_hidden_states = self.time_embed( - timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -377,8 +440,6 @@ def custom_forward(*inputs): image_rotary_emb=image_rotary_emb, joint_attention_mask=joint_attention_mask, ) - print(f"block_{i} {hidden_states.norm()}") - print(f"block_{i} {encoder_hidden_states.norm()}") hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 6c792b858796..ab289a579bd0 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -335,7 +335,12 @@ def encode_prompt( dtype=dtype, ) - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) def check_inputs( self, @@ -596,7 +601,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # 3. Prepare text embeddings ( prompt_embeds, @@ -615,7 +619,6 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - # if self.do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) @@ -712,34 +715,26 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if output_type == "latent": video = latents else: - with torch.autocast("cuda", torch.float32): - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) ) - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From 1782d0241ae5ea002fee6507ae0c27e81938e9db Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 08:11:11 +0100 Subject: [PATCH 07/47] update --- src/diffusers/models/attention_processor.py | 14 +- src/diffusers/models/normalization.py | 5 +- .../models/transformers/transformer_mochi.py | 163 ++++++++++++------ 3 files changed, 120 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 245118347c19..e953c3e65622 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch._higher_order_ops.flex_attention import sdpa_dense import torch.nn.functional as F from torch import nn @@ -3554,11 +3555,11 @@ def __call__( if image_rotary_emb is not None: def apply_rotary_emb(x, freqs_cos, freqs_sin): - x_even = x[..., 0::2].float() - x_odd = x[..., 1::2].float() + x_even = x[..., 0::2] + x_odd = x[..., 1::2] - cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) - sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + cos = (x_even * freqs_cos.float() - x_odd * freqs_sin.float()).to(x.dtype) + sin = (x_even * freqs_sin.float() + x_odd * freqs_cos.float()).to(x.dtype) return torch.stack([cos, sin], dim=-1).flatten(-2) @@ -3595,7 +3596,10 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): key = torch.index_select(key, 2, select_index) value = torch.index_select(value, 2, select_index) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + from torch.nn.attention import SDPBackend, sdpa_kernel + with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0) output = torch.zeros( batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index f74d5b36ac76..1f0d16a2e210 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -532,15 +532,14 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True): def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps) if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states * self.weight - else: - hidden_states = hidden_states # .to(input_dtype) + hidden_states = hidden_states.to(input_dtype) return hidden_states diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index a38e098b8d53..154c4b8fd5b8 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers +from operator import ipow from typing import Any, Dict, Optional, Tuple import torch +from torch._prims_common import is_low_precision_dtype import torch.nn as nn +from transformers.tokenization_utils_base import import_protobuf_decode_error from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging @@ -30,45 +32,107 @@ from ..normalization import ( AdaLayerNormContinuous, LuminaLayerNormContinuous, - MochiRMSNormZero, ) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger = logging.get_logger(__name__) # pylint: disable=invalid-n -class MochiRMSNorm(nn.Module): +class FP32ModulatedRMSNorm(nn.Module): def __init__(self, dim, eps: float, elementwise_affine: bool = True): super().__init__() self.eps = eps - if isinstance(dim, numbers.Integral): - dim = (dim,) - - self.dim = torch.Size(dim) - - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.weight = None - def forward(self, hidden_states, scale=None): - input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps) + if scale is not None: hidden_states = hidden_states * scale - if self.weight is not None: - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight + return hidden_states + + +class MochiLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: - hidden_states = hidden_states.to(input_dtype) + raise ValueError(f"unknown norm_type {norm_type}") - return hidden_states + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + output_dtype = x.dtype + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x, (1 + scale.unsqueeze(1).float())) + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x.to(output_dtype) + + +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states_dtype = hidden_states.dtype + + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + + hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].float())) + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states, gate_msa, scale_mlp, gate_mlp @maybe_allow_in_graph @@ -115,7 +179,7 @@ def __init__( if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: - self.norm1_context = LuminaLayerNormContinuous( + self.norm1_context = MochiLayerNormContinuous( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, eps=eps, @@ -137,16 +201,16 @@ def __init__( out_context_dim=pooled_projection_dim, context_pre_only=context_pre_only, processor=MochiAttnProcessor2_0(), - eps=eps, + eps=1e-5, elementwise_affine=True, ) # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True - self.norm2 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm2_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm2 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) - self.norm3 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm3_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm3 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -158,8 +222,8 @@ def __init__( bias=False, ) - self.norm4 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm4_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm4 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) def forward( self, @@ -176,9 +240,8 @@ def forward( encoder_hidden_states, temb ) else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb).to( - encoder_hidden_states.dtype - ) + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, @@ -186,35 +249,26 @@ def forward( attention_mask=joint_attention_mask, ) - # hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) - # norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) - # ff_output = self.ff(norm_hidden_states) - # hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) - - hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) - norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())) + hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)).to( + hidden_states.dtype + ) + norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())).to(hidden_states.dtype) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) + hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)).to( + hidden_states.dtype + ) if not self.context_pre_only: - # encoder_hidden_states = encoder_hidden_states + self.norm2_context( - # context_attn_hidden_states - # ) * torch.tanh(enc_gate_msa).unsqueeze(1) - # norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) - # context_ff_output = self.ff_context(norm_encoder_hidden_states) - # encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( - # enc_gate_mlp - # ).unsqueeze(1) encoder_hidden_states = encoder_hidden_states + self.norm2_context( context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) - ) + ).to(encoder_hidden_states.dtype) norm_encoder_hidden_states = self.norm3_context( encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float()) - ) + ).to(encoder_hidden_states.dtype) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + self.norm4_context( context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) - ) + ).to(encoder_hidden_states.dtype) return hidden_states, encoder_hidden_states @@ -259,7 +313,8 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) + with torch.autocast("cuda", enabled=False): + freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin From 66a5f59ca1099e0981cd9da4fa2b95c7a3e4ff31 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 08:25:54 +0100 Subject: [PATCH 08/47] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index ab289a579bd0..2a2e3fe453ef 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -687,13 +687,16 @@ def __call__( joint_attention_mask=negative_joint_attention_mask, return_dict=False, )[0] - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond.float() + self.guidance_scale * ( + noise_pred_text.float() - noise_pred_uncond.float() + ) else: noise_pred = noise_pred_text # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + latents = latents.to(latents_dtype) if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): From 3ffa711db19bfa3c8c73895b0de4fa6afce67da8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 08:34:33 +0100 Subject: [PATCH 09/47] update --- .../pipelines/mochi/pipeline_mochi.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 2a2e3fe453ef..ec8119a3c332 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -601,24 +601,25 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) + with torch.autocast("cuda", torch.float32): + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) # if self.do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) From dded24364ccf0d678e61e09dc49ba4642f981208 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 08:50:27 +0100 Subject: [PATCH 10/47] update --- .../pipelines/mochi/pipeline_mochi.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index ec8119a3c332..1197f48c0053 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -722,23 +722,30 @@ def __call__( if output_type == "latent": video = latents else: - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + with torch.autocast("cuda", torch.float32): + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From d99234feac59ae596af7c28bbc3b3ad5896ca86a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 15:35:24 +0100 Subject: [PATCH 11/47] update --- .../pipelines/mochi/pipeline_mochi.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 1197f48c0053..515f5235ed28 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -601,25 +601,24 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - with torch.autocast("cuda", torch.float32): - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) # if self.do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) From 8b9d5b63ae5b28689dde45b95dac44c8a790d173 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Nov 2024 16:40:09 +0100 Subject: [PATCH 12/47] update --- src/diffusers/models/attention_processor.py | 2 +- .../models/transformers/transformer_mochi.py | 182 ++++++++++++------ 2 files changed, 122 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e953c3e65622..21e8dfe028fe 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from torch._higher_order_ops.flex_attention import sdpa_dense import torch.nn.functional as F from torch import nn @@ -3597,6 +3596,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): value = torch.index_select(value, 2, select_index) from torch.nn.attention import SDPBackend, sdpa_kernel + with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 154c4b8fd5b8..52c316455439 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,44 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -from operator import ipow from typing import Any, Dict, Optional, Tuple import torch -from torch._prims_common import is_low_precision_dtype import torch.nn as nn -from transformers.tokenization_utils_base import import_protobuf_decode_error from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..attention_processor import MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import ( AdaLayerNormContinuous, - LuminaLayerNormContinuous, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-n -class FP32ModulatedRMSNorm(nn.Module): - def __init__(self, dim, eps: float, elementwise_affine: bool = True): +class MochiModulatedRMSNorm(nn.Module): + def __init__(self, eps: float): super().__init__() self.eps = eps def forward(self, hidden_states, scale=None): + hidden_states_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) if scale is not None: hidden_states = hidden_states * scale + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states + + +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine=True): + super().__init__() + + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + hidden_states_dtype = hidden_states.dtype + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + + hidden_states = hidden_states.to(hidden_states_dtype) + return hidden_states @@ -59,49 +86,28 @@ def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, - # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters - # because the output is immediately scaled and shifted by the projected conditioning embeddings. - # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. - # However, this is how it was implemented in the original code, and it's rather likely you should - # set `elementwise_affine` to False. - elementwise_affine=True, eps=1e-5, bias=True, - norm_type="layer_norm", - out_dim: Optional[int] = None, ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) - - if norm_type == "layer_norm": - self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) - elif norm_type == "rms_norm": - self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - else: - raise ValueError(f"unknown norm_type {norm_type}") - - self.linear_2 = None - if out_dim is not None: - self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + self.norm = MochiModulatedRMSNorm(eps=eps) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: - output_dtype = x.dtype - # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) - emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) - scale = emb - x = self.norm(x, (1 + scale.unsqueeze(1).float())) + input_dtype = x.dtype - if self.linear_2 is not None: - x = self.linear_2(x) + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) - return x.to(output_dtype) + return x.to(input_dtype) class MochiRMSNormZero(nn.Module): @@ -119,7 +125,7 @@ def __init__( self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + self.norm = MochiModulatedRMSNorm(eps=eps) def forward( self, hidden_states: torch.Tensor, emb: torch.Tensor @@ -129,12 +135,76 @@ def forward( emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].float())) + hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32))) hidden_states = hidden_states.to(hidden_states_dtype) return hidden_states, gate_msa, scale_mlp, gate_mlp +class MochiAttention(nn.Module): + def __init__( + self, + query_dim: int, + processor: Optional["MochiAttnProcessor2_0"], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_dim: int = None, + out_context_dim: int = None, + out_bias: bool = True, + context_pre_only: bool = False, + eps: float = 1e-5, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim else query_dim + self.context_pre_only = context_pre_only + + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.norm_q = MochiRMSNorm(dim_head, eps) + self.norm_k = MochiRMSNorm(dim_head, eps) + self.norm_added_q = MochiRMSNorm(dim_head, eps) + self.norm_added_k = MochiRMSNorm(dim_head, eps) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" @@ -183,18 +253,13 @@ def __init__( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, eps=eps, - elementwise_affine=False, - norm_type="rms_norm", - out_dim=None, ) - self.attn1 = Attention( + self.attn1 = MochiAttention( query_dim=dim, - cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, bias=False, - qk_norm=qk_norm, added_kv_proj_dim=pooled_projection_dim, added_proj_bias=False, out_dim=dim, @@ -202,15 +267,14 @@ def __init__( context_pre_only=context_pre_only, processor=MochiAttnProcessor2_0(), eps=1e-5, - elementwise_affine=True, ) # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True - self.norm2 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm2_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm2 = MochiModulatedRMSNorm(eps=eps) + self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - self.norm3 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm3_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm3 = MochiModulatedRMSNorm(eps) + self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None @@ -222,8 +286,8 @@ def __init__( bias=False, ) - self.norm4 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False) - self.norm4_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + self.norm4 = MochiModulatedRMSNorm(eps=eps) + self.norm4_context = MochiModulatedRMSNorm(eps=eps) def forward( self, @@ -249,26 +313,22 @@ def forward( attention_mask=joint_attention_mask, ) - hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)).to( - hidden_states.dtype - ) - norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())).to(hidden_states.dtype) + hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) + norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)).to( - hidden_states.dtype - ) + hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) - ).to(encoder_hidden_states.dtype) + ) norm_encoder_hidden_states = self.norm3_context( - encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float()) - ).to(encoder_hidden_states.dtype) + encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)) + ) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + self.norm4_context( context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) - ).to(encoder_hidden_states.dtype) + ) return hidden_states, encoder_hidden_states From 2cfca5e0d2dd2a092a1228eee0ded7e832b3a7d0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 07:07:01 +0100 Subject: [PATCH 13/47] update --- src/diffusers/models/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 1f0d16a2e210..aa15e0a4f8d7 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -532,7 +532,7 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True): def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight is not None: # convert into half-precision if necessary From 900feadbc9b692f095ffc166de6b054b2f5be691 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 09:12:17 +0100 Subject: [PATCH 14/47] update --- src/diffusers/models/attention_processor.py | 43 ++--- .../pipelines/mochi/pipeline_mochi.py | 149 ++++++++---------- 2 files changed, 80 insertions(+), 112 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21e8dfe028fe..61aa8943ccd0 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3554,11 +3554,11 @@ def __call__( if image_rotary_emb is not None: def apply_rotary_emb(x, freqs_cos, freqs_sin): - x_even = x[..., 0::2] - x_odd = x[..., 1::2] + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() - cos = (x_even * freqs_cos.float() - x_odd * freqs_sin.float()).to(x.dtype) - sin = (x_even * freqs_sin.float() + x_odd * freqs_cos.float()).to(x.dtype) + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) return torch.stack([cos, sin], dim=-1).flatten(-2) @@ -3572,40 +3572,23 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): encoder_value.transpose(1, 2), ) - batch_size, heads, sequence_length, dim = query.shape - encoder_sequence_length = encoder_query.shape[2] - total_length = sequence_length + encoder_sequence_length + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) query = torch.cat([query, encoder_query], dim=2) key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) # Zero out tokens based on the attention mask - # query = query * attention_mask[:, None, :, None] - # key = key * attention_mask[:, None, :, None] - # value = value * attention_mask[:, None, :, None] + query = query * attention_mask[:, None, :, None] + key = key * attention_mask[:, None, :, None] + value = value * attention_mask[:, None, :, None] - query = query.view(1, query.size(1), -1, query.size(-1)) - key = key.view(1, key.size(1), -1, key.size(-1)) - value = value.view(1, value.size(1), -1, key.size(-1)) - - select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - - query = torch.index_select(query, 2, select_index) - key = torch.index_select(key, 2, select_index) - value = torch.index_select(value, 2, select_index) - - from torch.nn.attention import SDPBackend, sdpa_kernel - - with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0) - output = torch.zeros( - batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype - ) - output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states) - hidden_states = output.view(batch_size, total_length, dim * heads) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + # Zero out tokens based on attention mask + hidden_states = hidden_states * attention_mask[:, :, None] hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 515f5235ed28..bcc1eeb612d6 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,7 +21,7 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...models.autoencoders import AutoencoderKLMochi +from ...models.autoencoders import AutoencoderKL from ...models.transformers import MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -56,7 +56,7 @@ >>> pipe.enable_model_cpu_offload() >>> pipe.enable_vae_tiling() >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." - >>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0] + >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] >>> export_to_video(frames, "mochi.mp4") ``` """ @@ -164,8 +164,8 @@ class MochiPipeline(DiffusionPipeline): Conditional Transformer architecture to denoise the encoded video latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLMochi`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. @@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKLMochi, + vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, @@ -198,11 +198,17 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - - self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8 - self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6 - - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_height = 480 + self.default_width = 848 # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -253,6 +259,14 @@ def _get_t5_prompt_embeds( return prompt_embeds, prompt_attention_mask + def prepare_joint_attention_mask(self, prompt_attention_mask, latents): + batch_size, channels, latent_frames, latent_height, latent_width = latents.shape + num_latents = latent_frames * latent_height * latent_width + num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) + mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) + + return mask + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt def encode_prompt( self, @@ -335,12 +349,7 @@ def encode_prompt( dtype=dtype, ) - return ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask def check_inputs( self, @@ -424,13 +433,6 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - def prepare_joint_attention_mask(self, prompt_attention_mask, latents): - batch_size, channels, latent_frames, latent_height, latent_width = latents.shape - num_latents = latent_frames * latent_height * latent_width - num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) - mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) - return mask - def prepare_latents( self, batch_size, @@ -443,9 +445,9 @@ def prepare_latents( generator, latents=None, ): - height = height // self.vae_scale_factor_spatial - width = width // self.vae_scale_factor_spatial - num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 shape = (batch_size, num_channels_latents, num_frames, height, width) @@ -485,7 +487,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 19, - num_inference_steps: int = 50, + num_inference_steps: int = 64, timesteps: List[int] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, @@ -508,13 +510,13 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`): + height (`int`, *optional*, defaults to `self.default_height`): The height in pixels of the generated image. This is set to 480 by default for the best results. - width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`): + width (`int`, *optional*, defaults to `self.default_width`): The width in pixels of the generated image. This is set to 848 by default for the best results. num_frames (`int`, defaults to `19`): The number of video frames to generate - num_inference_steps (`int`, *optional*, defaults to `50`): + num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): @@ -574,8 +576,8 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial - width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial + height = height or self.default_height + width = width or self.default_width # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -601,6 +603,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + # 3. Prepare text embeddings ( prompt_embeds, @@ -619,10 +622,6 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - # if self.do_classifier_free_guidance: - # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -636,6 +635,13 @@ def __call__( generator, latents, ) + joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) + negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0) # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 @@ -643,9 +649,6 @@ def __call__( sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = np.array(sigmas) - joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) - negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) - timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -662,14 +665,11 @@ def __call__( if self.interrupt: continue - # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - # timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - latent_model_input = latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred_text = self.transformer( + noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, @@ -677,25 +677,16 @@ def __call__( joint_attention_mask=joint_attention_mask, return_dict=False, )[0] + # Mochi CFG + Sampling runs in FP32 + noise_pred = noise_pred.to(torch.float32) if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, - timestep=timestep, - encoder_attention_mask=negative_prompt_attention_mask, - joint_attention_mask=negative_joint_attention_mask, - return_dict=False, - )[0] - noise_pred = noise_pred_uncond.float() + self.guidance_scale * ( - noise_pred_text.float() - noise_pred_uncond.float() - ) - else: - noise_pred = noise_pred_text + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0] latents = latents.to(latents_dtype) if latents.dtype != latents_dtype: @@ -718,33 +709,27 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if output_type == "latent": video = latents else: - with torch.autocast("cuda", torch.float32): - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) ) - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor - - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From 0b09231c76b37c5d37ae591a52ba2c48f83688ec Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 10:02:54 +0100 Subject: [PATCH 15/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 3 +-- src/diffusers/pipelines/mochi/pipeline_mochi.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 52c316455439..6ad120e600f4 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -373,8 +373,7 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - with torch.autocast("cuda", enabled=False): - freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs) + freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index bcc1eeb612d6..d66e852ad944 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -210,7 +210,6 @@ def __init__( self.default_height = 480 self.default_width = 848 - # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -233,6 +232,7 @@ def _get_t5_prompt_embeds( add_special_tokens=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) From 883f5c8ef440a305ced310c92835072c25fbe8c8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 10:30:34 +0100 Subject: [PATCH 16/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 6ad120e600f4..c800342ed6d0 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -373,7 +373,9 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs) + # Always run ROPE freqs computation in FP32 + with torch.set_default_dtype(torch.float32): + freqs = torch.einsum("nd,dhf->nhf", pos, freqs) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin From 59c9f5d9faf0318b89b1892adc736734124d5088 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 10:44:52 +0100 Subject: [PATCH 17/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index c800342ed6d0..0c431ab2ebed 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -373,9 +373,10 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - # Always run ROPE freqs computation in FP32 - with torch.set_default_dtype(torch.float32): - freqs = torch.einsum("nd,dhf->nhf", pos, freqs) + with torch.autocast(freqs.device.type, enabled=False): + # Always run ROPE freqs computation in FP32 + freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) + freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin From f3fefaecadbef15bb622282f22e66817cb5dac05 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 12:30:16 +0100 Subject: [PATCH 18/47] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index d66e852ad944..7b96abcfd4df 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -636,11 +636,12 @@ def __call__( latents, ) joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) - negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0) # 5. Prepare timestep From 8a5d03b90300d5bb12ec7ce18c1a2728aa2bb15a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Nov 2024 18:58:12 +0100 Subject: [PATCH 19/47] update --- src/diffusers/models/attention_processor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 61aa8943ccd0..fa69f8809d49 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3579,12 +3579,12 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) - # Zero out tokens based on the attention mask - query = query * attention_mask[:, None, :, None] - key = key * attention_mask[:, None, :, None] - value = value * attention_mask[:, None, :, None] + attention_mask = (1.0 - attention_mask.to(hidden_states.dtype)) * -10000 + attention_mask = attention_mask.unsqueeze(1) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, attn_mask=attention_mask, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) # Zero out tokens based on attention mask From b7464e58280a1316dd32ed6e4809e1dc90d500be Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 01:55:03 +0100 Subject: [PATCH 20/47] update --- src/diffusers/models/attention_processor.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fa69f8809d49..4f6973295716 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch._prims_common import validate_strides import torch.nn.functional as F from torch import nn @@ -3574,16 +3575,30 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): sequence_length = query.size(2) encoder_sequence_length = encoder_query.size(2) + total_length = sequence_length + encoder_sequence_length query = torch.cat([query, encoder_query], dim=2) key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) - attention_mask = (1.0 - attention_mask.to(hidden_states.dtype)) * -10000 - attention_mask = attention_mask.unsqueeze(1) + batch_size, _, _, _ = query.shape + """ + torch.zeros(batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype) + for idx in range(batch_size): + mask = attention_mask[idx] + valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + valid_query = torch.index_select(query[idx], 2, valid_token_indices) + valid_key = torch.index_select(key[idx], 2, valid_token_indices) + valid_value = torch.index_select(value[idx], 2, valid_token_indices) + + attn_output = F.scaled_dot_product_attention( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False + ) + """ hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, attn_mask=attention_mask, is_causal=False + query, key, value, dropout_p=0.0, attn_mask=attention_mask.unsqueeze(1).unsqueeze(2), is_causal=False ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) From fb4e1753564243e1f7d18d2a4c0f060b03747e94 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 03:47:21 +0100 Subject: [PATCH 21/47] update --- src/diffusers/models/attention_processor.py | 26 +++++++++------------ 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4f6973295716..c1f21bb251eb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,7 +18,7 @@ import torch from torch._prims_common import validate_strides import torch.nn.functional as F -from torch import nn +from torch import nn, unsqueeze from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging @@ -3581,29 +3581,25 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) - batch_size, _, _, _ = query.shape - """ - torch.zeros(batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype) + batch_size, heads, _, dim = query.shape + attn_outputs = [] for idx in range(batch_size): - mask = attention_mask[idx] + mask = attention_mask[idx].unsqueeze(0) valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() - valid_query = torch.index_select(query[idx], 2, valid_token_indices) - valid_key = torch.index_select(key[idx], 2, valid_token_indices) - valid_value = torch.index_select(value[idx], 2, valid_token_indices) + valid_query = torch.index_select(query[idx].unsqueeze(0), 2, valid_token_indices) + valid_key = torch.index_select(key[idx].unsqueeze(0), 2, valid_token_indices) + valid_value = torch.index_select(value[idx].unsqueeze(0), 2, valid_token_indices) attn_output = F.scaled_dot_product_attention( valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False ) - """ - - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, attn_mask=attention_mask.unsqueeze(1).unsqueeze(2), is_causal=False - ) + valid_sequence_length = attn_output.size(2) + attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) + attn_outputs.append(attn_output) + hidden_states = torch.cat(attn_outputs, dim=0) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - # Zero out tokens based on attention mask - hidden_states = hidden_states * attention_mask[:, :, None] hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 From 61001c8f8f7777460b86f5358189bb184251d444 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 03:51:46 +0100 Subject: [PATCH 22/47] update --- src/diffusers/models/attention_processor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c1f21bb251eb..2d044f269725 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,9 +16,8 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from torch._prims_common import validate_strides import torch.nn.functional as F -from torch import nn, unsqueeze +from torch import nn from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging From 0fdef41d660e49f9c27f01ebab2363c611c0b61e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 04:48:05 +0100 Subject: [PATCH 23/47] update --- src/diffusers/models/attention_processor.py | 20 +++++++++---------- .../models/transformers/transformer_mochi.py | 7 +++---- .../pipelines/mochi/pipeline_mochi.py | 13 ------------ 3 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2d044f269725..7e7e1a54c4e1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3522,7 +3522,7 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: query = attn.to_q(hidden_states) @@ -3576,19 +3576,19 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): encoder_sequence_length = encoder_query.size(2) total_length = sequence_length + encoder_sequence_length - query = torch.cat([query, encoder_query], dim=2) - key = torch.cat([key, encoder_key], dim=2) - value = torch.cat([value, encoder_value], dim=2) - batch_size, heads, _, dim = query.shape attn_outputs = [] for idx in range(batch_size): - mask = attention_mask[idx].unsqueeze(0) - valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + mask = attention_mask[idx][None, :] + valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices) + valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices) + valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices) - valid_query = torch.index_select(query[idx].unsqueeze(0), 2, valid_token_indices) - valid_key = torch.index_select(key[idx].unsqueeze(0), 2, valid_token_indices) - valid_value = torch.index_select(value[idx].unsqueeze(0), 2, valid_token_indices) + valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2) + valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2) + valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2) attn_output = F.scaled_dot_product_attention( valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 0c431ab2ebed..fb346a70ba4d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -294,8 +294,8 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, + encoder_attention_mask: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, - joint_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) @@ -310,7 +310,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - attention_mask=joint_attention_mask, + attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) @@ -502,7 +502,6 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - joint_attention_mask=None, return_dict: bool = True, ) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -554,8 +553,8 @@ def custom_forward(*inputs): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, + encoder_attention_mask=encoder_attention_mask, image_rotary_emb=image_rotary_emb, - joint_attention_mask=joint_attention_mask, ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 7b96abcfd4df..951691e8c37a 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -17,7 +17,6 @@ import numpy as np import torch -import torch.nn.functional as F from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -259,14 +258,6 @@ def _get_t5_prompt_embeds( return prompt_embeds, prompt_attention_mask - def prepare_joint_attention_mask(self, prompt_attention_mask, latents): - batch_size, channels, latent_frames, latent_height, latent_width = latents.shape - num_latents = latent_frames * latent_height * latent_width - num_visual_tokens = num_latents // (self.transformer.config.patch_size**2) - mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True) - - return mask - # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt def encode_prompt( self, @@ -641,9 +632,6 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents) - joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0) - # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 threshold_noise = 0.025 @@ -675,7 +663,6 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, encoder_attention_mask=prompt_attention_mask, - joint_attention_mask=joint_attention_mask, return_dict=False, )[0] # Mochi CFG + Sampling runs in FP32 From e6fe9f1a09619d7918b0243f5be170a40f263d49 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 06:10:24 +0100 Subject: [PATCH 24/47] update --- .../pipelines/mochi/pipeline_mochi.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 951691e8c37a..572cf87172a7 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -626,7 +626,6 @@ def __call__( generator, latents, ) - joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -701,23 +700,30 @@ def __call__( if output_type == "latent": video = latents else: - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + with torch.autocast("cuda", torch.float32): + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None ) - latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor - - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, 12, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From c17cef75befa6eef02d80f6a9d467aeeac558dd5 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 06:25:52 +0100 Subject: [PATCH 25/47] update --- src/diffusers/models/transformers/__init__.py | 22 - .../transformers/auraflow_transformer_2d.py | 544 ----------------- .../transformers/cogvideox_transformer_3d.py | 507 --------------- .../models/transformers/dit_transformer_2d.py | 240 -------- .../transformers/dual_transformer_2d.py | 156 ----- .../transformers/hunyuan_transformer_2d.py | 578 ------------------ .../transformers/latte_transformer_3d.py | 327 ---------- .../models/transformers/lumina_nextdit2d.py | 340 ----------- .../transformers/pixart_transformer_2d.py | 445 -------------- .../models/transformers/prior_transformer.py | 380 ------------ .../transformers/stable_audio_transformer.py | 458 -------------- .../transformers/t5_film_transformer.py | 436 ------------- .../models/transformers/transformer_2d.py | 566 ----------------- .../transformers/transformer_allegro.py | 422 ------------- .../transformers/transformer_cogview3plus.py | 386 ------------ .../models/transformers/transformer_flux.py | 577 ----------------- .../models/transformers/transformer_mochi.py | 568 ----------------- .../models/transformers/transformer_sd3.py | 373 ----------- .../transformers/transformer_temporal.py | 381 ------------ .../pipelines/mochi/pipeline_mochi.py | 64 +- 20 files changed, 32 insertions(+), 7738 deletions(-) delete mode 100644 src/diffusers/models/transformers/__init__.py delete mode 100644 src/diffusers/models/transformers/auraflow_transformer_2d.py delete mode 100644 src/diffusers/models/transformers/cogvideox_transformer_3d.py delete mode 100644 src/diffusers/models/transformers/dit_transformer_2d.py delete mode 100644 src/diffusers/models/transformers/dual_transformer_2d.py delete mode 100644 src/diffusers/models/transformers/hunyuan_transformer_2d.py delete mode 100644 src/diffusers/models/transformers/latte_transformer_3d.py delete mode 100644 src/diffusers/models/transformers/lumina_nextdit2d.py delete mode 100644 src/diffusers/models/transformers/pixart_transformer_2d.py delete mode 100644 src/diffusers/models/transformers/prior_transformer.py delete mode 100644 src/diffusers/models/transformers/stable_audio_transformer.py delete mode 100644 src/diffusers/models/transformers/t5_film_transformer.py delete mode 100644 src/diffusers/models/transformers/transformer_2d.py delete mode 100644 src/diffusers/models/transformers/transformer_allegro.py delete mode 100644 src/diffusers/models/transformers/transformer_cogview3plus.py delete mode 100644 src/diffusers/models/transformers/transformer_flux.py delete mode 100644 src/diffusers/models/transformers/transformer_mochi.py delete mode 100644 src/diffusers/models/transformers/transformer_sd3.py delete mode 100644 src/diffusers/models/transformers/transformer_temporal.py diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py deleted file mode 100644 index a2c087d708a4..000000000000 --- a/src/diffusers/models/transformers/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from ...utils import is_torch_available - - -if is_torch_available(): - from .auraflow_transformer_2d import AuraFlowTransformer2DModel - from .cogvideox_transformer_3d import CogVideoXTransformer3DModel - from .dit_transformer_2d import DiTTransformer2DModel - from .dual_transformer_2d import DualTransformer2DModel - from .hunyuan_transformer_2d import HunyuanDiT2DModel - from .latte_transformer_3d import LatteTransformer3DModel - from .lumina_nextdit2d import LuminaNextDiT2DModel - from .pixart_transformer_2d import PixArtTransformer2DModel - from .prior_transformer import PriorTransformer - from .stable_audio_transformer import StableAudioDiTModel - from .t5_film_transformer import T5FilmDecoder - from .transformer_2d import Transformer2DModel - from .transformer_allegro import AllegroTransformer3DModel - from .transformer_cogview3plus import CogView3PlusTransformer2DModel - from .transformer_flux import FluxTransformer2DModel - from .transformer_mochi import MochiTransformer3DModel - from .transformer_sd3 import SD3Transformer2DModel - from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py deleted file mode 100644 index b3f29e6b6224..000000000000 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ /dev/null @@ -1,544 +0,0 @@ -# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import ( - Attention, - AttentionProcessor, - AuraFlowAttnProcessor2_0, - FusedAuraFlowAttnProcessor2_0, -) -from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormZero, FP32LayerNorm - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Taken from the original aura flow inference code. -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -# Aura Flow patch embed doesn't use convs for projections. -# Additionally, it uses learned positional embeddings. -class AuraFlowPatchEmbed(nn.Module): - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - pos_embed_max_size=None, - ): - super().__init__() - - self.num_patches = (height // patch_size) * (width // patch_size) - self.pos_embed_max_size = pos_embed_max_size - - self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) - self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) - - self.patch_size = patch_size - self.height, self.width = height // patch_size, width // patch_size - self.base_size = height // patch_size - - def pe_selection_index_based_on_dim(self, h, w): - # select subset of positional embedding based on H, W, where H, W is size of latent - # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected - # because original input are in flattened format, we have to flatten this 2d grid as well. - h_p, w_p = h // self.patch_size, w // self.patch_size - original_pe_indexes = torch.arange(self.pos_embed.shape[1]) - h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) - original_pe_indexes = original_pe_indexes.view(h_max, w_max) - starth = h_max // 2 - h_p // 2 - endh = starth + h_p - startw = w_max // 2 - w_p // 2 - endw = startw + w_p - original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] - return original_pe_indexes.flatten() - - def forward(self, latent): - batch_size, num_channels, height, width = latent.size() - latent = latent.view( - batch_size, - num_channels, - height // self.patch_size, - self.patch_size, - width // self.patch_size, - self.patch_size, - ) - latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) - latent = self.proj(latent) - pe_index = self.pe_selection_index_based_on_dim(height, width) - return latent + self.pos_embed[:, pe_index] - - -# Taken from the original Aura flow inference code. -# Our feedforward only has GELU but Aura uses SiLU. -class AuraFlowFeedForward(nn.Module): - def __init__(self, dim, hidden_dim=None) -> None: - super().__init__() - if hidden_dim is None: - hidden_dim = 4 * dim - - final_hidden_dim = int(2 * hidden_dim / 3) - final_hidden_dim = find_multiple(final_hidden_dim, 256) - - self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) - self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) - self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.silu(self.linear_1(x)) * self.linear_2(x) - x = self.out_projection(x) - return x - - -class AuraFlowPreFinalBlock(nn.Module): - def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) - - def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: - emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) - scale, shift = torch.chunk(emb, 2, dim=1) - x = x * (1 + scale)[:, None, :] + shift[:, None, :] - return x - - -@maybe_allow_in_graph -class AuraFlowSingleTransformerBlock(nn.Module): - """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" - - def __init__(self, dim, num_attention_heads, attention_head_dim): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - - processor = AuraFlowAttnProcessor2_0() - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="fp32_layer_norm", - out_dim=dim, - bias=False, - out_bias=False, - processor=processor, - ) - - self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - self.ff = AuraFlowFeedForward(dim, dim * 4) - - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): - residual = hidden_states - - # Norm + Projection. - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - # Attention. - attn_output = self.attn(hidden_states=norm_hidden_states) - - # Process attention outputs for the `hidden_states`. - hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) - hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(hidden_states) - hidden_states = gate_mlp.unsqueeze(1) * ff_output - hidden_states = residual + hidden_states - - return hidden_states - - -@maybe_allow_in_graph -class AuraFlowJointTransformerBlock(nn.Module): - r""" - Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): - - * QK Norm in the attention blocks - * No bias in the attention blocks - * Most LayerNorms are in FP32 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - is_last (`bool`): Boolean to determine if this is the last block in the model. - """ - - def __init__(self, dim, num_attention_heads, attention_head_dim): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") - - processor = AuraFlowAttnProcessor2_0() - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - added_proj_bias=False, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="fp32_layer_norm", - out_dim=dim, - bias=False, - out_bias=False, - processor=processor, - context_pre_only=False, - ) - - self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - self.ff = AuraFlowFeedForward(dim, dim * 4) - self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) - self.ff_context = AuraFlowFeedForward(dim, dim * 4) - - def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor - ): - residual = hidden_states - residual_context = encoder_hidden_states - - # Norm + Projection. - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states - ) - - # Process attention outputs for the `hidden_states`. - hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) - hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) - hidden_states = residual + hidden_states - - # Process attention outputs for the `encoder_hidden_states`. - encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) - encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) - encoder_hidden_states = residual_context + encoder_hidden_states - - return encoder_hidden_states, hidden_states - - -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): - r""" - A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). - - Parameters: - sample_size (`int`): The width of the latent images. This is fixed during training since - it is used to learn a number of position embeddings. - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. - num_single_dit_layers (`int`, *optional*, defaults to 4): - The number of layers of Transformer blocks to use. These blocks use concatenated image and text - representations. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. - out_channels (`int`, defaults to 16): Number of output channels. - pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. - """ - - _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int = 64, - patch_size: int = 2, - in_channels: int = 4, - num_mmdit_layers: int = 4, - num_single_dit_layers: int = 32, - attention_head_dim: int = 256, - num_attention_heads: int = 12, - joint_attention_dim: int = 2048, - caption_projection_dim: int = 3072, - out_channels: int = 4, - pos_embed_max_size: int = 1024, - ): - super().__init__() - default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - - self.pos_embed = AuraFlowPatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, - embed_dim=self.inner_dim, - pos_embed_max_size=pos_embed_max_size, - ) - - self.context_embedder = nn.Linear( - self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False - ) - self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) - self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) - - self.joint_transformer_blocks = nn.ModuleList( - [ - AuraFlowJointTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_mmdit_layers) - ] - ) - self.single_transformer_blocks = nn.ModuleList( - [ - AuraFlowSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for _ in range(self.config.num_single_dit_layers) - ] - ) - - self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) - - # https://arxiv.org/abs/2309.16588 - # prevents artifacts in the attention maps - self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - height, width = hidden_states.shape[-2:] - - # Apply patch embedding, timestep embedding, and project the caption embeddings. - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) - temb = self.time_step_proj(temb) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - encoder_hidden_states = torch.cat( - [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 - ) - - # MMDiT blocks. - for index_block, block in enumerate(self.joint_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - - # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) - if len(self.single_transformer_blocks) > 0: - encoder_seq_len = encoder_hidden_states.size(1) - combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - combined_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - combined_hidden_states, - temb, - **ckpt_kwargs, - ) - - else: - combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) - - hidden_states = combined_hidden_states[:, encoder_seq_len:] - - hidden_states = self.norm_out(hidden_states, temb) - hidden_states = self.proj_out(hidden_states) - - # unpatchify - patch_size = self.config.patch_size - out_channels = self.config.out_channels - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py deleted file mode 100644 index 01c54ef090bd..000000000000 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Tuple, Union - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward -from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@maybe_allow_in_graph -class CogVideoXBlock(nn.Module): - r""" - Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - time_embed_dim (`int`): - The number of channels in timestep embedding. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - qk_norm (`bool`, defaults to `True`): - Whether or not to use normalization after query and key projections in Attention. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - ff_inner_dim (`int`, *optional*, defaults to `None`): - Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in Feed-forward layer. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in Attention output projection layer. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - time_embed_dim: int, - dropout: float = 0.0, - activation_fn: str = "gelu-approximate", - attention_bias: bool = False, - qk_norm: bool = True, - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - final_dropout: bool = True, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - ): - super().__init__() - - # 1. Self Attention - self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.attn1 = Attention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=attention_bias, - out_bias=attention_out_bias, - processor=CogVideoXAttnProcessor2_0(), - ) - - # 2. Feed Forward - self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( - hidden_states, encoder_hidden_states, temb - ) - - # attention - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - - hidden_states = hidden_states + gate_msa * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( - hidden_states, encoder_hidden_states, temb - ) - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] - - return hidden_states, encoder_hidden_states - - -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - """ - A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). - - Parameters: - num_attention_heads (`int`, defaults to `30`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `64`): - The number of channels in each head. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `16`): - The number of channels in the output. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - time_embed_dim (`int`, defaults to `512`): - Output dimension of timestep embeddings. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - num_layers (`int`, defaults to `30`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_width (`int`, defaults to `90`): - The width of the input latents. - sample_height (`int`, defaults to `60`): - The height of the input latents. - sample_frames (`int`, defaults to `49`): - The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 - instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, - but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with - K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - temporal_compression_ratio (`int`, defaults to `4`): - The compression ratio across the temporal dimension. See documentation for `sample_frames`. - max_text_seq_length (`int`, defaults to `226`): - The maximum sequence length of the input text embeddings. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - timestep_activation_fn (`str`, defaults to `"silu"`): - Activation function to use when generating the timestep embeddings. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-5`): - The epsilon value to use in normalization layers. - spatial_interpolation_scale (`float`, defaults to `1.875`): - Scaling factor to apply in 3D positional embeddings across spatial dimensions. - temporal_interpolation_scale (`float`, defaults to `1.0`): - Scaling factor to apply in 3D positional embeddings across temporal dimensions. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 64, - in_channels: int = 16, - out_channels: Optional[int] = 16, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - time_embed_dim: int = 512, - text_embed_dim: int = 4096, - num_layers: int = 30, - dropout: float = 0.0, - attention_bias: bool = True, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 49, - patch_size: int = 2, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - activation_fn: str = "gelu-approximate", - timestep_activation_fn: str = "silu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_rotary_positional_embeddings: bool = False, - use_learned_positional_embeddings: bool = False, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - if not use_rotary_positional_embeddings and use_learned_positional_embeddings: - raise ValueError( - "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " - "embeddings. If you're using a custom model and/or believe this should be supported, please open an " - "issue at https://github.com/huggingface/diffusers/issues." - ) - - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed( - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - text_embed_dim=text_embed_dim, - bias=True, - sample_width=sample_width, - sample_height=sample_height, - sample_frames=sample_frames, - temporal_compression_ratio=temporal_compression_ratio, - max_text_seq_length=max_text_seq_length, - spatial_interpolation_scale=spatial_interpolation_scale, - temporal_interpolation_scale=temporal_interpolation_scale, - use_positional_embeddings=not use_rotary_positional_embeddings, - use_learned_positional_embeddings=use_learned_positional_embeddings, - ) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. Time embeddings - self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - - # 3. Define spatio-temporal transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - CogVideoXBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - dropout=dropout, - activation_fn=activation_fn, - attention_bias=attention_bias, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - - # 4. Output blocks - self.norm_out = AdaLayerNorm( - embedding_dim=time_embed_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - chunk_dim=1, - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) - - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py deleted file mode 100644 index f787c5279499..000000000000 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging -from ..attention import BasicTransformerBlock -from ..embeddings import PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class DiTTransformer2DModel(ModelMixin, ConfigMixin): - r""" - A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). - - Parameters: - num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (int, optional, defaults to 72): The number of channels in each head. - in_channels (int, defaults to 4): The number of channels in the input. - out_channels (int, optional): - The number of channels in the output. Specify this parameter if the output channel number differs from the - input. - num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. - dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. - norm_num_groups (int, optional, defaults to 32): - Number of groups for group normalization within Transformer blocks. - attention_bias (bool, optional, defaults to True): - Configure if the Transformer blocks' attention should contain a bias parameter. - sample_size (int, defaults to 32): - The width of the latent images. This parameter is fixed during training. - patch_size (int, defaults to 2): - Size of the patches the model processes, relevant for architectures working on non-sequential data. - activation_fn (str, optional, defaults to "gelu-approximate"): - Activation function to use in feed-forward networks within Transformer blocks. - num_embeds_ada_norm (int, optional, defaults to 1000): - Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during - inference. - upcast_attention (bool, optional, defaults to False): - If true, upcasts the attention mechanism dimensions for potentially improved performance. - norm_type (str, optional, defaults to "ada_norm_zero"): - Specifies the type of normalization used, can be 'ada_norm_zero'. - norm_elementwise_affine (bool, optional, defaults to False): - If true, enables element-wise affine parameters in the normalization layers. - norm_eps (float, optional, defaults to 1e-5): - A small constant added to the denominator in normalization layers to prevent division by zero. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 72, - in_channels: int = 4, - out_channels: Optional[int] = None, - num_layers: int = 28, - dropout: float = 0.0, - norm_num_groups: int = 32, - attention_bias: bool = True, - sample_size: int = 32, - patch_size: int = 2, - activation_fn: str = "gelu-approximate", - num_embeds_ada_norm: Optional[int] = 1000, - upcast_attention: bool = False, - norm_type: str = "ada_norm_zero", - norm_elementwise_affine: bool = False, - norm_eps: float = 1e-5, - ): - super().__init__() - - # Validate inputs. - if norm_type != "ada_norm_zero": - raise NotImplementedError( - f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." - ) - elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: - raise ValueError( - f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." - ) - - # Set some common variables used across the board. - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.out_channels = in_channels if out_channels is None else out_channels - self.gradient_checkpointing = False - - # 2. Initialize the position embedding and transformer blocks. - self.height = self.config.sample_size - self.width = self.config.sample_size - - self.patch_size = self.config.patch_size - self.pos_embed = PatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, - embed_dim=self.inner_dim, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - ) - for _ in range(self.config.num_layers) - ] - ) - - # 3. Output blocks. - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) - self.proj_out_2 = nn.Linear( - self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels - ) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - return_dict: bool = True, - ): - """ - The [`DiTTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # 1. Input - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - None, - None, - timestep, - cross_attention_kwargs, - class_labels, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py deleted file mode 100644 index 1c48c4e3db79..000000000000 --- a/src/diffusers/models/transformers/dual_transformer_2d.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from torch import nn - -from ..modeling_outputs import Transformer2DModelOutput -from .transformer_2d import Transformer2DModel - - -class DualTransformer2DModel(nn.Module): - """ - Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - ): - super().__init__() - self.transformers = nn.ModuleList( - [ - Transformer2DModel( - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - in_channels=in_channels, - num_layers=num_layers, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attention_bias=attention_bias, - sample_size=sample_size, - num_vector_embeds=num_vector_embeds, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - ) - for _ in range(2) - ] - ) - - # Variables that can be set by a pipeline: - - # The ratio of transformer1 to transformer2's output states to be combined during inference - self.mix_ratio = 0.5 - - # The shape of `encoder_hidden_states` is expected to be - # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` - self.condition_lengths = [77, 257] - - # Which transformer to use to encode which condition. - # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` - self.transformer_index_for_condition = [1, 0] - - def forward( - self, - hidden_states, - encoder_hidden_states, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - attention_mask (`torch.Tensor`, *optional*): - Optional attention mask to be applied in Attention. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - """ - input_states = hidden_states - - encoded_states = [] - tokens_start = 0 - # attention_mask is not used yet - for i in range(2): - # for each of the two transformers, pass the corresponding condition tokens - condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] - transformer_index = self.transformer_index_for_condition[i] - encoded_state = self.transformers[transformer_index]( - input_states, - encoder_hidden_states=condition_state, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - encoded_states.append(encoded_state - input_states) - tokens_start += self.condition_lengths[i] - - output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) - output_states = output_states + input_states - - if not return_dict: - return (output_states,) - - return Transformer2DModelOutput(sample=output_states) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py deleted file mode 100644 index 7f3dab220aaa..000000000000 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ /dev/null @@ -1,578 +0,0 @@ -# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, Optional, Union - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 -from ..embeddings import ( - HunyuanCombinedTimestepTextSizeStyleEmbedding, - PatchEmbed, - PixArtAlphaTextProjection, -) -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, FP32LayerNorm - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class AdaLayerNormShift(nn.Module): - r""" - Norm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - - def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): - super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim) - self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) - x = self.norm(x) + shift.unsqueeze(dim=1) - return x - - -@maybe_allow_in_graph -class HunyuanDiTBlock(nn.Module): - r""" - Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and - QKNorm - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of headsto use for multi-head attention. - cross_attention_dim (`int`,*optional*): - The size of the encoder_hidden_states vector for cross attention. - dropout(`float`, *optional*, defaults to 0.0): - The dropout probability to use. - activation_fn (`str`,*optional*, defaults to `"geglu"`): - Activation function to be used in feed-forward. . - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, *optional*, defaults to 1e-6): - A small constant added to the denominator in normalization layers to prevent division by zero. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - ff_inner_dim (`int`, *optional*): - The size of the hidden layer in the feed-forward block. Defaults to `None`. - ff_bias (`bool`, *optional*, defaults to `True`): - Whether to use bias in the feed-forward block. - skip (`bool`, *optional*, defaults to `False`): - Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. - qk_norm (`bool`, *optional*, defaults to `True`): - Whether to use normalization in QK calculation. Defaults to `True`. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - cross_attention_dim: int = 1024, - dropout=0.0, - activation_fn: str = "geglu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-6, - final_dropout: bool = False, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - skip: bool = False, - qk_norm: bool = True, - ): - super().__init__() - - # Define 3 blocks. Each block has its own normalization layer. - # NOTE: when new version comes, check norm2 and norm 3 - # 1. Self-Attn - self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=dim // num_attention_heads, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=True, - processor=HunyuanAttnProcessor2_0(), - ) - - # 2. Cross-Attn - self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - dim_head=dim // num_attention_heads, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=True, - processor=HunyuanAttnProcessor2_0(), - ) - # 3. Feed-forward - self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.ff = FeedForward( - dim, - dropout=dropout, ### 0.0 - activation_fn=activation_fn, ### approx GeLU - final_dropout=final_dropout, ### 0.0 - inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) - bias=ff_bias, - ) - - # 4. Skip Connection - if skip: - self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) - self.skip_linear = nn.Linear(2 * dim, dim) - else: - self.skip_linear = None - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb=None, - skip=None, - ) -> torch.Tensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Long Skip Connection - if self.skip_linear is not None: - cat = torch.cat([hidden_states, skip], dim=-1) - cat = self.skip_norm(cat) - hidden_states = self.skip_linear(cat) - - # 1. Self-Attention - norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct - attn_output = self.attn1( - norm_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - hidden_states = hidden_states + attn_output - - # 2. Cross-Attention - hidden_states = hidden_states + self.attn2( - self.norm2(hidden_states), - encoder_hidden_states=encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) - - # FFN Layer ### TODO: switch norm2 and norm3 in the state dict - mlp_inputs = self.norm3(hidden_states) - hidden_states = hidden_states + self.ff(mlp_inputs) - - return hidden_states - - -class HunyuanDiT2DModel(ModelMixin, ConfigMixin): - """ - HunYuanDiT: Diffusion model with a Transformer backbone. - - Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): - The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - patch_size (`int`, *optional*): - The size of the patch to use for the input. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to use in feed-forward. - sample_size (`int`, *optional*): - The width of the latent images. This is fixed during training since it is used to learn a number of - position embeddings. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - cross_attention_dim (`int`, *optional*): - The number of dimension in the clip text embedding. - hidden_size (`int`, *optional*): - The size of hidden layer in the conditioning embedding layers. - num_layers (`int`, *optional*, defaults to 1): - The number of layers of Transformer blocks to use. - mlp_ratio (`float`, *optional*, defaults to 4.0): - The ratio of the hidden layer size to the input size. - learn_sigma (`bool`, *optional*, defaults to `True`): - Whether to predict variance. - cross_attention_dim_t5 (`int`, *optional*): - The number dimensions in t5 text embedding. - pooled_projection_dim (`int`, *optional*): - The size of the pooled projection. - text_len (`int`, *optional*): - The length of the clip text embedding. - text_len_t5 (`int`, *optional*): - The length of the T5 text embedding. - use_style_cond_and_image_meta_size (`bool`, *optional*): - Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "gelu-approximate", - sample_size=32, - hidden_size=1152, - num_layers: int = 28, - mlp_ratio: float = 4.0, - learn_sigma: bool = True, - cross_attention_dim: int = 1024, - norm_type: str = "layer_norm", - cross_attention_dim_t5: int = 2048, - pooled_projection_dim: int = 1024, - text_len: int = 77, - text_len_t5: int = 256, - use_style_cond_and_image_meta_size: bool = True, - ): - super().__init__() - self.out_channels = in_channels * 2 if learn_sigma else in_channels - self.num_heads = num_attention_heads - self.inner_dim = num_attention_heads * attention_head_dim - - self.text_embedder = PixArtAlphaTextProjection( - in_features=cross_attention_dim_t5, - hidden_size=cross_attention_dim_t5 * 4, - out_features=cross_attention_dim, - act_fn="silu_fp32", - ) - - self.text_embedding_padding = nn.Parameter( - torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) - ) - - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - in_channels=in_channels, - embed_dim=hidden_size, - patch_size=patch_size, - pos_embed_type=None, - ) - - self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( - hidden_size, - pooled_projection_dim=pooled_projection_dim, - seq_len=text_len_t5, - cross_attention_dim=cross_attention_dim_t5, - use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, - ) - - # HunyuanDiT Blocks - self.blocks = nn.ModuleList( - [ - HunyuanDiTBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - activation_fn=activation_fn, - ff_inner_dim=int(self.inner_dim * mlp_ratio), - cross_attention_dim=cross_attention_dim, - qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. - skip=layer > num_layers // 2, - ) - for layer in range(num_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.set_attn_processor(HunyuanAttnProcessor2_0()) - - def forward( - self, - hidden_states, - timestep, - encoder_hidden_states=None, - text_embedding_mask=None, - encoder_hidden_states_t5=None, - text_embedding_mask_t5=None, - image_meta_size=None, - style=None, - image_rotary_emb=None, - controlnet_block_samples=None, - return_dict=True, - ): - """ - The [`HunyuanDiT2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): - The input tensor. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. - encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. This is the output of `BertModel`. - text_embedding_mask: torch.Tensor - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output - of `BertModel`. - encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. - text_embedding_mask_t5: torch.Tensor - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output - of T5 Text Encoder. - image_meta_size (torch.Tensor): - Conditional embedding indicate the image sizes - style: torch.Tensor: - Conditional embedding indicate the style - image_rotary_emb (`torch.Tensor`): - The image rotary embeddings to apply on query and key tensors during attention calculation. - return_dict: bool - Whether to return a dictionary. - """ - - height, width = hidden_states.shape[-2:] - - hidden_states = self.pos_embed(hidden_states) - - temb = self.time_extra_emb( - timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype - ) # [B, D] - - # text projection - batch_size, sequence_length, _ = encoder_hidden_states_t5.shape - encoder_hidden_states_t5 = self.text_embedder( - encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) - ) - encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) - - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) - text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) - text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() - - encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) - - skips = [] - for layer, block in enumerate(self.blocks): - if layer > self.config.num_layers // 2: - if controlnet_block_samples is not None: - skip = skips.pop() + controlnet_block_samples.pop() - else: - skip = skips.pop() - hidden_states = block( - hidden_states, - temb=temb, - encoder_hidden_states=encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - skip=skip, - ) # (N, L, D) - else: - hidden_states = block( - hidden_states, - temb=temb, - encoder_hidden_states=encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - ) # (N, L, D) - - if layer < (self.config.num_layers // 2 - 1): - skips.append(hidden_states) - - if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: - raise ValueError("The number of controls is not equal to the number of skip connections.") - - # final layer - hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) - hidden_states = self.proj_out(hidden_states) - # (N, L, patch_size ** 2 * out_channels) - - # unpatchify: (N, out_channels, H, W) - patch_size = self.pos_embed.patch_size - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) - ) - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking - def disable_forward_chunking(self): - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, None, 0) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py deleted file mode 100644 index 7e2b1273687d..000000000000 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid -from ..attention import BasicTransformerBlock -from ..embeddings import PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle - - -class LatteTransformer3DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - """ - A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: - https://github.com/Vchitect/Latte - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input. - out_channels (`int`, *optional*): - The number of channels in the output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - patch_size (`int`, *optional*): - The size of the patches to use in the patch embedding layer. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. During inference, you can denoise for up to but not more steps than - `num_embeds_ada_norm`. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. - caption_channels (`int`, *optional*): - The number of channels in the caption embeddings. - video_length (`int`, *optional*): - The number of frames in the video-like data. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: int = 64, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - caption_channels: int = None, - video_length: int = 16, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - # 1. Define input layers - self.height = sample_size - self.width = sample_size - - interpolation_scale = self.config.sample_size // 64 - interpolation_scale = max(interpolation_scale, 1) - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - interpolation_scale=interpolation_scale, - ) - - # 2. Define spatial transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for d in range(num_layers) - ] - ) - - # 3. Define temporal transformers blocks - self.temporal_transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=None, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - # 5. Latte other blocks. - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - - # define temporal positional embedding - temp_pos_embed = get_1d_sincos_pos_embed_from_grid( - inner_dim, torch.arange(0, video_length).unsqueeze(1) - ) # 1152 hidden size - self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - enable_temporal_attentions: bool = True, - return_dict: bool = True, - ): - """ - The [`LatteTransformer3DModel`] forward method. - - Args: - hidden_states shape `(batch size, channel, num_frame, height, width)`: - Input `hidden_states`. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batcheight, sequence_length)` True = keep, False = discard. - * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - enable_temporal_attentions: - (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - - # Reshape hidden states - batch_size, channels, num_frame, height, width = hidden_states.shape - # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) - - # Input - height, width = ( - hidden_states.shape[-2] // self.config.patch_size, - hidden_states.shape[-1] // self.config.patch_size, - ) - num_patches = height * width - - hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings - - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - # Prepare text embeddings for spatial block - # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size - encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 - encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) - - # Prepare timesteps for spatial and temporal block - timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) - timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) - - # Spatial and temporal transformer blocks - for i, (spatial_block, temp_block) in enumerate( - zip(self.transformer_blocks, self.temporal_transformer_blocks) - ): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - spatial_block, - hidden_states, - None, # attention_mask - encoder_hidden_states_spatial, - encoder_attention_mask, - timestep_spatial, - None, # cross_attention_kwargs - None, # class_labels - use_reentrant=False, - ) - else: - hidden_states = spatial_block( - hidden_states, - None, # attention_mask - encoder_hidden_states_spatial, - encoder_attention_mask, - timestep_spatial, - None, # cross_attention_kwargs - None, # class_labels - ) - - if enable_temporal_attentions: - # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size - hidden_states = hidden_states.reshape( - batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] - ).permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - - if i == 0 and num_frame > 1: - hidden_states = hidden_states + self.temp_pos_embed - - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - temp_block, - hidden_states, - None, # attention_mask - None, # encoder_hidden_states - None, # encoder_attention_mask - timestep_temp, - None, # cross_attention_kwargs - None, # class_labels - use_reentrant=False, - ) - else: - hidden_states = temp_block( - hidden_states, - None, # attention_mask - None, # encoder_hidden_states - None, # encoder_attention_mask - timestep_temp, - None, # cross_attention_kwargs - None, # class_labels - ) - - # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size - hidden_states = hidden_states.reshape( - batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] - ).permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - - embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) - ) - output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( - 0, 2, 1, 3, 4 - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py deleted file mode 100644 index d4f5b4658542..000000000000 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging -from ..attention import LuminaFeedForward -from ..attention_processor import Attention, LuminaAttnProcessor2_0 -from ..embeddings import ( - LuminaCombinedTimestepCaptionEmbedding, - LuminaPatchEmbed, -) -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class LuminaNextDiTBlock(nn.Module): - """ - A LuminaNextDiTBlock for LuminaNextDiT2DModel. - - Parameters: - dim (`int`): Embedding dimension of the input features. - num_attention_heads (`int`): Number of attention heads. - num_kv_heads (`int`): - Number of attention heads in key and value features (if using GQA), or set to None for the same as query. - multiple_of (`int`): The number of multiple of ffn layer. - ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. - norm_eps (`float`): The eps for norm layer. - qk_norm (`bool`): normalization for query and key. - cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. - norm_elementwise_affine (`bool`, *optional*, defaults to True), - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - num_kv_heads: int, - multiple_of: int, - ffn_dim_multiplier: float, - norm_eps: float, - qk_norm: bool, - cross_attention_dim: int, - norm_elementwise_affine: bool = True, - ) -> None: - super().__init__() - self.head_dim = dim // num_attention_heads - - self.gate = nn.Parameter(torch.zeros([num_attention_heads])) - - # Self-attention - self.attn1 = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=dim // num_attention_heads, - qk_norm="layer_norm_across_heads" if qk_norm else None, - heads=num_attention_heads, - kv_heads=num_kv_heads, - eps=1e-5, - bias=False, - out_bias=False, - processor=LuminaAttnProcessor2_0(), - ) - self.attn1.to_out = nn.Identity() - - # Cross-attention - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - dim_head=dim // num_attention_heads, - qk_norm="layer_norm_across_heads" if qk_norm else None, - heads=num_attention_heads, - kv_heads=num_kv_heads, - eps=1e-5, - bias=False, - out_bias=False, - processor=LuminaAttnProcessor2_0(), - ) - - self.feed_forward = LuminaFeedForward( - dim=dim, - inner_dim=4 * dim, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - ) - - self.norm1 = LuminaRMSNormZero( - embedding_dim=dim, - norm_eps=norm_eps, - norm_elementwise_affine=norm_elementwise_affine, - ) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - - self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - - self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - image_rotary_emb: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_mask: torch.Tensor, - temb: torch.Tensor, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Perform a forward pass through the LuminaNextDiTBlock. - - Parameters: - hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. - attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. - encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. - encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. - temb (`torch.Tensor`): Timestep embedding with text prompt embedding. - cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. - """ - residual = hidden_states - - # Self-attention - norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - self_attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - attention_mask=attention_mask, - query_rotary_emb=image_rotary_emb, - key_rotary_emb=image_rotary_emb, - **cross_attention_kwargs, - ) - - # Cross-attention - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) - cross_attn_output = self.attn2( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - attention_mask=encoder_mask, - query_rotary_emb=image_rotary_emb, - key_rotary_emb=None, - **cross_attention_kwargs, - ) - cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) - mixed_attn_output = self_attn_output + cross_attn_output - mixed_attn_output = mixed_attn_output.flatten(-2) - # linear proj - hidden_states = self.attn2.to_out[0](mixed_attn_output) - - hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) - - mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) - - hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) - - return hidden_states - - -class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): - """ - LuminaNextDiT: Diffusion model with a Transformer backbone. - - Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. - - Parameters: - sample_size (`int`): The width of the latent images. This is fixed during training since - it is used to learn a number of position embeddings. - patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): - The size of each patch in the image. This parameter defines the resolution of patches fed into the model. - in_channels (`int`, *optional*, defaults to 4): - The number of input channels for the model. Typically, this matches the number of channels in the input - images. - hidden_size (`int`, *optional*, defaults to 4096): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - num_layers (`int`, *optional*, default to 32): - The number of layers in the model. This defines the depth of the neural network. - num_attention_heads (`int`, *optional*, defaults to 32): - The number of attention heads in each attention layer. This parameter specifies how many separate attention - mechanisms are used. - num_kv_heads (`int`, *optional*, defaults to 8): - The number of key-value heads in the attention mechanism, if different from the number of attention heads. - If None, it defaults to num_attention_heads. - multiple_of (`int`, *optional*, defaults to 256): - A factor that the hidden size should be a multiple of. This can help optimize certain hardware - configurations. - ffn_dim_multiplier (`float`, *optional*): - A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on - the model configuration. - norm_eps (`float`, *optional*, defaults to 1e-5): - A small value added to the denominator for numerical stability in normalization layers. - learn_sigma (`bool`, *optional*, defaults to True): - Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in - predictions. - qk_norm (`bool`, *optional*, defaults to True): - Indicates if the queries and keys in the attention mechanism should be normalized. - cross_attention_dim (`int`, *optional*, defaults to 2048): - The dimensionality of the text embeddings. This parameter defines the size of the text representations used - in the model. - scaling_factor (`float`, *optional*, defaults to 1.0): - A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the - overall scale of the model's operations. - """ - - @register_to_config - def __init__( - self, - sample_size: int = 128, - patch_size: Optional[int] = 2, - in_channels: Optional[int] = 4, - hidden_size: Optional[int] = 2304, - num_layers: Optional[int] = 32, - num_attention_heads: Optional[int] = 32, - num_kv_heads: Optional[int] = None, - multiple_of: Optional[int] = 256, - ffn_dim_multiplier: Optional[float] = None, - norm_eps: Optional[float] = 1e-5, - learn_sigma: Optional[bool] = True, - qk_norm: Optional[bool] = True, - cross_attention_dim: Optional[int] = 2048, - scaling_factor: Optional[float] = 1.0, - ) -> None: - super().__init__() - self.sample_size = sample_size - self.patch_size = patch_size - self.in_channels = in_channels - self.out_channels = in_channels * 2 if learn_sigma else in_channels - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.head_dim = hidden_size // num_attention_heads - self.scaling_factor = scaling_factor - - self.patch_embedder = LuminaPatchEmbed( - patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True - ) - - self.pad_token = nn.Parameter(torch.empty(hidden_size)) - - self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding( - hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim - ) - - self.layers = nn.ModuleList( - [ - LuminaNextDiTBlock( - hidden_size, - num_attention_heads, - num_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - qk_norm, - cross_attention_dim, - ) - for _ in range(num_layers) - ] - ) - self.norm_out = LuminaLayerNormContinuous( - embedding_dim=hidden_size, - conditioning_embedding_dim=min(hidden_size, 1024), - elementwise_affine=False, - eps=1e-6, - bias=True, - out_dim=patch_size * patch_size * self.out_channels, - ) - # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) - - assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" - - def forward( - self, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_mask: torch.Tensor, - image_rotary_emb: torch.Tensor, - cross_attention_kwargs: Dict[str, Any] = None, - return_dict=True, - ) -> torch.Tensor: - """ - Forward pass of LuminaNextDiT. - - Parameters: - hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). - timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). - encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). - encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). - """ - hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) - image_rotary_emb = image_rotary_emb.to(hidden_states.device) - - temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) - - encoder_mask = encoder_mask.bool() - for layer in self.layers: - hidden_states = layer( - hidden_states, - mask, - image_rotary_emb, - encoder_hidden_states, - encoder_mask, - temb=temb, - cross_attention_kwargs=cross_attention_kwargs, - ) - - hidden_states = self.norm_out(hidden_states, temb) - - # unpatchify - height_tokens = width_tokens = self.patch_size - height, width = img_size[0] - batch_size = hidden_states.size(0) - sequence_length = (height // height_tokens) * (width // width_tokens) - hidden_states = hidden_states[:, :sequence_length].view( - batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels - ) - output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py deleted file mode 100644 index 7f145edf16fb..000000000000 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional, Union - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging -from ..attention import BasicTransformerBlock -from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class PixArtTransformer2DModel(ModelMixin, ConfigMixin): - r""" - A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, - https://arxiv.org/abs/2403.04692). - - Parameters: - num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (int, optional, defaults to 72): The number of channels in each head. - in_channels (int, defaults to 4): The number of channels in the input. - out_channels (int, optional): - The number of channels in the output. Specify this parameter if the output channel number differs from the - input. - num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. - dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. - norm_num_groups (int, optional, defaults to 32): - Number of groups for group normalization within Transformer blocks. - cross_attention_dim (int, optional): - The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. - attention_bias (bool, optional, defaults to True): - Configure if the Transformer blocks' attention should contain a bias parameter. - sample_size (int, defaults to 128): - The width of the latent images. This parameter is fixed during training. - patch_size (int, defaults to 2): - Size of the patches the model processes, relevant for architectures working on non-sequential data. - activation_fn (str, optional, defaults to "gelu-approximate"): - Activation function to use in feed-forward networks within Transformer blocks. - num_embeds_ada_norm (int, optional, defaults to 1000): - Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during - inference. - upcast_attention (bool, optional, defaults to False): - If true, upcasts the attention mechanism dimensions for potentially improved performance. - norm_type (str, optional, defaults to "ada_norm_zero"): - Specifies the type of normalization used, can be 'ada_norm_zero'. - norm_elementwise_affine (bool, optional, defaults to False): - If true, enables element-wise affine parameters in the normalization layers. - norm_eps (float, optional, defaults to 1e-6): - A small constant added to the denominator in normalization layers to prevent division by zero. - interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. - use_additional_conditions (bool, optional): If we're using additional conditions as inputs. - attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. - caption_channels (int, optional, defaults to None): - Number of channels to use for projecting the caption embeddings. - use_linear_projection (bool, optional, defaults to False): - Deprecated argument. Will be removed in a future version. - num_vector_embeds (bool, optional, defaults to False): - Deprecated argument. Will be removed in a future version. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 72, - in_channels: int = 4, - out_channels: Optional[int] = 8, - num_layers: int = 28, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = 1152, - attention_bias: bool = True, - sample_size: int = 128, - patch_size: int = 2, - activation_fn: str = "gelu-approximate", - num_embeds_ada_norm: Optional[int] = 1000, - upcast_attention: bool = False, - norm_type: str = "ada_norm_single", - norm_elementwise_affine: bool = False, - norm_eps: float = 1e-6, - interpolation_scale: Optional[int] = None, - use_additional_conditions: Optional[bool] = None, - caption_channels: Optional[int] = None, - attention_type: Optional[str] = "default", - ): - super().__init__() - - # Validate inputs. - if norm_type != "ada_norm_single": - raise NotImplementedError( - f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." - ) - elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: - raise ValueError( - f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." - ) - - # Set some common variables used across the board. - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.out_channels = in_channels if out_channels is None else out_channels - if use_additional_conditions is None: - if sample_size == 128: - use_additional_conditions = True - else: - use_additional_conditions = False - self.use_additional_conditions = use_additional_conditions - - self.gradient_checkpointing = False - - # 2. Initialize the position embedding and transformer blocks. - self.height = self.config.sample_size - self.width = self.config.sample_size - - interpolation_scale = ( - self.config.interpolation_scale - if self.config.interpolation_scale is not None - else max(self.config.sample_size // 64, 1) - ) - self.pos_embed = PatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - cross_attention_dim=self.config.cross_attention_dim, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - attention_type=self.config.attention_type, - ) - for _ in range(self.config.num_layers) - ] - ) - - # 3. Output blocks. - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) - - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=self.use_additional_conditions - ) - self.caption_projection = None - if self.config.caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.config.caption_channels, hidden_size=self.inner_dim - ) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - - Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model. - """ - self.set_attn_processor(AttnProcessor()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`PixArtTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep (`torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - batch_size = hidden_states.shape[0] - height, width = ( - hidden_states.shape[-2] // self.config.patch_size, - hidden_states.shape[-1] // self.config.patch_size, - ) - hidden_states = self.pos_embed(hidden_states) - - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) - - # 2. Blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - None, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=None, - ) - - # 3. Output - shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) - ).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # unpatchify - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py deleted file mode 100644 index fdb67384ff5e..000000000000 --- a/src/diffusers/models/transformers/prior_transformer.py +++ /dev/null @@ -1,380 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Union - -import torch -import torch.nn.functional as F -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput -from ..attention import BasicTransformerBlock -from ..attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) -from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_utils import ModelMixin - - -@dataclass -class PriorTransformerOutput(BaseOutput): - """ - The output of [`PriorTransformer`]. - - Args: - predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): - The predicted CLIP image embedding conditioned on the CLIP text embedding input. - """ - - predicted_image_embedding: torch.Tensor - - -class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): - """ - A Prior Transformer model. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` - num_embeddings (`int`, *optional*, defaults to 77): - The number of embeddings of the model input `hidden_states` - additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the - projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + - additional_embeddings`. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - time_embed_act_fn (`str`, *optional*, defaults to 'silu'): - The activation function to use to create timestep embeddings. - norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before - passing to Transformer blocks. Set it to `None` if normalization is not needed. - embedding_proj_norm_type (`str`, *optional*, defaults to None): - The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not - needed. - encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): - The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if - `encoder_hidden_states` is `None`. - added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. - Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot - product between the text embedding and image embedding as proposed in the unclip paper - https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. - time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. - If None, will be set to `num_attention_heads * attention_head_dim` - embedding_proj_dim (`int`, *optional*, default to None): - The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. - clip_embed_dim (`int`, *optional*, default to None): - The dimension of the output. If None, will be set to `embedding_dim`. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 32, - attention_head_dim: int = 64, - num_layers: int = 20, - embedding_dim: int = 768, - num_embeddings=77, - additional_embeddings=4, - dropout: float = 0.0, - time_embed_act_fn: str = "silu", - norm_in_type: Optional[str] = None, # layer - embedding_proj_norm_type: Optional[str] = None, # layer - encoder_hid_proj_type: Optional[str] = "linear", # linear - added_emb_type: Optional[str] = "prd", # prd - time_embed_dim: Optional[int] = None, - embedding_proj_dim: Optional[int] = None, - clip_embed_dim: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.additional_embeddings = additional_embeddings - - time_embed_dim = time_embed_dim or inner_dim - embedding_proj_dim = embedding_proj_dim or embedding_dim - clip_embed_dim = clip_embed_dim or embedding_dim - - self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) - - self.proj_in = nn.Linear(embedding_dim, inner_dim) - - if embedding_proj_norm_type is None: - self.embedding_proj_norm = None - elif embedding_proj_norm_type == "layer": - self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) - else: - raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") - - self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) - - if encoder_hid_proj_type is None: - self.encoder_hidden_states_proj = None - elif encoder_hid_proj_type == "linear": - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) - else: - raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") - - self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) - - if added_emb_type == "prd": - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) - elif added_emb_type is None: - self.prd_embedding = None - else: - raise ValueError( - f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - activation_fn="gelu", - attention_bias=True, - ) - for d in range(num_layers) - ] - ) - - if norm_in_type == "layer": - self.norm_in = nn.LayerNorm(inner_dim) - elif norm_in_type is None: - self.norm_in = None - else: - raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") - - self.norm_out = nn.LayerNorm(inner_dim) - - self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) - - causal_attention_mask = torch.full( - [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 - ) - causal_attention_mask.triu_(1) - causal_attention_mask = causal_attention_mask[None, ...] - self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) - - self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - def forward( - self, - hidden_states, - timestep: Union[torch.Tensor, float, int], - proj_embedding: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - return_dict: bool = True, - ): - """ - The [`PriorTransformer`] forward method. - - Args: - hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`): - The currently predicted image embeddings. - timestep (`torch.LongTensor`): - Current denoising step. - proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): - Projected embedding vector the denoising process is conditioned on. - encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`): - Hidden states of the text embeddings the denoising process is conditioned on. - attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): - Text mask for the text embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of - a plain tuple. - - Returns: - [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`: - If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. - """ - batch_size = hidden_states.shape[0] - - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(hidden_states.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) - - timesteps_projected = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might be fp16, so we need to cast here. - timesteps_projected = timesteps_projected.to(dtype=self.dtype) - time_embeddings = self.time_embedding(timesteps_projected) - - if self.embedding_proj_norm is not None: - proj_embedding = self.embedding_proj_norm(proj_embedding) - - proj_embeddings = self.embedding_proj(proj_embedding) - if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) - elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: - raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") - - hidden_states = self.proj_in(hidden_states) - - positional_embeddings = self.positional_embedding.to(hidden_states.dtype) - - additional_embeds = [] - additional_embeddings_len = 0 - - if encoder_hidden_states is not None: - additional_embeds.append(encoder_hidden_states) - additional_embeddings_len += encoder_hidden_states.shape[1] - - if len(proj_embeddings.shape) == 2: - proj_embeddings = proj_embeddings[:, None, :] - - if len(hidden_states.shape) == 2: - hidden_states = hidden_states[:, None, :] - - additional_embeds = additional_embeds + [ - proj_embeddings, - time_embeddings[:, None, :], - hidden_states, - ] - - if self.prd_embedding is not None: - prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) - additional_embeds.append(prd_embedding) - - hidden_states = torch.cat( - additional_embeds, - dim=1, - ) - - # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens - additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 - if positional_embeddings.shape[1] < hidden_states.shape[1]: - positional_embeddings = F.pad( - positional_embeddings, - ( - 0, - 0, - additional_embeddings_len, - self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, - ), - value=0.0, - ) - - hidden_states = hidden_states + positional_embeddings - - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) - attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) - - if self.norm_in is not None: - hidden_states = self.norm_in(hidden_states) - - for block in self.transformer_blocks: - hidden_states = block(hidden_states, attention_mask=attention_mask) - - hidden_states = self.norm_out(hidden_states) - - if self.prd_embedding is not None: - hidden_states = hidden_states[:, -1] - else: - hidden_states = hidden_states[:, additional_embeddings_len:] - - predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) - - if not return_dict: - return (predicted_image_embedding,) - - return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) - - def post_process_latents(self, prior_latents): - prior_latents = (prior_latents * self.clip_std) + self.clip_mean - return prior_latents diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py deleted file mode 100644 index d687dbabf317..000000000000 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward -from ...models.attention_processor import ( - Attention, - AttentionProcessor, - StableAudioAttnProcessor2_0, -) -from ...models.modeling_utils import ModelMixin -from ...models.transformers.transformer_2d import Transformer2DModelOutput -from ...utils import is_torch_version, logging -from ...utils.torch_utils import maybe_allow_in_graph - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableAudioGaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ - def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - del self.weight - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.weight = self.W - del self.W - - def forward(self, x): - if self.log: - x = torch.log(x) - - x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] - - if self.flip_sin_to_cos: - out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) - else: - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out - - -@maybe_allow_in_graph -class StableAudioDiTBlock(nn.Module): - r""" - Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip - connection and QKNorm - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for the query states. - num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - num_key_value_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - upcast_attention: bool = False, - norm_eps: float = 1e-5, - ff_inner_dim: Optional[int] = None, - ): - super().__init__() - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=False, - upcast_attention=upcast_attention, - out_bias=False, - processor=StableAudioAttnProcessor2_0(), - ) - - # 2. Cross-Attn - self.norm2 = nn.LayerNorm(dim, norm_eps, True) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - kv_heads=num_key_value_attention_heads, - dropout=dropout, - bias=False, - upcast_attention=upcast_attention, - out_bias=False, - processor=StableAudioAttnProcessor2_0(), - ) # is self-attn if encoder_hidden_states is none - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, norm_eps, True) - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn="swiglu", - final_dropout=False, - inner_dim=ff_inner_dim, - bias=True, - ) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - rotary_embedding: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - attention_mask=attention_mask, - rotary_emb=rotary_embedding, - ) - - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention - norm_hidden_states = self.norm2(hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - ff_output = self.ff(norm_hidden_states) - - hidden_states = ff_output + hidden_states - - return hidden_states - - -class StableAudioDiTModel(ModelMixin, ConfigMixin): - """ - The Diffusion Transformer model introduced in Stable Audio. - - Reference: https://github.com/Stability-AI/stable-audio-tools - - Parameters: - sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. - in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. - num_key_value_attention_heads (`int`, *optional*, defaults to 12): - The number of heads to use for the key and value states. - out_channels (`int`, defaults to 64): Number of output channels. - cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. - time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. - global_states_input_dim ( `int`, *optional*, defaults to 1536): - Input dimension of the global hidden states projection. - cross_attention_input_dim ( `int`, *optional*, defaults to 768): - Input dimension of the cross-attention projection - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int = 1024, - in_channels: int = 64, - num_layers: int = 24, - attention_head_dim: int = 64, - num_attention_heads: int = 24, - num_key_value_attention_heads: int = 12, - out_channels: int = 64, - cross_attention_dim: int = 768, - time_proj_dim: int = 256, - global_states_input_dim: int = 1536, - cross_attention_input_dim: int = 768, - ): - super().__init__() - self.sample_size = sample_size - self.out_channels = out_channels - self.inner_dim = num_attention_heads * attention_head_dim - - self.time_proj = StableAudioGaussianFourierProjection( - embedding_size=time_proj_dim // 2, - flip_sin_to_cos=True, - log=False, - set_W_to_weight=False, - ) - - self.timestep_proj = nn.Sequential( - nn.Linear(time_proj_dim, self.inner_dim, bias=True), - nn.SiLU(), - nn.Linear(self.inner_dim, self.inner_dim, bias=True), - ) - - self.global_proj = nn.Sequential( - nn.Linear(global_states_input_dim, self.inner_dim, bias=False), - nn.SiLU(), - nn.Linear(self.inner_dim, self.inner_dim, bias=False), - ) - - self.cross_attention_proj = nn.Sequential( - nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), - nn.SiLU(), - nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), - ) - - self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) - self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) - - self.transformer_blocks = nn.ModuleList( - [ - StableAudioDiTBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - num_key_value_attention_heads=num_key_value_attention_heads, - attention_head_dim=attention_head_dim, - cross_attention_dim=cross_attention_dim, - ) - for i in range(num_layers) - ] - ) - - self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) - self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - self.set_attn_processor(StableAudioAttnProcessor2_0()) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.FloatTensor, - timestep: torch.LongTensor = None, - encoder_hidden_states: torch.FloatTensor = None, - global_hidden_states: torch.FloatTensor = None, - rotary_embedding: torch.FloatTensor = None, - return_dict: bool = True, - attention_mask: Optional[torch.LongTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`StableAudioDiTModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): - Input `hidden_states`. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): - Global embeddings that will be prepended to the hidden states. - rotary_embedding (`torch.Tensor`): - The rotary embeddings to apply on query and key tensors during attention calculation. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): - Mask to avoid performing attention on padding token indices, formed by concatenating the attention - masks - for the two text encoders together. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): - Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating - the attention masks - for the two text encoders together. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) - global_hidden_states = self.global_proj(global_hidden_states) - time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) - - global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) - - hidden_states = self.preprocess_conv(hidden_states) + hidden_states - # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) - hidden_states = hidden_states.transpose(1, 2) - - hidden_states = self.proj_in(hidden_states) - - # prepend global states to hidden states - hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) - if attention_mask is not None: - prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) - attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) - - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - cross_attention_hidden_states, - encoder_attention_mask, - rotary_embedding, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=cross_attention_hidden_states, - encoder_attention_mask=encoder_attention_mask, - rotary_embedding=rotary_embedding, - ) - - hidden_states = self.proj_out(hidden_states) - - # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) - # remove prepend length that has been added by global hidden states - hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] - hidden_states = self.postprocess_conv(hidden_states) + hidden_states - - if not return_dict: - return (hidden_states,) - - return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/models/transformers/t5_film_transformer.py b/src/diffusers/models/transformers/t5_film_transformer.py deleted file mode 100644 index 1dea37a25910..000000000000 --- a/src/diffusers/models/transformers/t5_film_transformer.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from typing import Optional, Tuple - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ..attention_processor import Attention -from ..embeddings import get_timestep_embedding -from ..modeling_utils import ModelMixin - - -class T5FilmDecoder(ModelMixin, ConfigMixin): - r""" - T5 style decoder with FiLM conditioning. - - Args: - input_dims (`int`, *optional*, defaults to `128`): - The number of input dimensions. - targets_length (`int`, *optional*, defaults to `256`): - The length of the targets. - d_model (`int`, *optional*, defaults to `768`): - Size of the input hidden states. - num_layers (`int`, *optional*, defaults to `12`): - The number of `DecoderLayer`'s to use. - num_heads (`int`, *optional*, defaults to `12`): - The number of attention heads to use. - d_kv (`int`, *optional*, defaults to `64`): - Size of the key-value projection vectors. - d_ff (`int`, *optional*, defaults to `2048`): - The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. - dropout_rate (`float`, *optional*, defaults to `0.1`): - Dropout probability. - """ - - @register_to_config - def __init__( - self, - input_dims: int = 128, - targets_length: int = 256, - max_decoder_noise_time: float = 2000.0, - d_model: int = 768, - num_layers: int = 12, - num_heads: int = 12, - d_kv: int = 64, - d_ff: int = 2048, - dropout_rate: float = 0.1, - ): - super().__init__() - - self.conditioning_emb = nn.Sequential( - nn.Linear(d_model, d_model * 4, bias=False), - nn.SiLU(), - nn.Linear(d_model * 4, d_model * 4, bias=False), - nn.SiLU(), - ) - - self.position_encoding = nn.Embedding(targets_length, d_model) - self.position_encoding.weight.requires_grad = False - - self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) - - self.dropout = nn.Dropout(p=dropout_rate) - - self.decoders = nn.ModuleList() - for lyr_num in range(num_layers): - # FiLM conditional T5 decoder - lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) - self.decoders.append(lyr) - - self.decoder_norm = T5LayerNorm(d_model) - - self.post_dropout = nn.Dropout(p=dropout_rate) - self.spec_out = nn.Linear(d_model, input_dims, bias=False) - - def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor: - mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) - return mask.unsqueeze(-3) - - def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): - batch, _, _ = decoder_input_tokens.shape - assert decoder_noise_time.shape == (batch,) - - # decoder_noise_time is in [0, 1), so rescale to expected timing range. - time_steps = get_timestep_embedding( - decoder_noise_time * self.config.max_decoder_noise_time, - embedding_dim=self.config.d_model, - max_period=self.config.max_decoder_noise_time, - ).to(dtype=self.dtype) - - conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) - - assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) - - seq_length = decoder_input_tokens.shape[1] - - # If we want to use relative positions for audio context, we can just offset - # this sequence by the length of encodings_and_masks. - decoder_positions = torch.broadcast_to( - torch.arange(seq_length, device=decoder_input_tokens.device), - (batch, seq_length), - ) - - position_encodings = self.position_encoding(decoder_positions) - - inputs = self.continuous_inputs_projection(decoder_input_tokens) - inputs += position_encodings - y = self.dropout(inputs) - - # decoder: No padding present. - decoder_mask = torch.ones( - decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype - ) - - # Translate encoding masks to encoder-decoder masks. - encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] - - # cross attend style: concat encodings - encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) - encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) - - for lyr in self.decoders: - y = lyr( - y, - conditioning_emb=conditioning_emb, - encoder_hidden_states=encoded, - encoder_attention_mask=encoder_decoder_mask, - )[0] - - y = self.decoder_norm(y) - y = self.post_dropout(y) - - spec_out = self.spec_out(y) - return spec_out - - -class DecoderLayer(nn.Module): - r""" - T5 decoder layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__( - self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 - ): - super().__init__() - self.layer = nn.ModuleList() - - # cond self attention: layer 0 - self.layer.append( - T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) - ) - - # cross attention: layer 1 - self.layer.append( - T5LayerCrossAttention( - d_model=d_model, - d_kv=d_kv, - num_heads=num_heads, - dropout_rate=dropout_rate, - layer_norm_epsilon=layer_norm_epsilon, - ) - ) - - # Film Cond MLP + dropout: last layer - self.layer.append( - T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) - ) - - def forward( - self, - hidden_states: torch.Tensor, - conditioning_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - encoder_decoder_position_bias=None, - ) -> Tuple[torch.Tensor]: - hidden_states = self.layer[0]( - hidden_states, - conditioning_emb=conditioning_emb, - attention_mask=attention_mask, - ) - - if encoder_hidden_states is not None: - encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( - encoder_hidden_states.dtype - ) - - hidden_states = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_extended_attention_mask, - ) - - # Apply Film Conditional Feed Forward layer - hidden_states = self.layer[-1](hidden_states, conditioning_emb) - - return (hidden_states,) - - -class T5LayerSelfAttentionCond(nn.Module): - r""" - T5 style self-attention layer with conditioning. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - dropout_rate (`float`): - Dropout probability. - """ - - def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): - super().__init__() - self.layer_norm = T5LayerNorm(d_model) - self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) - self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, - hidden_states: torch.Tensor, - conditioning_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # pre_self_attention_layer_norm - normed_hidden_states = self.layer_norm(hidden_states) - - if conditioning_emb is not None: - normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) - - # Self-attention block - attention_output = self.attention(normed_hidden_states) - - hidden_states = hidden_states + self.dropout(attention_output) - - return hidden_states - - -class T5LayerCrossAttention(nn.Module): - r""" - T5 style cross-attention layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_kv (`int`): - Size of the key-value projection vectors. - num_heads (`int`): - Number of attention heads. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): - super().__init__() - self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) - self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.attention( - normed_hidden_states, - encoder_hidden_states=key_value_states, - attention_mask=attention_mask.squeeze(1), - ) - layer_output = hidden_states + self.dropout(attention_output) - return layer_output - - -class T5LayerFFCond(nn.Module): - r""" - T5 style feed-forward conditional layer. - - Args: - d_model (`int`): - Size of the input hidden states. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - layer_norm_epsilon (`float`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): - super().__init__() - self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) - self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) - self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor: - forwarded_states = self.layer_norm(hidden_states) - if conditioning_emb is not None: - forwarded_states = self.film(forwarded_states, conditioning_emb) - - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - r""" - T5 style feed-forward layer with gated activations and dropout. - - Args: - d_model (`int`): - Size of the input hidden states. - d_ff (`int`): - Size of the intermediate feed-forward layer. - dropout_rate (`float`): - Dropout probability. - """ - - def __init__(self, d_model: int, d_ff: int, dropout_rate: float): - super().__init__() - self.wi_0 = nn.Linear(d_model, d_ff, bias=False) - self.wi_1 = nn.Linear(d_model, d_ff, bias=False) - self.wo = nn.Linear(d_ff, d_model, bias=False) - self.dropout = nn.Dropout(dropout_rate) - self.act = NewGELUActivation() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = self.wo(hidden_states) - return hidden_states - - -class T5LayerNorm(nn.Module): - r""" - T5 style layer normalization module. - - Args: - hidden_size (`int`): - Size of the input hidden states. - eps (`float`, `optional`, defaults to `1e-6`): - A small value used for numerical stability to avoid dividing by zero. - """ - - def __init__(self, hidden_size: int, eps: float = 1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - - -class T5FiLMLayer(nn.Module): - """ - T5 style FiLM Layer. - - Args: - in_features (`int`): - Number of input features. - out_features (`int`): - Number of output features. - """ - - def __init__(self, in_features: int, out_features: int): - super().__init__() - self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) - - def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor: - emb = self.scale_bias(conditioning_emb) - scale, shift = torch.chunk(emb, 2, -1) - x = x * (1 + scale) + shift - return x diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py deleted file mode 100644 index e208a1c10ed4..000000000000 --- a/src/diffusers/models/transformers/transformer_2d.py +++ /dev/null @@ -1,566 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from ...configuration_utils import LegacyConfigMixin, register_to_config -from ...utils import deprecate, is_torch_version, logging -from ..attention import BasicTransformerBlock -from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import LegacyModelMixin -from ..normalization import AdaLayerNormSingle - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class Transformer2DModelOutput(Transformer2DModelOutput): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." - deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) - - -class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock"] - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - attention_type: str = "default", - caption_channels: int = None, - interpolation_scale: float = None, - use_additional_conditions: Optional[bool] = None, - ): - super().__init__() - - # Validate inputs. - if patch_size is not None: - if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: - raise NotImplementedError( - f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." - ) - elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: - raise ValueError( - f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." - ) - - # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - # Set some common variables used across the board. - self.use_linear_projection = use_linear_projection - self.interpolation_scale = interpolation_scale - self.caption_channels = caption_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.gradient_checkpointing = False - - if use_additional_conditions is None: - if norm_type == "ada_norm_single" and sample_size == 128: - use_additional_conditions = True - else: - use_additional_conditions = False - self.use_additional_conditions = use_additional_conditions - - # 2. Initialize the right blocks. - # These functions follow a common structure: - # a. Initialize the input blocks. b. Initialize the transformer blocks. - # c. Initialize the output blocks and other projection blocks when necessary. - if self.is_input_continuous: - self._init_continuous_input(norm_type=norm_type) - elif self.is_input_vectorized: - self._init_vectorized_inputs(norm_type=norm_type) - elif self.is_input_patches: - self._init_patched_inputs(norm_type=norm_type) - - def _init_continuous_input(self, norm_type): - self.norm = torch.nn.GroupNorm( - num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True - ) - if self.use_linear_projection: - self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) - else: - self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - cross_attention_dim=self.config.cross_attention_dim, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - only_cross_attention=self.config.only_cross_attention, - double_self_attention=self.config.double_self_attention, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - attention_type=self.config.attention_type, - ) - for _ in range(self.config.num_layers) - ] - ) - - if self.use_linear_projection: - self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) - else: - self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) - - def _init_vectorized_inputs(self, norm_type): - assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert ( - self.config.num_vector_embeds is not None - ), "Transformer2DModel over discrete input must provide num_embed" - - self.height = self.config.sample_size - self.width = self.config.sample_size - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - cross_attention_dim=self.config.cross_attention_dim, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - only_cross_attention=self.config.only_cross_attention, - double_self_attention=self.config.double_self_attention, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - attention_type=self.config.attention_type, - ) - for _ in range(self.config.num_layers) - ] - ) - - self.norm_out = nn.LayerNorm(self.inner_dim) - self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1) - - def _init_patched_inputs(self, norm_type): - assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = self.config.sample_size - self.width = self.config.sample_size - - self.patch_size = self.config.patch_size - interpolation_scale = ( - self.config.interpolation_scale - if self.config.interpolation_scale is not None - else max(self.config.sample_size // 64, 1) - ) - self.pos_embed = PatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - cross_attention_dim=self.config.cross_attention_dim, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - only_cross_attention=self.config.only_cross_attention, - double_self_attention=self.config.double_self_attention, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - attention_type=self.config.attention_type, - ) - for _ in range(self.config.num_layers) - ] - ) - - if self.config.norm_type != "ada_norm_single": - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) - self.proj_out_2 = nn.Linear( - self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels - ) - elif self.config.norm_type == "ada_norm_single": - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Linear( - self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels - ) - - # PixArt-Alpha blocks. - self.adaln_single = None - if self.config.norm_type == "ada_norm_single": - # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use - # additional conditions until we find better name - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=self.use_additional_conditions - ) - - self.caption_projection = None - if self.caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, hidden_size=self.inner_dim - ) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): - Input `hidden_states`. - encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `torch.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, - otherwise a `tuple` where the first element is the sample tensor. - """ - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 2: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Input - if self.is_input_continuous: - batch_size, _, height, width = hidden_states.shape - residual = hidden_states - hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, added_cond_kwargs - ) - - # 2. Blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - cross_attention_kwargs, - class_labels, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - output = self._get_output_for_continuous_inputs( - hidden_states=hidden_states, - residual=residual, - batch_size=batch_size, - height=height, - width=width, - inner_dim=inner_dim, - ) - elif self.is_input_vectorized: - output = self._get_output_for_vectorized_inputs(hidden_states) - elif self.is_input_patches: - output = self._get_output_for_patched_inputs( - hidden_states=hidden_states, - timestep=timestep, - class_labels=class_labels, - embedded_timestep=embedded_timestep, - height=height, - width=width, - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - - def _operate_on_continuous_inputs(self, hidden_states): - batch, _, height, width = hidden_states.shape - hidden_states = self.norm(hidden_states) - - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - - return hidden_states, inner_dim - - def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): - batch_size = hidden_states.shape[0] - hidden_states = self.pos_embed(hidden_states) - embedded_timestep = None - - if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) - - return hidden_states, encoder_hidden_states, timestep, embedded_timestep - - def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - - output = hidden_states + residual - return output - - def _get_output_for_vectorized_inputs(self, hidden_states): - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - return output - - def _get_output_for_patched_inputs( - self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None - ): - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config.norm_type == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - return output diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py deleted file mode 100644 index fe9c7290b063..000000000000 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright 2024 The RhymesAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import AllegroAttnProcessor2_0, Attention -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle - - -logger = logging.get_logger(__name__) - - -@maybe_allow_in_graph -class AllegroTransformerBlock(nn.Module): - r""" - Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. - - Args: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - cross_attention_dim (`int`, defaults to `2304`): - The dimension of the cross attention features. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to be used in feed-forward. - attention_bias (`bool`, defaults to `False`): - Whether or not to use bias in attention projection layers. - only_cross_attention (`bool`, defaults to `False`): - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_eps (`float`, defaults to `1e-5`): - Epsilon value for normalization layers. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - attention_bias: bool = False, - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - ): - super().__init__() - - # 1. Self Attention - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=None, - processor=AllegroAttnProcessor2_0(), - ) - - # 2. Cross Attention - self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - processor=AllegroAttnProcessor2_0(), - ) - - # 3. Feed Forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - ) - - # 4. Scale-shift - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - temb: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb=None, - ) -> torch.Tensor: - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - norm_hidden_states = norm_hidden_states.squeeze(1) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=None, - attention_mask=attention_mask, - image_rotary_emb=image_rotary_emb, - ) - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = hidden_states - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - image_rotary_emb=None, - ) - hidden_states = attn_output + hidden_states - - # 2. Feed-forward - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - - # TODO(aryan): maybe following line is not required - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class AllegroTransformer3DModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - """ - A 3D Transformer model for video-like data. - - Args: - patch_size (`int`, defaults to `2`): - The size of spatial patches to use in the patch embedding layer. - patch_size_t (`int`, defaults to `1`): - The size of temporal patches to use in the patch embedding layer. - num_attention_heads (`int`, defaults to `24`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `96`): - The number of channels in each head. - in_channels (`int`, defaults to `4`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `4`): - The number of channels in the output. - num_layers (`int`, defaults to `32`): - The number of layers of Transformer blocks to use. - dropout (`float`, defaults to `0.0`): - The dropout probability to use. - cross_attention_dim (`int`, defaults to `2304`): - The dimension of the cross attention features. - attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. - sample_height (`int`, defaults to `90`): - The height of the input latents. - sample_width (`int`, defaults to `160`): - The width of the input latents. - sample_frames (`int`, defaults to `22`): - The number of frames in the input latents. - activation_fn (`str`, defaults to `"gelu-approximate"`): - Activation function to use in feed-forward. - norm_elementwise_affine (`bool`, defaults to `False`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, defaults to `1e-6`): - The epsilon value to use in normalization layers. - caption_channels (`int`, defaults to `4096`): - Number of channels to use for projecting the caption embeddings. - interpolation_scale_h (`float`, defaults to `2.0`): - Scaling factor to apply in 3D positional embeddings across height dimension. - interpolation_scale_w (`float`, defaults to `2.0`): - Scaling factor to apply in 3D positional embeddings across width dimension. - interpolation_scale_t (`float`, defaults to `2.2`): - Scaling factor to apply in 3D positional embeddings across time dimension. - """ - - @register_to_config - def __init__( - self, - patch_size: int = 2, - patch_size_t: int = 1, - num_attention_heads: int = 24, - attention_head_dim: int = 96, - in_channels: int = 4, - out_channels: int = 4, - num_layers: int = 32, - dropout: float = 0.0, - cross_attention_dim: int = 2304, - attention_bias: bool = True, - sample_height: int = 90, - sample_width: int = 160, - sample_frames: int = 22, - activation_fn: str = "gelu-approximate", - norm_elementwise_affine: bool = False, - norm_eps: float = 1e-6, - caption_channels: int = 4096, - interpolation_scale_h: float = 2.0, - interpolation_scale_w: float = 2.0, - interpolation_scale_t: float = 2.2, - ): - super().__init__() - - self.inner_dim = num_attention_heads * attention_head_dim - - interpolation_scale_t = ( - interpolation_scale_t - if interpolation_scale_t is not None - else ((sample_frames - 1) // 16 + 1) - if sample_frames % 2 == 1 - else sample_frames // 16 - ) - interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 - interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 - - # 1. Patch embedding - self.pos_embed = PatchEmbed( - height=sample_height, - width=sample_width, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=self.inner_dim, - pos_embed_type=None, - ) - - # 2. Transformer blocks - self.transformer_blocks = nn.ModuleList( - [ - AllegroTransformerBlock( - self.inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for _ in range(num_layers) - ] - ) - - # 3. Output projection & norm - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) - - # 4. Timestep embeddings - self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) - - # 5. Caption projection - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - return_dict: bool = True, - ): - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p_t = self.config.patch_size_t - p = self.config.patch_size - - post_patch_num_frames = num_frames // p_t - post_patch_height = height // p - post_patch_width = width // p - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None - if attention_mask is not None and attention_mask.ndim == 4: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - # b, frame+use_image_num, h, w -> a video with images - # b, 1, h, w -> only images - attention_mask = attention_mask.to(hidden_states.dtype) - attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width] - - if attention_mask.numel() > 0: - attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width] - attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) - attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) - - attention_mask = ( - (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None - ) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 1. Timestep embeddings - timestep, embedded_timestep = self.adaln_single( - timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - # 2. Patch embeddings - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.pos_embed(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) - - # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - # TODO(aryan): Implement gradient checkpointing - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - timestep, - attention_mask, - encoder_attention_mask, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=timestep, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - image_rotary_emb=image_rotary_emb, - ) - - # 4. Output normalization & projection - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) - - # 5. Unpatchify - hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 - ) - hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - output = hidden_states.reshape(batch_size, -1, num_frames, height, width) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py deleted file mode 100644 index 94d852f6df4b..000000000000 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Union - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward -from ...models.attention_processor import ( - Attention, - AttentionProcessor, - CogVideoXAttnProcessor2_0, -) -from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous -from ...utils import is_torch_version, logging -from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed -from ..modeling_outputs import Transformer2DModelOutput -from ..normalization import CogView3PlusAdaLayerNormZeroTextImage - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class CogView3PlusTransformerBlock(nn.Module): - r""" - Transformer block used in [CogView](https://github.com/THUDM/CogView3) model. - - Args: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - time_embed_dim (`int`): - The number of channels in timestep embedding. - """ - - def __init__( - self, - dim: int = 2560, - num_attention_heads: int = 64, - attention_head_dim: int = 40, - time_embed_dim: int = 512, - ): - super().__init__() - - self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - out_dim=dim, - bias=True, - qk_norm="layer_norm", - elementwise_affine=False, - eps=1e-6, - processor=CogVideoXAttnProcessor2_0(), - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) - - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - emb: torch.Tensor, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - # norm & modulate - ( - norm_hidden_states, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - norm_encoder_hidden_states, - c_gate_msa, - c_shift_mlp, - c_scale_mlp, - c_gate_mlp, - ) = self.norm1(hidden_states, encoder_hidden_states, emb) - - # attention - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states - ) - - hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states - encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states - - # norm & modulate - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - - # feed-forward - norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) - ff_output = self.ff(norm_hidden_states) - - hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] - - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return hidden_states, encoder_hidden_states - - -class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): - r""" - The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay - Diffusion](https://huggingface.co/papers/2403.05121). - - Args: - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - in_channels (`int`, defaults to `16`): - The number of channels in the input. - num_layers (`int`, defaults to `30`): - The number of layers of Transformer blocks to use. - attention_head_dim (`int`, defaults to `40`): - The number of channels in each head. - num_attention_heads (`int`, defaults to `64`): - The number of heads to use for multi-head attention. - out_channels (`int`, defaults to `16`): - The number of channels in the output. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - time_embed_dim (`int`, defaults to `512`): - Output dimension of timestep embeddings. - condition_dim (`int`, defaults to `256`): - The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, - crop_coords). - pos_embed_max_size (`int`, defaults to `128`): - The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added - to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 - means that the maximum supported height and width for image generation is `128 * vae_scale_factor * - patch_size => 128 * 8 * 2 => 2048`. - sample_size (`int`, defaults to `128`): - The base resolution of input latents. If height/width is not provided during generation, this value is used - to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 30, - attention_head_dim: int = 40, - num_attention_heads: int = 64, - out_channels: int = 16, - text_embed_dim: int = 4096, - time_embed_dim: int = 512, - condition_dim: int = 256, - pos_embed_max_size: int = 128, - sample_size: int = 128, - ): - super().__init__() - self.out_channels = out_channels - self.inner_dim = num_attention_heads * attention_head_dim - - # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords - # Each of these are sincos embeddings of shape 2 * condition_dim - self.pooled_projection_dim = 3 * 2 * condition_dim - - self.patch_embed = CogView3PlusPatchEmbed( - in_channels=in_channels, - hidden_size=self.inner_dim, - patch_size=patch_size, - text_hidden_size=text_embed_dim, - pos_embed_max_size=pos_embed_max_size, - ) - - self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( - embedding_dim=time_embed_dim, - condition_dim=condition_dim, - pooled_projection_dim=self.pooled_projection_dim, - timesteps_dim=self.inner_dim, - ) - - self.transformer_blocks = nn.ModuleList( - [ - CogView3PlusTransformerBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - ) - for _ in range(num_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous( - embedding_dim=self.inner_dim, - conditioning_embedding_dim=time_embed_dim, - elementwise_affine=False, - eps=1e-6, - ) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, - original_size: torch.Tensor, - target_size: torch.Tensor, - crop_coords: torch.Tensor, - return_dict: bool = True, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - """ - The [`CogView3PlusTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.Tensor`): - Input `hidden_states` of shape `(batch size, channel, height, width)`. - encoder_hidden_states (`torch.Tensor`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape - `(batch_size, sequence_len, text_embed_dim)` - timestep (`torch.LongTensor`): - Used to indicate denoising step. - original_size (`torch.Tensor`): - CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`torch.Tensor`): - CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crop_coords (`torch.Tensor`): - CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]: - The denoised latents using provided inputs as conditioning. - """ - height, width = hidden_states.shape[-2:] - text_seq_length = encoder_hidden_states.shape[1] - - hidden_states = self.patch_embed( - hidden_states, encoder_hidden_states - ) # takes care of adding positional embeddings too. - emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) - - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=emb, - ) - - hidden_states = self.norm_out(hidden_states, emb) - hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) - - # unpatchify - patch_size = self.config.patch_size - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) - ) - hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py deleted file mode 100644 index 0ad3be866019..000000000000 --- a/src/diffusers/models/transformers/transformer_flux.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.attention import FeedForward -from ...models.attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) -from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...utils.import_utils import is_torch_npu_available -from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from ..modeling_outputs import Transformer2DModelOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): - super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) - - self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - - if is_torch_npu_available(): - processor = FluxAttnProcessor2_0_NPU() - else: - processor = FluxAttnProcessor2_0() - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm="rms_norm", - eps=1e-6, - pre_only=True, - ) - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - joint_attention_kwargs=None, - ): - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs or {} - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - - return hidden_states - - -@maybe_allow_in_graph -class FluxTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - - def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - - self.norm1_context = AdaLayerNormZero(dim) - - if hasattr(F, "scaled_dot_product_attention"): - processor = FluxAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=eps, - ) - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - joint_attention_kwargs=None, - ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - joint_attention_kwargs = joint_attention_kwargs or {} - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - **joint_attention_kwargs, - ) - - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output - - # Process attention outputs for the `encoder_hidden_states`. - - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - -class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - """ - The Transformer model introduced in Flux. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - - @register_to_config - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: Tuple[int] = (16, 56, 56), - ): - super().__init__() - self.out_channels = in_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim - ) - - self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_layers) - ] - ) - - self.single_transformer_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - ) - for i in range(self.config.num_single_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, - return_dict: bool = True, - controlnet_blocks_repeat: bool = False, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - hidden_states = self.x_embedder(hidden_states) - - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 - else: - guidance = None - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: - logger.warning( - "Passing `txt_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - txt_ids = txt_ids[0] - if img_ids.ndim == 3: - logger.warning( - "Passing `img_ids` 3d torch.Tensor is deprecated." - "Please remove the batch dimension and pass it as a 2d torch Tensor" - ) - img_ids = img_ids[0] - - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) - - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] - ) - else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - - else: - hidden_states = block( - hidden_states=hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py deleted file mode 100644 index fb346a70ba4d..000000000000 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ /dev/null @@ -1,568 +0,0 @@ -# Copyright 2024 The Genmo team and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Tuple - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import MochiAttnProcessor2_0 -from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import ( - AdaLayerNormContinuous, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-n - - -class MochiModulatedRMSNorm(nn.Module): - def __init__(self, eps: float): - super().__init__() - - self.eps = eps - - def forward(self, hidden_states, scale=None): - hidden_states_dtype = hidden_states.dtype - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) - - if scale is not None: - hidden_states = hidden_states * scale - - hidden_states = hidden_states.to(hidden_states_dtype) - - return hidden_states - - -class MochiRMSNorm(nn.Module): - def __init__(self, dim, eps: float, elementwise_affine=True): - super().__init__() - - self.eps = eps - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.weight = None - - def forward(self, hidden_states): - hidden_states_dtype = hidden_states.dtype - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) - - if self.weight is not None: - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight - - hidden_states = hidden_states.to(hidden_states_dtype) - - return hidden_states - - -class MochiLayerNormContinuous(nn.Module): - def __init__( - self, - embedding_dim: int, - conditioning_embedding_dim: int, - eps=1e-5, - bias=True, - ): - super().__init__() - - # AdaLN - self.silu = nn.SiLU() - self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) - self.norm = MochiModulatedRMSNorm(eps=eps) - - def forward( - self, - x: torch.Tensor, - conditioning_embedding: torch.Tensor, - ) -> torch.Tensor: - input_dtype = x.dtype - - # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) - scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) - x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) - - return x.to(input_dtype) - - -class MochiRMSNormZero(nn.Module): - r""" - Adaptive RMS Norm used in Mochi. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - """ - - def __init__( - self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False - ) -> None: - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = MochiModulatedRMSNorm(eps=eps) - - def forward( - self, hidden_states: torch.Tensor, emb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_states_dtype = hidden_states.dtype - - emb = self.linear(self.silu(emb)) - scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - - hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32))) - hidden_states = hidden_states.to(hidden_states_dtype) - - return hidden_states, gate_msa, scale_mlp, gate_mlp - - -class MochiAttention(nn.Module): - def __init__( - self, - query_dim: int, - processor: Optional["MochiAttnProcessor2_0"], - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - out_dim: int = None, - out_context_dim: int = None, - out_bias: bool = True, - context_pre_only: bool = False, - eps: float = 1e-5, - ): - super().__init__() - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim else query_dim - self.context_pre_only = context_pre_only - - self.heads = out_dim // dim_head if out_dim is not None else heads - - self.norm_q = MochiRMSNorm(dim_head, eps) - self.norm_k = MochiRMSNorm(dim_head, eps) - self.norm_added_q = MochiRMSNorm(dim_head, eps) - self.norm_added_k = MochiRMSNorm(dim_head, eps) - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - if not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - - self.processor = processor - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **kwargs, - ) - - -@maybe_allow_in_graph -class MochiTransformerBlock(nn.Module): - r""" - Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). - - Args: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - qk_norm (`str`, defaults to `"rms_norm"`): - The normalization layer to use. - activation_fn (`str`, defaults to `"swiglu"`): - Activation function to use in feed-forward. - context_pre_only (`bool`, defaults to `False`): - Whether or not to process context-related conditions with additional layers. - eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - pooled_projection_dim: int, - qk_norm: str = "rms_norm", - activation_fn: str = "swiglu", - context_pre_only: bool = False, - eps: float = 1e-6, - ) -> None: - super().__init__() - - self.context_pre_only = context_pre_only - self.ff_inner_dim = (4 * dim * 2) // 3 - self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 - - self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) - - if not context_pre_only: - self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) - else: - self.norm1_context = MochiLayerNormContinuous( - embedding_dim=pooled_projection_dim, - conditioning_embedding_dim=dim, - eps=eps, - ) - - self.attn1 = MochiAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - bias=False, - added_kv_proj_dim=pooled_projection_dim, - added_proj_bias=False, - out_dim=dim, - out_context_dim=pooled_projection_dim, - context_pre_only=context_pre_only, - processor=MochiAttnProcessor2_0(), - eps=1e-5, - ) - - # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True - self.norm2 = MochiModulatedRMSNorm(eps=eps) - self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - - self.norm3 = MochiModulatedRMSNorm(eps) - self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - - self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) - self.ff_context = None - if not context_pre_only: - self.ff_context = FeedForward( - pooled_projection_dim, - inner_dim=self.ff_context_inner_dim, - activation_fn=activation_fn, - bias=False, - ) - - self.norm4 = MochiModulatedRMSNorm(eps=eps) - self.norm4_context = MochiModulatedRMSNorm(eps=eps) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - encoder_attention_mask: torch.Tensor, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - - if not self.context_pre_only: - norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( - encoder_hidden_states, temb - ) - else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) - - attn_hidden_states, context_attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attention_mask=encoder_attention_mask, - ) - - hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) - norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) - ff_output = self.ff(norm_hidden_states) - hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) - - if not self.context_pre_only: - encoder_hidden_states = encoder_hidden_states + self.norm2_context( - context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) - ) - norm_encoder_hidden_states = self.norm3_context( - encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)) - ) - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + self.norm4_context( - context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) - ) - - return hidden_states, encoder_hidden_states - - -class MochiRoPE(nn.Module): - r""" - RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). - - Args: - base_height (`int`, defaults to `192`): - Base height used to compute interpolation scale for rotary positional embeddings. - base_width (`int`, defaults to `192`): - Base width used to compute interpolation scale for rotary positional embeddings. - """ - - def __init__(self, base_height: int = 192, base_width: int = 192) -> None: - super().__init__() - - self.target_area = base_height * base_width - - def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: - edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) - return (edges[:-1] + edges[1:]) / 2 - - def _get_positions( - self, - num_frames: int, - height: int, - width: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - scale = (self.target_area / (height * width)) ** 0.5 - - t = torch.arange(num_frames, device=device, dtype=dtype) - h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) - w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) - - grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") - - positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) - return positions - - def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - with torch.autocast(freqs.device.type, enabled=False): - # Always run ROPE freqs computation in FP32 - freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) - - freqs_cos = torch.cos(freqs) - freqs_sin = torch.sin(freqs) - return freqs_cos, freqs_sin - - def forward( - self, - pos_frequencies: torch.Tensor, - num_frames: int, - height: int, - width: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - pos = self._get_positions(num_frames, height, width, device, dtype) - rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) - return rope_cos, rope_sin - - -@maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin): - r""" - A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). - - Args: - patch_size (`int`, defaults to `2`): - The size of the patches to use in the patch embedding layer. - num_attention_heads (`int`, defaults to `24`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`, defaults to `128`): - The number of channels in each head. - num_layers (`int`, defaults to `48`): - The number of layers of Transformer blocks to use. - in_channels (`int`, defaults to `12`): - The number of channels in the input. - out_channels (`int`, *optional*, defaults to `None`): - The number of channels in the output. - qk_norm (`str`, defaults to `"rms_norm"`): - The normalization layer to use. - text_embed_dim (`int`, defaults to `4096`): - Input dimension of text embeddings from the text encoder. - time_embed_dim (`int`, defaults to `256`): - Output dimension of timestep embeddings. - activation_fn (`str`, defaults to `"swiglu"`): - Activation function to use in feed-forward. - max_sequence_length (`int`, defaults to `256`): - The maximum sequence length of text embeddings supported. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - patch_size: int = 2, - num_attention_heads: int = 24, - attention_head_dim: int = 128, - num_layers: int = 48, - pooled_projection_dim: int = 1536, - in_channels: int = 12, - out_channels: Optional[int] = None, - qk_norm: str = "rms_norm", - text_embed_dim: int = 4096, - time_embed_dim: int = 256, - activation_fn: str = "swiglu", - max_sequence_length: int = 256, - ) -> None: - super().__init__() - - inner_dim = num_attention_heads * attention_head_dim - out_channels = out_channels or in_channels - - self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - pos_embed_type=None, - ) - - self.time_embed = MochiCombinedTimestepCaptionEmbedding( - embedding_dim=inner_dim, - pooled_projection_dim=pooled_projection_dim, - text_embed_dim=text_embed_dim, - time_embed_dim=time_embed_dim, - num_attention_heads=8, - ) - - self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) - self.rope = MochiRoPE() - - self.transformer_blocks = nn.ModuleList( - [ - MochiTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - pooled_projection_dim=pooled_projection_dim, - qk_norm=qk_norm, - activation_fn=activation_fn, - context_pre_only=i == num_layers - 1, - ) - for i in range(num_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous( - inner_dim, - inner_dim, - elementwise_affine=False, - eps=1e-6, - norm_type="layer_norm", - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - - self.gradient_checkpointing = False - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: torch.LongTensor, - encoder_attention_mask: torch.Tensor, - return_dict: bool = True, - ) -> torch.Tensor: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - p = self.config.patch_size - - post_patch_height = height // p - post_patch_width = width // p - - temb, encoder_hidden_states = self.time_embed( - timestep, - encoder_hidden_states, - encoder_attention_mask, - hidden_dtype=hidden_states.dtype, - ) - - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.patch_embed(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - - image_rotary_emb = self.rope( - self.pos_frequencies, - num_frames, - post_patch_height, - post_patch_width, - device=hidden_states.device, - dtype=torch.float32, - ) - - for i, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - encoder_attention_mask=encoder_attention_mask, - image_rotary_emb=image_rotary_emb, - ) - hidden_states = self.norm_out(hidden_states, temb) - hidden_states = self.proj_out(hidden_states) - - hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) - hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) - output = hidden_states.reshape(batch_size, -1, num_frames, height, width) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py deleted file mode 100644 index f39a102c7256..000000000000 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.attention import JointTransformerBlock -from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 -from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed -from ..modeling_outputs import Transformer2DModelOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - """ - The Transformer model introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - sample_size (`int`): The width of the latent images. This is fixed during training since - it is used to learn a number of position embeddings. - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - out_channels (`int`, defaults to 16): Number of output channels. - - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int = 128, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 18, - attention_head_dim: int = 64, - num_attention_heads: int = 18, - joint_attention_dim: int = 4096, - caption_projection_dim: int = 1152, - pooled_projection_dim: int = 2048, - out_channels: int = 16, - pos_embed_max_size: int = 96, - dual_attention_layers: Tuple[ - int, ... - ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 - qk_norm: Optional[str] = None, - ): - super().__init__() - default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - - self.pos_embed = PatchEmbed( - height=self.config.sample_size, - width=self.config.sample_size, - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, - embed_dim=self.inner_dim, - pos_embed_max_size=pos_embed_max_size, # hard-code for now. - ) - self.time_text_embed = CombinedTimestepTextProjEmbeddings( - embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim - ) - self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) - - # `attention_head_dim` is doubled to account for the mixing. - # It needs to crafted when we get the actual checkpoints. - self.transformer_blocks = nn.ModuleList( - [ - JointTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, - context_pre_only=i == num_layers - 1, - qk_norm=qk_norm, - use_dual_attention=True if i in dual_attention_layers else False, - ) - for i in range(self.config.num_layers) - ] - ) - - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - - self.gradient_checkpointing = False - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking - def disable_forward_chunking(self): - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, None, 0) - - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedJointAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - pooled_projections: torch.FloatTensor = None, - timestep: torch.LongTensor = None, - block_controlnet_hidden_states: List = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: - """ - The [`SD3Transformer2DModel`] forward method. - - Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): - Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): - Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. - timestep ( `torch.LongTensor`): - Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) - - height, width = hidden_states.shape[-2:] - - hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. - temb = self.time_text_embed(timestep, pooled_projections) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) - - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb - ) - - # controlnet residual - if block_controlnet_hidden_states is not None and block.context_pre_only is False: - interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) - hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] - - hidden_states = self.norm_out(hidden_states, temb) - hidden_states = self.proj_out(hidden_states) - - # unpatchify - patch_size = self.config.patch_size - height = height // patch_size - width = width // patch_size - - hidden_states = hidden_states.reshape( - shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) - ) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py deleted file mode 100644 index 6ca42b9745fd..000000000000 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput -from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock -from ..embeddings import TimestepEmbedding, Timesteps -from ..modeling_utils import ModelMixin -from ..resnet import AlphaBlender - - -@dataclass -class TransformerTemporalModelOutput(BaseOutput): - """ - The output of [`TransformerTemporalModel`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. - """ - - sample: torch.Tensor - - -class TransformerTemporalModel(ModelMixin, ConfigMixin): - """ - A Transformer model for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlock` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported - activation functions. - norm_elementwise_affine (`bool`, *optional*): - Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. - double_self_attention (`bool`, *optional*): - Configure if each `TransformerBlock` should contain two self-attention layers. - positional_embeddings: (`str`, *optional*): - The type of positional embeddings to apply to the sequence input before passing use. - num_positional_embeddings: (`int`, *optional*): - The maximum length of the sequence over which to apply positional embeddings. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - activation_fn: str = "geglu", - norm_elementwise_affine: bool = True, - double_self_attention: bool = True, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - double_self_attention=double_self_attention, - norm_elementwise_affine=norm_elementwise_affine, - positional_embeddings=positional_embeddings, - num_positional_embeddings=num_positional_embeddings, - ) - for d in range(num_layers) - ] - ) - - self.proj_out = nn.Linear(inner_dim, in_channels) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.LongTensor] = None, - timestep: Optional[torch.LongTensor] = None, - class_labels: torch.LongTensor = None, - num_frames: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> TransformerTemporalModelOutput: - """ - The [`TransformerTemporal`] forward method. - - Args: - hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): - Input hidden_states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - num_frames (`int`, *optional*, defaults to 1): - The number of frames to be processed per batch. This is used to reshape the hidden states. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] - instead of a plain tuple. - - Returns: - [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an - [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, channel, height, width = hidden_states.shape - batch_size = batch_frames // num_frames - - residual = hidden_states - - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) - hidden_states = hidden_states.permute(0, 2, 1, 3, 4) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) - - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states[None, None, :] - .reshape(batch_size, height, width, num_frames, channel) - .permute(0, 3, 4, 1, 2) - .contiguous() - ) - hidden_states = hidden_states.reshape(batch_frames, channel, height, width) - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return TransformerTemporalModelOutput(sample=output) - - -class TransformerSpatioTemporalModel(nn.Module): - """ - A Transformer model for video-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - out_channels (`int`, *optional*): - The number of channels in the output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: int = 320, - out_channels: Optional[int] = None, - num_layers: int = 1, - cross_attention_dim: Optional[int] = None, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - - inner_dim = num_attention_heads * attention_head_dim - self.inner_dim = inner_dim - - # 2. Define input layers - self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) - self.proj_in = nn.Linear(in_channels, inner_dim) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - ) - for d in range(num_layers) - ] - ) - - time_mix_inner_dim = inner_dim - self.temporal_transformer_blocks = nn.ModuleList( - [ - TemporalBasicTransformerBlock( - inner_dim, - time_mix_inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - ) - for _ in range(num_layers) - ] - ) - - time_embed_dim = in_channels * 4 - self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) - self.time_proj = Timesteps(in_channels, True, 0) - self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - # TODO: should use out_channels for continuous projections - self.proj_out = nn.Linear(inner_dim, in_channels) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - image_only_indicator: Optional[torch.Tensor] = None, - return_dict: bool = True, - ): - """ - Args: - hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): - Input hidden_states. - num_frames (`int`): - The number of frames to be processed per batch. This is used to reshape the hidden states. - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): - A tensor indicating whether the input contains only images. 1 indicates that the input contains only - images, 0 indicates that the input contains video frames. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] - instead of a plain tuple. - - Returns: - [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: - If `return_dict` is True, an - [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - # 1. Input - batch_frames, _, height, width = hidden_states.shape - num_frames = image_only_indicator.shape[-1] - batch_size = batch_frames // num_frames - - time_context = encoder_hidden_states - time_context_first_timestep = time_context[None, :].reshape( - batch_size, num_frames, -1, time_context.shape[-1] - )[:, 0] - time_context = time_context_first_timestep[:, None].broadcast_to( - batch_size, height * width, time_context.shape[-2], time_context.shape[-1] - ) - time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) - - residual = hidden_states - - hidden_states = self.norm(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - - num_frames_emb = torch.arange(num_frames, device=hidden_states.device) - num_frames_emb = num_frames_emb.repeat(batch_size, 1) - num_frames_emb = num_frames_emb.reshape(-1) - t_emb = self.time_proj(num_frames_emb) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - - emb = self.time_pos_embed(t_emb) - emb = emb[:, None, :] - - # 2. Blocks - for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - block, - hidden_states, - None, - encoder_hidden_states, - None, - use_reentrant=False, - ) - else: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states_mix = hidden_states - hidden_states_mix = hidden_states_mix + emb - - hidden_states_mix = temporal_block( - hidden_states_mix, - num_frames=num_frames, - encoder_hidden_states=time_context, - ) - hidden_states = self.time_mixer( - x_spatial=hidden_states, - x_temporal=hidden_states_mix, - image_only_indicator=image_only_indicator, - ) - - # 3. Output - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return TransformerTemporalModelOutput(sample=output) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 572cf87172a7..c815b13e6d19 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -594,38 +594,38 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - prompt_embeds.dtype, - device, - generator, - latents, - ) + with torch.autocast("cuda", torch.float32): + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) From 0e8f20db4652c01c6d407eec8535d5bf749d238b Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 06:27:38 +0100 Subject: [PATCH 26/47] update --- src/diffusers/models/transformers/__init__.py | 22 + .../transformers/auraflow_transformer_2d.py | 544 +++++++++++++++++ .../transformers/cogvideox_transformer_3d.py | 507 +++++++++++++++ .../models/transformers/dit_transformer_2d.py | 240 ++++++++ .../transformers/dual_transformer_2d.py | 156 +++++ .../transformers/hunyuan_transformer_2d.py | 578 ++++++++++++++++++ .../transformers/latte_transformer_3d.py | 327 ++++++++++ .../models/transformers/lumina_nextdit2d.py | 340 +++++++++++ .../transformers/pixart_transformer_2d.py | 445 ++++++++++++++ .../models/transformers/prior_transformer.py | 380 ++++++++++++ .../transformers/stable_audio_transformer.py | 458 ++++++++++++++ .../transformers/t5_film_transformer.py | 436 +++++++++++++ .../models/transformers/transformer_2d.py | 566 +++++++++++++++++ .../transformers/transformer_allegro.py | 422 +++++++++++++ .../transformers/transformer_cogview3plus.py | 386 ++++++++++++ .../models/transformers/transformer_flux.py | 577 +++++++++++++++++ .../models/transformers/transformer_mochi.py | 568 +++++++++++++++++ .../models/transformers/transformer_sd3.py | 373 +++++++++++ .../transformers/transformer_temporal.py | 381 ++++++++++++ 19 files changed, 7706 insertions(+) create mode 100644 src/diffusers/models/transformers/__init__.py create mode 100644 src/diffusers/models/transformers/auraflow_transformer_2d.py create mode 100644 src/diffusers/models/transformers/cogvideox_transformer_3d.py create mode 100644 src/diffusers/models/transformers/dit_transformer_2d.py create mode 100644 src/diffusers/models/transformers/dual_transformer_2d.py create mode 100644 src/diffusers/models/transformers/hunyuan_transformer_2d.py create mode 100644 src/diffusers/models/transformers/latte_transformer_3d.py create mode 100644 src/diffusers/models/transformers/lumina_nextdit2d.py create mode 100644 src/diffusers/models/transformers/pixart_transformer_2d.py create mode 100644 src/diffusers/models/transformers/prior_transformer.py create mode 100644 src/diffusers/models/transformers/stable_audio_transformer.py create mode 100644 src/diffusers/models/transformers/t5_film_transformer.py create mode 100644 src/diffusers/models/transformers/transformer_2d.py create mode 100644 src/diffusers/models/transformers/transformer_allegro.py create mode 100644 src/diffusers/models/transformers/transformer_cogview3plus.py create mode 100644 src/diffusers/models/transformers/transformer_flux.py create mode 100644 src/diffusers/models/transformers/transformer_mochi.py create mode 100644 src/diffusers/models/transformers/transformer_sd3.py create mode 100644 src/diffusers/models/transformers/transformer_temporal.py diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py new file mode 100644 index 000000000000..a2c087d708a4 --- /dev/null +++ b/src/diffusers/models/transformers/__init__.py @@ -0,0 +1,22 @@ +from ...utils import is_torch_available + + +if is_torch_available(): + from .auraflow_transformer_2d import AuraFlowTransformer2DModel + from .cogvideox_transformer_3d import CogVideoXTransformer3DModel + from .dit_transformer_2d import DiTTransformer2DModel + from .dual_transformer_2d import DualTransformer2DModel + from .hunyuan_transformer_2d import HunyuanDiT2DModel + from .latte_transformer_3d import LatteTransformer3DModel + from .lumina_nextdit2d import LuminaNextDiT2DModel + from .pixart_transformer_2d import PixArtTransformer2DModel + from .prior_transformer import PriorTransformer + from .stable_audio_transformer import StableAudioDiTModel + from .t5_film_transformer import T5FilmDecoder + from .transformer_2d import Transformer2DModel + from .transformer_allegro import AllegroTransformer3DModel + from .transformer_cogview3plus import CogView3PlusTransformer2DModel + from .transformer_flux import FluxTransformer2DModel + from .transformer_mochi import MochiTransformer3DModel + from .transformer_sd3 import SD3Transformer2DModel + from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py new file mode 100644 index 000000000000..b3f29e6b6224 --- /dev/null +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -0,0 +1,544 @@ +# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_processor import ( + Attention, + AttentionProcessor, + AuraFlowAttnProcessor2_0, + FusedAuraFlowAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormZero, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Taken from the original aura flow inference code. +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +# Aura Flow patch embed doesn't use convs for projections. +# Additionally, it uses learned positional embeddings. +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def pe_selection_index_based_on_dim(self, h, w): + # select subset of positional embedding based on H, W, where H, W is size of latent + # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected + # because original input are in flattened format, we have to flatten this 2d grid as well. + h_p, w_p = h // self.patch_size, w // self.patch_size + original_pe_indexes = torch.arange(self.pos_embed.shape[1]) + h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) + original_pe_indexes = original_pe_indexes.view(h_max, w_max) + starth = h_max // 2 - h_p // 2 + endh = starth + h_p + startw = w_max // 2 - w_p // 2 + endw = startw + w_p + original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] + return original_pe_indexes.flatten() + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + pe_index = self.pe_selection_index_based_on_dim(height, width) + return latent + self.pos_embed[:, pe_index] + + +# Taken from the original Aura flow inference code. +# Our feedforward only has GELU but Aura uses SiLU. +class AuraFlowFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + final_hidden_dim = int(2 * hidden_dim / 3) + final_hidden_dim = find_multiple(final_hidden_dim, 256) + + self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) + self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) + self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.linear_1(x)) * self.linear_2(x) + x = self.out_projection(x) + return x + + +class AuraFlowPreFinalBlock(nn.Module): + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +@maybe_allow_in_graph +class AuraFlowSingleTransformerBlock(nn.Module): + """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): + residual = hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + # Attention. + attn_output = self.attn(hidden_states=norm_hidden_states) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(hidden_states) + hidden_states = gate_mlp.unsqueeze(1) * ff_output + hidden_states = residual + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class AuraFlowJointTransformerBlock(nn.Module): + r""" + Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): + + * QK Norm in the attention blocks + * No bias in the attention blocks + * Most LayerNorms are in FP32 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + is_last (`bool`): Boolean to determine if this is the last block in the model. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") + + processor = AuraFlowAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + added_proj_bias=False, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="fp32_layer_norm", + out_dim=dim, + bias=False, + out_bias=False, + processor=processor, + context_pre_only=False, + ) + + self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff = AuraFlowFeedForward(dim, dim * 4) + self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) + self.ff_context = AuraFlowFeedForward(dim, dim * 4) + + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + residual = hidden_states + residual_context = encoder_hidden_states + + # Norm + Projection. + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) + hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) + hidden_states = residual + hidden_states + + # Process attention outputs for the `encoder_hidden_states`. + encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) + encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) + encoder_hidden_states = residual_context + encoder_hidden_states + + return encoder_hidden_states, hidden_states + + +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. + num_single_dit_layers (`int`, *optional*, defaults to 4): + The number of layers of Transformer blocks to use. These blocks use concatenated image and text + representations. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + out_channels (`int`, defaults to 16): Number of output channels. + pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. + """ + + _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 64, + patch_size: int = 2, + in_channels: int = 4, + num_mmdit_layers: int = 4, + num_single_dit_layers: int = 32, + attention_head_dim: int = 256, + num_attention_heads: int = 12, + joint_attention_dim: int = 2048, + caption_projection_dim: int = 3072, + out_channels: int = 4, + pos_embed_max_size: int = 1024, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = AuraFlowPatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False + ) + self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) + self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.joint_transformer_blocks = nn.ModuleList( + [ + AuraFlowJointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_mmdit_layers) + ] + ) + self.single_transformer_blocks = nn.ModuleList( + [ + AuraFlowSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_single_dit_layers) + ] + ) + + self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + # https://arxiv.org/abs/2309.16588 + # prevents artifacts in the attention maps + self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + height, width = hidden_states.shape[-2:] + + # Apply patch embedding, timestep embedding, and project the caption embeddings. + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) + temb = self.time_step_proj(temb) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + encoder_hidden_states = torch.cat( + [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 + ) + + # MMDiT blocks. + for index_block, block in enumerate(self.joint_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) + if len(self.single_transformer_blocks) > 0: + encoder_seq_len = encoder_hidden_states.size(1) + combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + combined_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + combined_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) + + hidden_states = combined_hidden_states[:, encoder_seq_len:] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + out_channels = self.config.out_channels + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py new file mode 100644 index 000000000000..01c54ef090bd --- /dev/null +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -0,0 +1,507 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import Attention, FeedForward +from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class CogVideoXBlock(nn.Module): + r""" + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float = 0.0, + activation_fn: str = "gelu-approximate", + attention_bias: bool = False, + qk_norm: bool = True, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = True, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0(), + ) + + # 2. Feed Forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + if not use_rotary_positional_embeddings and use_learned_positional_embeddings: + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " + "embeddings. If you're using a custom model and/or believe this should be supported, please open an " + "issue at https://github.com/huggingface/diffusers/issues." + ) + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + # 2. Time embeddings + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + + # 3. Define spatio-temporal transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) + + # 4. Output blocks + self.norm_out = AdaLayerNorm( + embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + chunk_dim=1, + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py new file mode 100644 index 000000000000..f787c5279499 --- /dev/null +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -0,0 +1,240 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DiTTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 32): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-5): + A small constant added to the denominator in normalization layers to prevent division by zero. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = None, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + attention_bias: bool = True, + sample_size: int = 32, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_zero", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_zero": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict: bool = True, + ): + """ + The [`DiTTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + None, + None, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py new file mode 100644 index 000000000000..1c48c4e3db79 --- /dev/null +++ b/src/diffusers/models/transformers/dual_transformer_2d.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import nn + +from ..modeling_outputs import Transformer2DModelOutput +from .transformer_2d import Transformer2DModel + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.Tensor`, *optional*): + Optional attention mask to be applied in Attention. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py new file mode 100644 index 000000000000..7f3dab220aaa --- /dev/null +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -0,0 +1,578 @@ +# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 +from ..embeddings import ( + HunyuanCombinedTimestepTextSizeStyleEmbedding, + PatchEmbed, + PixArtAlphaTextProjection, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AdaLayerNormShift(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) + x = self.norm(x) + shift.unsqueeze(dim=1) + return x + + +@maybe_allow_in_graph +class HunyuanDiTBlock(nn.Module): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + cross_attention_dim: int = 1024, + dropout=0.0, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-6, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + skip: bool = False, + qk_norm: bool = True, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=HunyuanAttnProcessor2_0(), + ) + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb=None, + skip=None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([hidden_states, skip], dim=-1) + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class HunyuanDiT2DModel(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + use_style_cond_and_image_meta_size (`bool`, *optional*): + Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "gelu-approximate", + sample_size=32, + hidden_size=1152, + num_layers: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = True, + cross_attention_dim: int = 1024, + norm_type: str = "layer_norm", + cross_attention_dim_t5: int = 2048, + pooled_projection_dim: int = 1024, + text_len: int = 77, + text_len_t5: int = 256, + use_style_cond_and_image_meta_size: bool = True, + ): + super().__init__() + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.num_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + + self.text_embedder = PixArtAlphaTextProjection( + in_features=cross_attention_dim_t5, + hidden_size=cross_attention_dim_t5 * 4, + out_features=cross_attention_dim, + act_fn="silu_fp32", + ) + + self.text_embedding_padding = nn.Parameter( + torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + ) + + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + in_channels=in_channels, + embed_dim=hidden_size, + patch_size=patch_size, + pos_embed_type=None, + ) + + self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( + hidden_size, + pooled_projection_dim=pooled_projection_dim, + seq_len=text_len_t5, + cross_attention_dim=cross_attention_dim_t5, + use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, + ) + + # HunyuanDiT Blocks + self.blocks = nn.ModuleList( + [ + HunyuanDiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + activation_fn=activation_fn, + ff_inner_dim=int(self.inner_dim * mlp_ratio), + cross_attention_dim=cross_attention_dim, + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + skip=layer > num_layers // 2, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(HunyuanAttnProcessor2_0()) + + def forward( + self, + hidden_states, + timestep, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + image_rotary_emb=None, + controlnet_block_samples=None, + return_dict=True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of `BertModel`. + text_embedding_mask: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of `BertModel`. + encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. + text_embedding_mask_t5: torch.Tensor + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output + of T5 Text Encoder. + image_meta_size (torch.Tensor): + Conditional embedding indicate the image sizes + style: torch.Tensor: + Conditional embedding indicate the style + image_rotary_emb (`torch.Tensor`): + The image rotary embeddings to apply on query and key tensors during attention calculation. + return_dict: bool + Whether to return a dictionary. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + + temb = self.time_extra_emb( + timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype + ) # [B, D] + + # text projection + batch_size, sequence_length, _ = encoder_hidden_states_t5.shape + encoder_hidden_states_t5 = self.text_embedder( + encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) + ) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) + + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) + text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() + + encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) + + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.config.num_layers // 2: + if controlnet_block_samples is not None: + skip = skips.pop() + controlnet_block_samples.pop() + else: + skip = skips.pop() + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + skip=skip, + ) # (N, L, D) + else: + hidden_states = block( + hidden_states, + temb=temb, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) # (N, L, D) + + if layer < (self.config.num_layers // 2 - 1): + skips.append(hidden_states) + + if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: + raise ValueError("The number of controls is not equal to the number of skip connections.") + + # final layer + hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) + hidden_states = self.proj_out(hidden_states) + # (N, L, patch_size ** 2 * out_channels) + + # unpatchify: (N, out_channels, H, W) + patch_size = self.pos_embed.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py new file mode 100644 index 000000000000..7e2b1273687d --- /dev/null +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -0,0 +1,327 @@ +# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid +from ..attention import BasicTransformerBlock +from ..embeddings import PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +class LatteTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: + https://github.com/Vchitect/Latte + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input. + out_channels (`int`, *optional*): + The number of channels in the output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + patch_size (`int`, *optional*): + The size of the patches to use in the patch embedding layer. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. During inference, you can denoise for up to but not more steps than + `num_embeds_ada_norm`. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. + caption_channels (`int`, *optional*): + The number of channels in the caption embeddings. + video_length (`int`, *optional*): + The number of frames in the video-like data. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: int = 64, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + caption_channels: int = None, + video_length: int = 16, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + # 1. Define input layers + self.height = sample_size + self.width = sample_size + + interpolation_scale = self.config.sample_size // 64 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 2. Define spatial transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 3. Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. Latte other blocks. + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + # define temporal positional embedding + temp_pos_embed = get_1d_sincos_pos_embed_from_grid( + inner_dim, torch.arange(0, video_length).unsqueeze(1) + ) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`LatteTransformer3DModel`] forward method. + + Args: + hidden_states shape `(batch size, channel, num_frame, height, width)`: + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batcheight, sequence_length)` True = keep, False = discard. + * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + enable_temporal_attentions: + (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + # Reshape hidden states + batch_size, channels, num_frame, height, width = hidden_states.shape + # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Input + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # Prepare text embeddings for spatial block + # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( + -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] + ) + + # Prepare timesteps for spatial and temporal block + timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) + + # Spatial and temporal transformer blocks + for i, (spatial_block, temp_block) in enumerate( + zip(self.transformer_blocks, self.temporal_transformer_blocks) + ): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + use_reentrant=False, + ) + else: + hidden_states = spatial_block( + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + ) + + if enable_temporal_attentions: + # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + if i == 0 and num_frame > 1: + hidden_states = hidden_states + self.temp_pos_embed + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + use_reentrant=False, + ) + else: + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + ) + + # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( + 0, 2, 1, 3, 4 + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py new file mode 100644 index 000000000000..d4f5b4658542 --- /dev/null +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -0,0 +1,340 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import LuminaFeedForward +from ..attention_processor import Attention, LuminaAttnProcessor2_0 +from ..embeddings import ( + LuminaCombinedTimestepCaptionEmbedding, + LuminaPatchEmbed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LuminaNextDiTBlock(nn.Module): + """ + A LuminaNextDiTBlock for LuminaNextDiT2DModel. + + Parameters: + dim (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_kv_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + multiple_of (`int`): The number of multiple of ffn layer. + ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. + norm_eps (`float`): The eps for norm layer. + qk_norm (`bool`): normalization for query and key. + cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. + norm_elementwise_affine (`bool`, *optional*, defaults to True), + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + cross_attention_dim: int, + norm_elementwise_affine: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + + self.gate = nn.Parameter(torch.zeros([num_attention_heads])) + + # Self-attention + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + self.attn1.to_out = nn.Identity() + + # Cross-attention + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=norm_elementwise_affine, + ) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + temb: torch.Tensor, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform a forward pass through the LuminaNextDiTBlock. + + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. + encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. + temb (`torch.Tensor`): Timestep embedding with text prompt embedding. + cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. + """ + residual = hidden_states + + # Self-attention + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + self_attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=image_rotary_emb, + **cross_attention_kwargs, + ) + + # Cross-attention + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + cross_attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=encoder_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=None, + **cross_attention_kwargs, + ) + cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) + mixed_attn_output = self_attn_output + cross_attn_output + mixed_attn_output = mixed_attn_output.flatten(-2) + # linear proj + hidden_states = self.attn2.to_out[0](mixed_attn_output) + + hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) + + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + + return hidden_states + + +class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): + """ + LuminaNextDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + learn_sigma (`bool`, *optional*, defaults to True): + Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in + predictions. + qk_norm (`bool`, *optional*, defaults to True): + Indicates if the queries and keys in the attention mechanism should be normalized. + cross_attention_dim (`int`, *optional*, defaults to 2048): + The dimensionality of the text embeddings. This parameter defines the size of the text representations used + in the model. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: Optional[int] = 2, + in_channels: Optional[int] = 4, + hidden_size: Optional[int] = 2304, + num_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_kv_heads: Optional[int] = None, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: Optional[float] = 1e-5, + learn_sigma: Optional[bool] = True, + qk_norm: Optional[bool] = True, + cross_attention_dim: Optional[int] = 2048, + scaling_factor: Optional[float] = 1.0, + ) -> None: + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling_factor = scaling_factor + + self.patch_embedder = LuminaPatchEmbed( + patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True + ) + + self.pad_token = nn.Parameter(torch.empty(hidden_size)) + + self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding( + hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim + ) + + self.layers = nn.ModuleList( + [ + LuminaNextDiTBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) + + assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict=True, + ) -> torch.Tensor: + """ + Forward pass of LuminaNextDiT. + + Parameters: + hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). + timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). + encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). + encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). + """ + hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) + image_rotary_emb = image_rotary_emb.to(hidden_states.device) + + temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) + + encoder_mask = encoder_mask.bool() + for layer in self.layers: + hidden_states = layer( + hidden_states, + mask, + image_rotary_emb, + encoder_hidden_states, + encoder_mask, + temb=temb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + # unpatchify + height_tokens = width_tokens = self.patch_size + height, width = img_size[0] + batch_size = hidden_states.size(0) + sequence_length = (height // height_tokens) * (width // width_tokens) + hidden_states = hidden_states[:, :sequence_length].view( + batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels + ) + output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py new file mode 100644 index 000000000000..7f145edf16fb --- /dev/null +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -0,0 +1,445 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PixArtTransformer2DModel(ModelMixin, ConfigMixin): + r""" + A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, + https://arxiv.org/abs/2403.04692). + + Parameters: + num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (int, optional, defaults to 72): The number of channels in each head. + in_channels (int, defaults to 4): The number of channels in the input. + out_channels (int, optional): + The number of channels in the output. Specify this parameter if the output channel number differs from the + input. + num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. + dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. + norm_num_groups (int, optional, defaults to 32): + Number of groups for group normalization within Transformer blocks. + cross_attention_dim (int, optional): + The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. + attention_bias (bool, optional, defaults to True): + Configure if the Transformer blocks' attention should contain a bias parameter. + sample_size (int, defaults to 128): + The width of the latent images. This parameter is fixed during training. + patch_size (int, defaults to 2): + Size of the patches the model processes, relevant for architectures working on non-sequential data. + activation_fn (str, optional, defaults to "gelu-approximate"): + Activation function to use in feed-forward networks within Transformer blocks. + num_embeds_ada_norm (int, optional, defaults to 1000): + Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during + inference. + upcast_attention (bool, optional, defaults to False): + If true, upcasts the attention mechanism dimensions for potentially improved performance. + norm_type (str, optional, defaults to "ada_norm_zero"): + Specifies the type of normalization used, can be 'ada_norm_zero'. + norm_elementwise_affine (bool, optional, defaults to False): + If true, enables element-wise affine parameters in the normalization layers. + norm_eps (float, optional, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. + use_additional_conditions (bool, optional): If we're using additional conditions as inputs. + attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. + caption_channels (int, optional, defaults to None): + Number of channels to use for projecting the caption embeddings. + use_linear_projection (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + num_vector_embeds (bool, optional, defaults to False): + Deprecated argument. Will be removed in a future version. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + in_channels: int = 4, + out_channels: Optional[int] = 8, + num_layers: int = 28, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + sample_size: int = 128, + patch_size: int = 2, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + use_additional_conditions: Optional[bool] = None, + caption_channels: Optional[int] = None, + attention_type: Optional[str] = "default", + ): + super().__init__() + + # Validate inputs. + if norm_type != "ada_norm_single": + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = in_channels if out_channels is None else out_channels + if use_additional_conditions is None: + if sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + self.gradient_checkpointing = False + + # 2. Initialize the position embedding and transformer blocks. + self.height = self.config.sample_size + self.width = self.config.sample_size + + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + # 3. Output blocks. + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + self.caption_projection = None + if self.config.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + + Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`PixArtTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep (`torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + # 2. Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + None, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py new file mode 100644 index 000000000000..fdb67384ff5e --- /dev/null +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + The output of [`PriorTransformer`]. + + Args: + predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: torch.Tensor + + +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + """ + A Prior Transformer model. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) + + self.proj_in = nn.Linear(embedding_dim, inner_dim) + + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") + + self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) + + if added_emb_type == "prd": + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + if norm_in_type == "layer": + self.norm_in = nn.LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + + self.norm_out = nn.LayerNorm(inner_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) + + causal_attention_mask = torch.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 + ) + causal_attention_mask.triu_(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) + + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def forward( + self, + hidden_states, + timestep: Union[torch.Tensor, float, int], + proj_embedding: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + return_dict: bool = True, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + The currently predicted image embeddings. + timestep (`torch.LongTensor`): + Current denoising step. + proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of + a plain tuple. + + Returns: + [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + + proj_embeddings = self.embedding_proj(proj_embedding) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + + hidden_states = self.proj_in(hidden_states) + + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + additional_embeds.append(prd_embedding) + + hidden_states = torch.cat( + additional_embeds, + dim=1, + ) + + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + ( + 0, + 0, + additional_embeddings_len, + self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, + ), + value=0.0, + ) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py new file mode 100644 index 000000000000..d687dbabf317 --- /dev/null +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -0,0 +1,458 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + StableAudioAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_2d import Transformer2DModelOutput +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +@maybe_allow_in_graph +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(StableAudioAttnProcessor2_0()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + cross_attention_hidden_states, + encoder_attention_mask, + rotary_embedding, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=cross_attention_hidden_states, + encoder_attention_mask=encoder_attention_mask, + rotary_embedding=rotary_embedding, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/models/transformers/t5_film_transformer.py b/src/diffusers/models/transformers/t5_film_transformer.py new file mode 100644 index 000000000000..1dea37a25910 --- /dev/null +++ b/src/diffusers/models/transformers/t5_film_transformer.py @@ -0,0 +1,436 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention_processor import Attention +from ..embeddings import get_timestep_embedding +from ..modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + r""" + T5 style decoder with FiLM conditioning. + + Args: + input_dims (`int`, *optional*, defaults to `128`): + The number of input dimensions. + targets_length (`int`, *optional*, defaults to `256`): + The length of the targets. + d_model (`int`, *optional*, defaults to `768`): + Size of the input hidden states. + num_layers (`int`, *optional*, defaults to `12`): + The number of `DecoderLayer`'s to use. + num_heads (`int`, *optional*, defaults to `12`): + The number of attention heads to use. + d_kv (`int`, *optional*, defaults to `64`): + Size of the key-value projection vectors. + d_ff (`int`, *optional*, defaults to `2048`): + The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. + dropout_rate (`float`, *optional*, defaults to `0.1`): + Dropout probability. + """ + + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.Sequential( + nn.Linear(d_model, d_model * 4, bias=False), + nn.SiLU(), + nn.Linear(d_model * 4, d_model * 4, bias=False), + nn.SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = nn.ModuleList() + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Linear(d_model, input_dims, bias=False) + + def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor: + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config.max_decoder_noise_time, + embedding_dim=self.config.d_model, + max_period=self.config.max_decoder_noise_time, + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = torch.broadcast_to( + torch.arange(seq_length, device=decoder_input_tokens.device), + (batch, seq_length), + ) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = torch.ones( + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype + ) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Module): + r""" + T5 decoder layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__( + self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 + ): + super().__init__() + self.layer = nn.ModuleList() + + # cond self attention: layer 0 + self.layer.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + self.layer.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + self.layer.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_decoder_position_bias=None, + ) -> Tuple[torch.Tensor]: + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( + encoder_hidden_states.dtype + ) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Module): + r""" + T5 style self-attention layer with conditioning. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + conditioning_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + r""" + T5 style cross-attention layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Module): + r""" + T5 style feed-forward conditional layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + r""" + T5 style feed-forward layer with gated activations and dropout. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float): + super().__init__() + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) + self.wo = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout_rate) + self.act = NewGELUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Module): + r""" + T5 style layer normalization module. + + Args: + hidden_size (`int`): + Size of the input hidden states. + eps (`float`, `optional`, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class T5FiLMLayer(nn.Module): + """ + T5 style FiLM Layer. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + """ + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) + + def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor: + emb = self.scale_bias(conditioning_emb) + scale, shift = torch.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py new file mode 100644 index 000000000000..e208a1c10ed4 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -0,0 +1,566 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import LegacyConfigMixin, register_to_config +from ...utils import deprecate, is_torch_version, logging +from ..attention import BasicTransformerBlock +from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import LegacyModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Transformer2DModelOutput(Transformer2DModelOutput): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." + deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale: float = None, + use_additional_conditions: Optional[bool] = None, + ): + super().__init__() + + # Validate inputs. + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # Set some common variables used across the board. + self.use_linear_projection = use_linear_projection + self.interpolation_scale = interpolation_scale + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + if use_additional_conditions is None: + if norm_type == "ada_norm_single" and sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + # 2. Initialize the right blocks. + # These functions follow a common structure: + # a. Initialize the input blocks. b. Initialize the transformer blocks. + # c. Initialize the output blocks and other projection blocks when necessary. + if self.is_input_continuous: + self._init_continuous_input(norm_type=norm_type) + elif self.is_input_vectorized: + self._init_vectorized_inputs(norm_type=norm_type) + elif self.is_input_patches: + self._init_patched_inputs(norm_type=norm_type) + + def _init_continuous_input(self, norm_type): + self.norm = torch.nn.GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: + self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) + else: + self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.use_linear_projection: + self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) + else: + self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) + + def _init_vectorized_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert ( + self.config.num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = self.config.sample_size + self.width = self.config.sample_size + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + self.norm_out = nn.LayerNorm(self.inner_dim) + self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1) + + def _init_patched_inputs(self, norm_type): + assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = self.config.sample_size + self.width = self.config.sample_size + + self.patch_size = self.config.patch_size + interpolation_scale = ( + self.config.interpolation_scale + if self.config.interpolation_scale is not None + else max(self.config.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear( + self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels + ) + + # PixArt-Alpha blocks. + self.adaln_single = None + if self.config.norm_type == "ada_norm_single": + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) + + # 2. Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) + elif self.is_input_vectorized: + output = self._get_output_for_vectorized_inputs(hidden_states) + elif self.is_input_patches: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height, + width=width, + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _operate_on_continuous_inputs(self, hidden_states): + batch, _, height, width = hidden_states.shape + hidden_states = self.norm(hidden_states) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + return hidden_states, inner_dim + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): + batch_size = hidden_states.shape[0] + hidden_states = self.pos_embed(hidden_states) + embedded_timestep = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + return output + + def _get_output_for_vectorized_inputs(self, hidden_states): + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + return output + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None + ): + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + return output diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py new file mode 100644 index 000000000000..fe9c7290b063 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -0,0 +1,422 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class AllegroTransformerBlock(nn.Module): + r""" + Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + only_cross_attention (`bool`, defaults to `False`): + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + processor=AllegroAttnProcessor2_0(), + ) + + # 2. Cross Attention + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + processor=AllegroAttnProcessor2_0(), + ) + + # 3. Feed Forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + ) + + # 4. Scale-shift + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb=None, + ) -> torch.Tensor: + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = hidden_states + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + image_rotary_emb=None, + ) + hidden_states = attn_output + hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + + # TODO(aryan): maybe following line is not required + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class AllegroTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data. + + Args: + patch_size (`int`, defaults to `2`): + The size of spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `96`): + The number of channels in each head. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `4`): + The number of channels in the output. + num_layers (`int`, defaults to `32`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_height (`int`, defaults to `90`): + The height of the input latents. + sample_width (`int`, defaults to `160`): + The width of the input latents. + sample_frames (`int`, defaults to `22`): + The number of frames in the input latents. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + norm_elementwise_affine (`bool`, defaults to `False`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-6`): + The epsilon value to use in normalization layers. + caption_channels (`int`, defaults to `4096`): + Number of channels to use for projecting the caption embeddings. + interpolation_scale_h (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across height dimension. + interpolation_scale_w (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across width dimension. + interpolation_scale_t (`float`, defaults to `2.2`): + Scaling factor to apply in 3D positional embeddings across time dimension. + """ + + @register_to_config + def __init__( + self, + patch_size: int = 2, + patch_size_t: int = 1, + num_attention_heads: int = 24, + attention_head_dim: int = 96, + in_channels: int = 4, + out_channels: int = 4, + num_layers: int = 32, + dropout: float = 0.0, + cross_attention_dim: int = 2304, + attention_bias: bool = True, + sample_height: int = 90, + sample_width: int = 160, + sample_frames: int = 22, + activation_fn: str = "gelu-approximate", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + interpolation_scale_h: float = 2.0, + interpolation_scale_w: float = 2.0, + interpolation_scale_t: float = 2.2, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + + interpolation_scale_t = ( + interpolation_scale_t + if interpolation_scale_t is not None + else ((sample_frames - 1) // 16 + 1) + if sample_frames % 2 == 1 + else sample_frames // 16 + ) + interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 + interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 + + # 1. Patch embedding + self.pos_embed = PatchEmbed( + height=sample_height, + width=sample_width, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + + # 2. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + AllegroTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + # 3. Output projection & norm + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) + + # 4. Timestep embeddings + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + + # 5. Caption projection + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ): + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t = self.config.patch_size_t + p = self.config.patch_size + + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(hidden_states.dtype) + attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width] + + if attention_mask.numel() > 0: + attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width] + attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) + attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) + + attention_mask = ( + (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None + ) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Timestep embeddings + timestep, embedded_timestep = self.adaln_single( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Patch embeddings + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.pos_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + # TODO(aryan): Implement gradient checkpointing + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + timestep, + attention_mask, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # 4. Output normalization & projection + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py new file mode 100644 index 000000000000..94d852f6df4b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -0,0 +1,386 @@ +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + CogVideoXAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import is_torch_version, logging +from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..normalization import CogView3PlusAdaLayerNormZeroTextImage + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView3PlusTransformerBlock(nn.Module): + r""" + Transformer block used in [CogView](https://github.com/THUDM/CogView3) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + """ + + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ): + super().__init__() + + self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-6, + processor=CogVideoXAttnProcessor2_0(), + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + emb: torch.Tensor, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + # norm & modulate + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, emb) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states + + # norm & modulate + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # feed-forward + norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states + + +class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): + r""" + The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay + Diffusion](https://huggingface.co/papers/2403.05121). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + out_channels: int = 16, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + sample_size: int = 128, + ): + super().__init__() + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + self.pooled_projection_dim = 3 * 2 * condition_dim + + self.patch_embed = CogView3PlusPatchEmbed( + in_channels=in_channels, + hidden_size=self.inner_dim, + patch_size=patch_size, + text_hidden_size=text_embed_dim, + pos_embed_max_size=pos_embed_max_size, + ) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=self.pooled_projection_dim, + timesteps_dim=self.inner_dim, + ) + + self.transformer_blocks = nn.ModuleList( + [ + CogView3PlusTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=time_embed_dim, + elementwise_affine=False, + eps=1e-6, + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`CogView3PlusTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor`): + Input `hidden_states` of shape `(batch size, channel, height, width)`. + encoder_hidden_states (`torch.Tensor`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape + `(batch_size, sequence_len, text_embed_dim)` + timestep (`torch.LongTensor`): + Used to indicate denoising step. + original_size (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crop_coords (`torch.Tensor`): + CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]: + The denoised latents using provided inputs as conditioning. + """ + height, width = hidden_states.shape[-2:] + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = self.patch_embed( + hidden_states, encoder_hidden_states + ) # takes care of adding positional embeddings too. + emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=emb, + ) + + hidden_states = self.norm_out(hidden_states, emb) + hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) + ) + hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py new file mode 100644 index 000000000000..0ad3be866019 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,577 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + FluxAttnProcessor2_0, + FluxAttnProcessor2_0_NPU, + FusedFluxAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available +from ...utils.torch_utils import maybe_allow_in_graph +from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + if is_torch_npu_available(): + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + self.norm1_context = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = FluxAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedFluxAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py new file mode 100644 index 000000000000..fb346a70ba4d --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -0,0 +1,568 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import MochiAttnProcessor2_0 +from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import ( + AdaLayerNormContinuous, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-n + + +class MochiModulatedRMSNorm(nn.Module): + def __init__(self, eps: float): + super().__init__() + + self.eps = eps + + def forward(self, hidden_states, scale=None): + hidden_states_dtype = hidden_states.dtype + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) + + if scale is not None: + hidden_states = hidden_states * scale + + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states + + +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine=True): + super().__init__() + + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + hidden_states_dtype = hidden_states.dtype + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states + + +class MochiLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + eps=1e-5, + bias=True, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + self.norm = MochiModulatedRMSNorm(eps=eps) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + input_dtype = x.dtype + + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32))) + + return x.to(input_dtype) + + +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = MochiModulatedRMSNorm(eps=eps) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states_dtype = hidden_states.dtype + + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + + hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32))) + hidden_states = hidden_states.to(hidden_states_dtype) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + +class MochiAttention(nn.Module): + def __init__( + self, + query_dim: int, + processor: Optional["MochiAttnProcessor2_0"], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_dim: int = None, + out_context_dim: int = None, + out_bias: bool = True, + context_pre_only: bool = False, + eps: float = 1e-5, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim else query_dim + self.context_pre_only = context_pre_only + + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.norm_q = MochiRMSNorm(dim_head, eps) + self.norm_k = MochiRMSNorm(dim_head, eps) + self.norm_added_q = MochiRMSNorm(dim_head, eps) + self.norm_added_k = MochiRMSNorm(dim_head, eps) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + +@maybe_allow_in_graph +class MochiTransformerBlock(nn.Module): + r""" + Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + context_pre_only (`bool`, defaults to `False`): + Whether or not to process context-related conditions with additional layers. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + pooled_projection_dim: int, + qk_norm: str = "rms_norm", + activation_fn: str = "swiglu", + context_pre_only: bool = False, + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.context_pre_only = context_pre_only + self.ff_inner_dim = (4 * dim * 2) // 3 + self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 + + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) + + if not context_pre_only: + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) + else: + self.norm1_context = MochiLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=eps, + ) + + self.attn1 = MochiAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=False, + added_kv_proj_dim=pooled_projection_dim, + added_proj_bias=False, + out_dim=dim, + out_context_dim=pooled_projection_dim, + context_pre_only=context_pre_only, + processor=MochiAttnProcessor2_0(), + eps=1e-5, + ) + + # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True + self.norm2 = MochiModulatedRMSNorm(eps=eps) + self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None + + self.norm3 = MochiModulatedRMSNorm(eps) + self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None + + self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = None + if not context_pre_only: + self.ff_context = FeedForward( + pooled_projection_dim, + inner_dim=self.ff_context_inner_dim, + activation_fn=activation_fn, + bias=False, + ) + + self.norm4 = MochiModulatedRMSNorm(eps=eps) + self.norm4_context = MochiModulatedRMSNorm(eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + encoder_attention_mask: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if not self.context_pre_only: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( + encoder_hidden_states, temb + ) + else: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + + attn_hidden_states, context_attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=encoder_attention_mask, + ) + + hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) + norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) + + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + self.norm2_context( + context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) + ) + norm_encoder_hidden_states = self.norm3_context( + encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)) + ) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + self.norm4_context( + context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) + ) + + return hidden_states, encoder_hidden_states + + +class MochiRoPE(nn.Module): + r""" + RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + base_height (`int`, defaults to `192`): + Base height used to compute interpolation scale for rotary positional embeddings. + base_width (`int`, defaults to `192`): + Base width used to compute interpolation scale for rotary positional embeddings. + """ + + def __init__(self, base_height: int = 192, base_width: int = 192) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + with torch.autocast(freqs.device.type, enabled=False): + # Always run ROPE freqs computation in FP32 + freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) + + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin + + +@maybe_allow_in_graph +class MochiTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `48`): + The number of layers of Transformer blocks to use. + in_channels (`int`, defaults to `12`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `256`): + Output dimension of timestep embeddings. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + max_sequence_length (`int`, defaults to `256`): + The maximum sequence length of text embeddings supported. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 48, + pooled_projection_dim: int = 1536, + in_channels: int = 12, + out_channels: Optional[int] = None, + qk_norm: str = "rms_norm", + text_embed_dim: int = 4096, + time_embed_dim: int = 256, + activation_fn: str = "swiglu", + max_sequence_length: int = 256, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + pos_embed_type=None, + ) + + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=inner_dim, + pooled_projection_dim=pooled_projection_dim, + text_embed_dim=text_embed_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, + ) + + self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) + self.rope = MochiRoPE() + + self.transformer_blocks = nn.ModuleList( + [ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i == num_layers - 1, + ) + for i in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + inner_dim, + inner_dim, + elementwise_affine=False, + eps=1e-6, + norm_type="layer_norm", + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = self.time_embed( + timestep, + encoder_hidden_states, + encoder_attention_mask, + hidden_dtype=hidden_states.dtype, + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + for i, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py new file mode 100644 index 000000000000..f39a102c7256 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -0,0 +1,373 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention import JointTransformerBlock +from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + The Transformer model introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + out_channels (`int`, defaults to 16): Number of output channels. + + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = out_channels if out_channels is not None else default_out_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, # hard-code for now. + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=i == num_layers - 1, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, + ) + for i in range(self.config.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + block_controlnet_hidden_states: List = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + ) + + # controlnet residual + if block_controlnet_hidden_states is not None and block.context_pre_only is False: + interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py new file mode 100644 index 000000000000..6ca42b9745fd --- /dev/null +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -0,0 +1,381 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.Tensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.LongTensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. + + Returns: + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] + instead of a plain tuple. + + Returns: + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an + [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] + time_context = time_context_first_timestep[:, None].broadcast_to( + batch_size, height * width, time_context.shape[-2], time_context.shape[-1] + ) + time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) From 6e2011aa7d6b02827de91aa7176e8b4887643891 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 06:56:06 +0100 Subject: [PATCH 27/47] update --- .../pipelines/mochi/pipeline_mochi.py | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index c815b13e6d19..e6eb4d04cf98 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -700,30 +700,23 @@ def __call__( if output_type == "latent": video = latents else: - with torch.autocast("cuda", torch.float32): - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) ) - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None - if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean - else: - latents = latents / self.vae.config.scaling_factor - - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From 9c5eb368c4981018117aebc24ad25cb0ff974899 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 07:30:18 +0100 Subject: [PATCH 28/47] update --- .../pipelines/mochi/pipeline_mochi.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index e6eb4d04cf98..98d52d095e3e 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -252,6 +252,7 @@ def _get_t5_prompt_embeds( _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.to(torch.float32) prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) @@ -450,7 +451,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) return latents @property @@ -594,38 +595,37 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - with torch.autocast("cuda", torch.float32): - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - prompt_embeds.dtype, - device, - generator, - latents, - ) + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) From d759516b2df9a0a3b65f66ad24099b7773aec7a8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 07:51:05 +0100 Subject: [PATCH 29/47] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 98d52d095e3e..5d696c65558e 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -252,7 +252,6 @@ def _get_t5_prompt_embeds( _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - prompt_embeds = prompt_embeds.to(torch.float32) prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) @@ -451,7 +450,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property From 7854bde9010761f7fce42d1445b25378d812056e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 07:55:35 +0100 Subject: [PATCH 30/47] update --- src/diffusers/models/normalization.py | 29 --------------------------- 1 file changed, 29 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index aa15e0a4f8d7..ba6146b1c751 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -234,35 +234,6 @@ def forward( return x, gate_msa, scale_mlp, gate_mlp -class MochiRMSNormZero(nn.Module): - r""" - Adaptive RMS Norm used in Mochi. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - """ - - def __init__( - self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False - ) -> None: - super().__init__() - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) - - def forward( - self, hidden_states: torch.Tensor, emb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - emb = self.linear(self.silu(emb)) - scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - scale_msa = scale_msa.float() - _hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None]) - hidden_states = _hidden_states.to(hidden_states.dtype) - - return hidden_states, gate_msa, scale_mlp, gate_mlp - - class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). From 2881f2f98694db8094db03e4626549de37434f90 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 07:58:19 +0100 Subject: [PATCH 31/47] update --- src/diffusers/models/attention_processor.py | 105 ----------------- .../models/transformers/transformer_mochi.py | 107 +++++++++++++++++- 2 files changed, 106 insertions(+), 106 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7e7e1a54c4e1..c818959c278f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3510,111 +3510,6 @@ def __call__( return hidden_states -class MochiAttnProcessor2_0: - """Attention processor used in Mochi.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) - - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - if image_rotary_emb is not None: - - def apply_rotary_emb(x, freqs_cos, freqs_sin): - x_even = x[..., 0::2].float() - x_odd = x[..., 1::2].float() - - cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) - sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) - - return torch.stack([cos, sin], dim=-1).flatten(-2) - - query = apply_rotary_emb(query, *image_rotary_emb) - key = apply_rotary_emb(key, *image_rotary_emb) - - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - encoder_query, encoder_key, encoder_value = ( - encoder_query.transpose(1, 2), - encoder_key.transpose(1, 2), - encoder_value.transpose(1, 2), - ) - - sequence_length = query.size(2) - encoder_sequence_length = encoder_query.size(2) - total_length = sequence_length + encoder_sequence_length - - batch_size, heads, _, dim = query.shape - attn_outputs = [] - for idx in range(batch_size): - mask = attention_mask[idx][None, :] - valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() - - valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices) - valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices) - valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices) - - valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2) - valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2) - valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2) - - attn_output = F.scaled_dot_product_attention( - valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False - ) - valid_sequence_length = attn_output.size(2) - attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) - attn_outputs.append(attn_output) - - hidden_states = torch.cat(attn_outputs, dim=0) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - - class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index fb346a70ba4d..0f5d2e71f6bd 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -17,12 +17,12 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -205,6 +205,111 @@ def forward( ) +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: MochiAttention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + total_length = sequence_length + encoder_sequence_length + + batch_size, heads, _, dim = query.shape + attn_outputs = [] + for idx in range(batch_size): + mask = attention_mask[idx][None, :] + valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices) + valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices) + valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices) + + valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2) + valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2) + valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2) + + attn_output = F.scaled_dot_product_attention( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False + ) + valid_sequence_length = attn_output.size(2) + attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) + attn_outputs.append(attn_output) + + hidden_states = torch.cat(attn_outputs, dim=0) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" From 7854061ebd4c54fc21e8d130dfe1094e837a2b7e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 27 Nov 2024 07:59:46 +0100 Subject: [PATCH 32/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 0f5d2e71f6bd..9839f37cc5ce 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -145,7 +145,7 @@ class MochiAttention(nn.Module): def __init__( self, query_dim: int, - processor: Optional["MochiAttnProcessor2_0"], + processor: "MochiAttnProcessor2_0", heads: int = 8, dim_head: int = 64, dropout: float = 0.0, @@ -214,7 +214,7 @@ def __init__(self): def __call__( self, - attn: MochiAttention, + attn: "MochiAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, From ba9c1850e8767d2958fc518b7c8281bc0fa5ca77 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 29 Nov 2024 08:54:56 +0100 Subject: [PATCH 33/47] update --- .../models/transformers/transformer_mochi.py | 6 +-- .../pipelines/mochi/pipeline_mochi.py | 37 ++++++++++--------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 9839f37cc5ce..f83ae66efa01 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.tokenization_utils_base import import_protobuf_decode_error from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging @@ -478,9 +479,8 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - with torch.autocast(freqs.device.type, enabled=False): - # Always run ROPE freqs computation in FP32 - freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) + # Always run ROPE freqs computation in FP32 + freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 5d696c65558e..09e88f2f7144 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -594,24 +594,25 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) + with torch.autocast("cuda", torch.float32): + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( From 53dbc37ea6d334c31ba1c5085b6b1e5b035c02ec Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 29 Nov 2024 09:57:57 +0100 Subject: [PATCH 34/47] update --- .../models/transformers/transformer_mochi.py | 1 - .../pipelines/mochi/pipeline_mochi.py | 37 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index f83ae66efa01..459be67ae37e 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers.tokenization_utils_base import import_protobuf_decode_error from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 09e88f2f7144..5d696c65558e 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -594,25 +594,24 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - with torch.autocast("cuda", torch.float32): - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( From 77f9d1905a28ba778a8adc80941ff3aa58b7e49e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 29 Nov 2024 13:00:46 +0100 Subject: [PATCH 35/47] update --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 5d696c65558e..f024193a5830 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -235,6 +235,9 @@ def _get_t5_prompt_embeds( text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) + if prompt == "" or prompt[-1] == "": + text_input_ids = torch.zeros_like(text_input_ids, device=device) + prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids @@ -450,7 +453,8 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = latents.to(dtype) return latents @property From dc96890d7b1c4263f74db2936683e2982c47674a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 29 Nov 2024 13:49:44 +0100 Subject: [PATCH 36/47] update --- src/diffusers/models/attention_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cfa2ecb55e8e..81a3190011ab 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -4992,7 +4992,6 @@ def __init__(self): PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - MochiAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, From ae57913fbbc0aebd86195b53866b5940d0f535ea Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 29 Nov 2024 16:58:43 +0100 Subject: [PATCH 37/47] update --- .../models/transformers/transformer_mochi.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 902c07e167b4..4a4c0dec6a06 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -278,13 +278,13 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): mask = attention_mask[idx][None, :] valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() - valid_encoder_query = torch.index_select(encoder_query[idx][None, :], 2, valid_prompt_token_indices) - valid_encoder_key = torch.index_select(encoder_key[idx][None, :], 2, valid_prompt_token_indices) - valid_encoder_value = torch.index_select(encoder_value[idx][None, :], 2, valid_prompt_token_indices) + valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] - valid_query = torch.cat([query[idx][None, :], valid_encoder_query], dim=2) - valid_key = torch.cat([key[idx][None, :], valid_encoder_key], dim=2) - valid_value = torch.cat([value[idx][None, :], valid_encoder_value], dim=2) + valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) + valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) + valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) attn_output = F.scaled_dot_product_attention( valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False @@ -666,6 +666,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + encoder_attention_mask, image_rotary_emb, **ckpt_kwargs, ) From 7626a34362a71cffffedc74f5a1661d06bd37ead Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 30 Nov 2024 10:52:18 +0100 Subject: [PATCH 38/47] update --- .../models/transformers/transformer_mochi.py | 44 ++++--------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 4a4c0dec6a06..e6513feec22a 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -27,9 +27,7 @@ from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import ( - AdaLayerNormContinuous, -) +from ..normalization import AdaLayerNormContinuous, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-n @@ -40,12 +38,13 @@ def __init__(self, eps: float): super().__init__() self.eps = eps + self.norm = RMSNorm(0, eps, False) def forward(self, hidden_states, scale=None): hidden_states_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) + hidden_states = self.norm(hidden_states) if scale is not None: hidden_states = hidden_states * scale @@ -55,33 +54,6 @@ def forward(self, hidden_states, scale=None): return hidden_states -class MochiRMSNorm(nn.Module): - def __init__(self, dim, eps: float, elementwise_affine=True): - super().__init__() - - self.eps = eps - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.weight = None - - def forward(self, hidden_states): - hidden_states_dtype = hidden_states.dtype - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps) - - if self.weight is not None: - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight - - hidden_states = hidden_states.to(hidden_states_dtype) - - return hidden_states - - class MochiLayerNormContinuous(nn.Module): def __init__( self, @@ -167,10 +139,10 @@ def __init__( self.heads = out_dim // dim_head if out_dim is not None else heads - self.norm_q = MochiRMSNorm(dim_head, eps) - self.norm_k = MochiRMSNorm(dim_head, eps) - self.norm_added_q = MochiRMSNorm(dim_head, eps) - self.norm_added_k = MochiRMSNorm(dim_head, eps) + self.norm_q = RMSNorm(dim_head, eps, True) + self.norm_k = RMSNorm(dim_head, eps, True) + self.norm_added_q = RMSNorm(dim_head, eps, True) + self.norm_added_k = RMSNorm(dim_head, eps, True) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) From c39886ac136452afb66d9b7a2d2092f4ddf6fb7a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 30 Nov 2024 11:06:08 +0100 Subject: [PATCH 39/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6513feec22a..159cdffe7e5c 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -98,7 +98,7 @@ def __init__( self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, hidden_dim) - self.norm = MochiModulatedRMSNorm(eps=eps) + self.norm = RMSNorm(0, eps, False) def forward( self, hidden_states: torch.Tensor, emb: torch.Tensor @@ -108,7 +108,7 @@ def forward( emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32))) + hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32)) hidden_states = hidden_states.to(hidden_states_dtype) return hidden_states, gate_msa, scale_mlp, gate_mlp From bbc58926cc34db6c11ebfd5e581779d66e284cf3 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 7 Dec 2024 04:45:42 +0100 Subject: [PATCH 40/47] update --- src/diffusers/models/embeddings.py | 7 ++++--- src/diffusers/models/transformers/transformer_mochi.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c44b110473c1..e19d2d9af1a3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1594,12 +1594,13 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. Returns: pooled: (B, D) tensor of pooled tokens. """ + input_dtype = x.dtype assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. - mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask[:, :, None].to(dtype=torch.float32) mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) - pooled = (x * mask).sum(dim=1, keepdim=keepdim) - return pooled + pooled = (x.to(torch.float32) * mask).sum(dim=1, keepdim=keepdim) + return pooled.to(input_dtype) def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: r""" diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 159cdffe7e5c..405024524aac 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -451,8 +451,9 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - # Always run ROPE freqs computation in FP32 - freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) + with torch.autocast(self.device.type, torch.float32): + # Always run ROPE freqs computation in FP32 + freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) From 3c70b54117877a71690600adf9b39bff6403c176 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 7 Dec 2024 04:47:59 +0100 Subject: [PATCH 41/47] update --- src/diffusers/models/transformers/transformer_mochi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 405024524aac..06ef60db77a3 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -451,7 +451,7 @@ def _get_positions( return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: - with torch.autocast(self.device.type, torch.float32): + with torch.autocast(freqs.device.type, torch.float32): # Always run ROPE freqs computation in FP32 freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32)) From 11ce6b8791ab64a628aa5588194e0b8dd7ea09e1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 7 Dec 2024 08:48:14 +0100 Subject: [PATCH 42/47] update --- src/diffusers/models/embeddings.py | 4 +- .../models/transformers/transformer_mochi.py | 37 +++++++++++++++++-- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e19d2d9af1a3..2119eeeb2ff8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1597,9 +1597,9 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. input_dtype = x.dtype assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. - mask = mask[:, :, None].to(dtype=torch.float32) + mask = mask[:, :, None].to(dtype=x.dtype) mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) - pooled = (x.to(torch.float32) * mask).sum(dim=1, keepdim=keepdim) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) return pooled.to(input_dtype) def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 06ef60db77a3..2e5306babb74 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers from typing import Any, Dict, Optional, Tuple import torch @@ -54,6 +55,34 @@ def forward(self, hidden_states, scale=None): return hidden_states +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + hidden_states = hidden_states * self.weight + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + class MochiLayerNormContinuous(nn.Module): def __init__( self, @@ -139,10 +168,10 @@ def __init__( self.heads = out_dim // dim_head if out_dim is not None else heads - self.norm_q = RMSNorm(dim_head, eps, True) - self.norm_k = RMSNorm(dim_head, eps, True) - self.norm_added_q = RMSNorm(dim_head, eps, True) - self.norm_added_k = RMSNorm(dim_head, eps, True) + self.norm_q = MochiRMSNorm(dim_head, eps, True) + self.norm_k = MochiRMSNorm(dim_head, eps, True) + self.norm_added_q = MochiRMSNorm(dim_head, eps, True) + self.norm_added_k = MochiRMSNorm(dim_head, eps, True) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) From cc7b91d27ba922f06d56bd90b8281f57c2b343b4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 7 Dec 2024 17:54:03 +0100 Subject: [PATCH 43/47] update --- src/diffusers/models/embeddings.py | 3 +-- src/diffusers/models/transformers/transformer_mochi.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2119eeeb2ff8..c44b110473c1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1594,13 +1594,12 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. Returns: pooled: (B, D) tensor of pooled tokens. """ - input_dtype = x.dtype assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. mask = mask[:, :, None].to(dtype=x.dtype) mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) pooled = (x * mask).sum(dim=1, keepdim=keepdim) - return pooled.to(input_dtype) + return pooled def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: r""" diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 2e5306babb74..620c16441998 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -136,7 +136,6 @@ def forward( emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32)) hidden_states = hidden_states.to(hidden_states_dtype) From 2a6b82d0475b3f5575da5d04d33b47c13c4f1879 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Dec 2024 09:11:47 +0100 Subject: [PATCH 44/47] update --- src/diffusers/models/attention_processor.py | 171 +++++++++++++++ src/diffusers/models/normalization.py | 30 +++ .../models/transformers/transformer_mochi.py | 200 +----------------- 3 files changed, 202 insertions(+), 199 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2ffe325fe664..f918be308539 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -880,6 +880,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.processor(self, hidden_states) +class MochiAttention(nn.Module): + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + processor: "MochiAttnProcessor2_0", + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_proj_bias: bool = True, + out_dim: Optional[int] = None, + out_context_dim: Optional[int] = None, + out_bias: bool = True, + context_pre_only: bool = False, + eps: float = 1e-5, + ): + super().__init__() + from .normalization import MochiRMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim else query_dim + self.context_pre_only = context_pre_only + + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.norm_q = MochiRMSNorm(dim_head, eps, True) + self.norm_k = MochiRMSNorm(dim_head, eps, True) + self.norm_added_q = MochiRMSNorm(dim_head, eps, True) + self.norm_added_k = MochiRMSNorm(dim_head, eps, True) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) + + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + + +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "MochiAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + total_length = sequence_length + encoder_sequence_length + + batch_size, heads, _, dim = query.shape + attn_outputs = [] + for idx in range(batch_size): + mask = attention_mask[idx][None, :] + valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] + valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] + + valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) + valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) + valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) + + attn_output = F.scaled_dot_product_attention( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False + ) + valid_sequence_length = attn_output.size(2) + attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) + attn_outputs.append(attn_output) + + hidden_states = torch.cat(attn_outputs, dim=0) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class AttnProcessor: r""" Default processor for performing attention-related computations. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index c9d1038357c4..fe3823e32acf 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -522,6 +522,36 @@ def forward(self, hidden_states): return hidden_states +# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported +# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013 +class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + hidden_states = hidden_states * self.weight + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + class GlobalResponseNorm(nn.Module): # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 620c16441998..2ea3c05c6b56 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -13,18 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward +from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -55,34 +54,6 @@ def forward(self, hidden_states, scale=None): return hidden_states -class MochiRMSNorm(nn.Module): - def __init__(self, dim, eps: float, elementwise_affine: bool = True): - super().__init__() - - self.eps = eps - - if isinstance(dim, numbers.Integral): - dim = (dim,) - - self.dim = torch.Size(dim) - - if elementwise_affine: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.weight = None - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - if self.weight is not None: - hidden_states = hidden_states * self.weight - hidden_states = hidden_states.to(input_dtype) - - return hidden_states - - class MochiLayerNormContinuous(nn.Module): def __init__( self, @@ -142,175 +113,6 @@ def forward( return hidden_states, gate_msa, scale_mlp, gate_mlp -class MochiAttention(nn.Module): - def __init__( - self, - query_dim: int, - processor: "MochiAttnProcessor2_0", - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - out_dim: int = None, - out_context_dim: int = None, - out_bias: bool = True, - context_pre_only: bool = False, - eps: float = 1e-5, - ): - super().__init__() - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim else query_dim - self.context_pre_only = context_pre_only - - self.heads = out_dim // dim_head if out_dim is not None else heads - - self.norm_q = MochiRMSNorm(dim_head, eps, True) - self.norm_k = MochiRMSNorm(dim_head, eps, True) - self.norm_added_q = MochiRMSNorm(dim_head, eps, True) - self.norm_added_k = MochiRMSNorm(dim_head, eps, True) - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - if not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - - self.processor = processor - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **kwargs, - ) - - -class MochiAttnProcessor2_0: - """Attention processor used in Mochi.""" - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: "MochiAttention", - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) - - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) - - if image_rotary_emb is not None: - - def apply_rotary_emb(x, freqs_cos, freqs_sin): - x_even = x[..., 0::2].float() - x_odd = x[..., 1::2].float() - - cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) - sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) - - return torch.stack([cos, sin], dim=-1).flatten(-2) - - query = apply_rotary_emb(query, *image_rotary_emb) - key = apply_rotary_emb(key, *image_rotary_emb) - - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - encoder_query, encoder_key, encoder_value = ( - encoder_query.transpose(1, 2), - encoder_key.transpose(1, 2), - encoder_value.transpose(1, 2), - ) - - sequence_length = query.size(2) - encoder_sequence_length = encoder_query.size(2) - total_length = sequence_length + encoder_sequence_length - - batch_size, heads, _, dim = query.shape - attn_outputs = [] - for idx in range(batch_size): - mask = attention_mask[idx][None, :] - valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() - - valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :] - valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :] - valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :] - - valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2) - valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) - valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) - - attn_output = F.scaled_dot_product_attention( - valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False - ) - valid_sequence_length = attn_output.size(2) - attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) - attn_outputs.append(attn_output) - - hidden_states = torch.cat(attn_outputs, dim=0) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - - @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" From cbbc54b05014f582e2d530514910c8169a12c21f Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Dec 2024 09:22:16 +0100 Subject: [PATCH 45/47] update --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 1 - src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 72b95fea1ce1..b08cf290ad9d 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -198,7 +198,6 @@ def __init__( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 ) - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 25ed635a3d17..961a9425cfe2 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -221,7 +221,6 @@ def __init__( self.default_width = 704 self.default_frames = 121 - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, From b75db11b9d1535af696724f45432969eb8cf23b3 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Dec 2024 13:59:15 +0100 Subject: [PATCH 46/47] update --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a6f4a675c519..33dc390629a1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5735,6 +5735,7 @@ def __call__( AttnProcessorNPU, AttnProcessor2_0, MochiVaeAttnProcessor2_0, + MochiAttnProcessor2_0, StableAudioAttnProcessor2_0, HunyuanAttnProcessor2_0, FusedHunyuanAttnProcessor2_0, From 50c5607e9608935c59a09d78608a2a2cb9f17da7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Dec 2024 16:16:22 +0530 Subject: [PATCH 47/47] Update src/diffusers/models/transformers/transformer_mochi.py Co-authored-by: Aryan --- src/diffusers/models/transformers/transformer_mochi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 2ea3c05c6b56..fe72dc56883e 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -30,7 +30,7 @@ from ..normalization import AdaLayerNormContinuous, RMSNorm -logger = logging.get_logger(__name__) # pylint: disable=invalid-n +logger = logging.get_logger(__name__) # pylint: disable=invalid-name class MochiModulatedRMSNorm(nn.Module):