From 7fa2bde6b204e7432a8281f1f0de297c05234c2d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 24 Aug 2024 03:07:34 +0200 Subject: [PATCH 1/8] remove frame limit in cogvideox --- .../transformers/cogvideox_transformer_3d.py | 34 +++---------- .../pipelines/cogvideo/pipeline_cogvideox.py | 51 +++++++++++++++---- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c8d4b1896346..054ab4bef64f 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -239,33 +239,15 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - # 1. Patch embedding self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) self.embedding_dropout = nn.Dropout(dropout) - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, - ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) - - # 3. Time embeddings + # 2. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 4. Define spatio-temporal transformers blocks + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -284,7 +266,7 @@ def __init__( ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - # 5. Output blocks + # 4. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, @@ -405,6 +387,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + positional_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): @@ -426,12 +409,11 @@ def forward( # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + video_seq_length = height * width * num_frames // (self.config.patch_size**2) + pos_embeds = positional_emb[:, : text_seq_length + video_seq_length] hidden_states = hidden_states + pos_embeds - hidden_states = self.embedding_dropout(hidden_states) + hidden_states = self.embedding_dropout(hidden_states) encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index e100c1f11e20..66e722fd4a7c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel -from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.embeddings import get_3d_rotary_pos_embed, get_3d_sincos_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring @@ -443,6 +443,36 @@ def unfuse_qkv_projections(self) -> None: self.transformer.unfuse_qkv_projections() self.fusing_transformer = False + def _prepare_normal_positional_embeddings( + self, + height, + width: int, + num_frames: int, + text_seq_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + inner_dim = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (grid_width, grid_height), + num_frames, + self.transformer.config.spatial_interpolation_scale, + self.transformer.config.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1).unsqueeze(0) + + text_embedding_pad = torch.zeros((1, text_seq_length, inner_dim), requires_grad=False) + print(grid_height, grid_width, num_frames) + print(pos_embedding.shape, text_embedding_pad.shape) + pos_embedding = torch.cat([text_embedding_pad, pos_embedding], dim=1) + + pos_embedding = pos_embedding.to(device=device, dtype=dtype) + return pos_embedding + def _prepare_rotary_positional_embeddings( self, height: int, @@ -585,11 +615,6 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -660,8 +685,15 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Create rotary embeds if required - image_rotary_emb = ( + # 7. Create positional and rotary embeds as required + positional_embeds = ( + self._prepare_normal_positional_embeddings( + height, width, latents.size(1), max_sequence_length, device, prompt_embeds.dtype + ) + if not self.transformer.config.use_rotary_positional_embeddings + else None + ) + rotary_embeds = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None @@ -688,7 +720,8 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, - image_rotary_emb=image_rotary_emb, + positional_emb=positional_embeds, + image_rotary_emb=rotary_embeds, return_dict=False, )[0] noise_pred = noise_pred.float() From 22311d1c05a09cec91ded11798710d31e23fc7b2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 24 Aug 2024 03:09:25 +0200 Subject: [PATCH 2/8] remove debug prints --- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 66e722fd4a7c..6ae146a03bf5 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -466,8 +466,6 @@ def _prepare_normal_positional_embeddings( pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1).unsqueeze(0) text_embedding_pad = torch.zeros((1, text_seq_length, inner_dim), requires_grad=False) - print(grid_height, grid_width, num_frames) - print(pos_embedding.shape, text_embedding_pad.shape) pos_embedding = torch.cat([text_embedding_pad, pos_embedding], dim=1) pos_embedding = pos_embedding.to(device=device, dtype=dtype) From f8f03a1e3406ab67922196491e0eae945e61906b Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 26 Aug 2024 03:20:18 +0530 Subject: [PATCH 3/8] Update src/diffusers/models/transformers/cogvideox_transformer_3d.py --- src/diffusers/models/transformers/cogvideox_transformer_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 054ab4bef64f..f36cfce87e8a 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -408,7 +408,7 @@ def forward( # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: + if not self.config.use_rotary_positional_embeddings and positional_emb is not None: video_seq_length = height * width * num_frames // (self.config.patch_size**2) pos_embeds = positional_emb[:, : text_seq_length + video_seq_length] hidden_states = hidden_states + pos_embeds From 92a2f7e6521565feea2c77a33f7ffde49d715759 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 27 Aug 2024 12:11:04 +0200 Subject: [PATCH 4/8] revert pipeline; remove frame limitation --- .../pipelines/cogvideo/pipeline_cogvideox.py | 44 ++----------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 34aa534afe57..5add60591aed 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel -from ...models.embeddings import get_3d_rotary_pos_embed, get_3d_sincos_pos_embed +from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring @@ -443,34 +443,6 @@ def unfuse_qkv_projections(self) -> None: self.transformer.unfuse_qkv_projections() self.fusing_transformer = False - def _prepare_normal_positional_embeddings( - self, - height, - width: int, - num_frames: int, - text_seq_length: int, - device: torch.device, - dtype: torch.dtype, - ) -> torch.Tensor: - inner_dim = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim - grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - - pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (grid_width, grid_height), - num_frames, - self.transformer.config.spatial_interpolation_scale, - self.transformer.config.temporal_interpolation_scale, - ) - pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1).unsqueeze(0) - - text_embedding_pad = torch.zeros((1, text_seq_length, inner_dim), requires_grad=False) - pos_embedding = torch.cat([text_embedding_pad, pos_embedding], dim=1) - - pos_embedding = pos_embedding.to(device=device, dtype=dtype) - return pos_embedding - def _prepare_rotary_positional_embeddings( self, height: int, @@ -682,15 +654,8 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Create positional and rotary embeds as required - positional_embeds = ( - self._prepare_normal_positional_embeddings( - height, width, latents.size(1), max_sequence_length, device, prompt_embeds.dtype - ) - if not self.transformer.config.use_rotary_positional_embeddings - else None - ) - rotary_embeds = ( + # 7. Create rotary embeds if required + image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) if self.transformer.config.use_rotary_positional_embeddings else None @@ -717,8 +682,7 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, - positional_emb=positional_embeds, - image_rotary_emb=rotary_embeds, + image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() From 392f726aae26b2628af4100d4c5b641cfb50552a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 27 Aug 2024 12:11:55 +0200 Subject: [PATCH 5/8] revert transformer changes --- .../transformers/cogvideox_transformer_3d.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index f36cfce87e8a..71b393025a4f 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -239,15 +239,33 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim + post_patch_height = sample_height // patch_size + post_patch_width = sample_width // patch_size + post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 + self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames + # 1. Patch embedding self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) self.embedding_dropout = nn.Dropout(dropout) - # 2. Time embeddings + # 2. 3D positional embeddings + spatial_pos_embedding = get_3d_sincos_pos_embed( + inner_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + spatial_interpolation_scale, + temporal_interpolation_scale, + ) + spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) + pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) + pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + # 3. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 3. Define spatio-temporal transformers blocks + # 4. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -266,7 +284,7 @@ def __init__( ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - # 4. Output blocks + # 5. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, @@ -387,7 +405,6 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, - positional_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): @@ -408,11 +425,12 @@ def forward( # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings and positional_emb is not None: - video_seq_length = height * width * num_frames // (self.config.patch_size**2) - pos_embeds = positional_emb[:, : text_seq_length + video_seq_length] - hidden_states = hidden_states + pos_embeds + if not self.config.use_rotary_positional_embeddings: + seq_length = height * width * num_frames // (self.config.patch_size**2) + pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] + hidden_states = hidden_states + pos_embeds + hidden_states = self.embedding_dropout(hidden_states) encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] From 431ad60c1d5a8eef3557583db7bdaf06067839c8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 27 Aug 2024 12:44:45 +0200 Subject: [PATCH 6/8] address review comments --- src/diffusers/models/embeddings.py | 47 +++++++++++++++ .../transformers/cogvideox_transformer_3d.py | 57 ++++++++----------- 2 files changed, 70 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dcb9528cb1a0..0d110f09b785 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -342,15 +342,58 @@ def __init__( embed_dim: int = 1920, text_embed_dim: int = 4096, bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_positional_embeddings: bool = True, ) -> None: super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.text_proj = nn.Linear(text_embed_dim, embed_dim) + if use_positional_embeddings: + pos_embedding = self._create_positional_embeddings() + self.register_buffer("pos_embedding", pos_embedding, persistent=False) + + def _create_positional_embeddings(self) -> torch.Tensor: + post_patch_height = self.sample_height // self.patch_size + post_patch_width = self.sample_width // self.patch_size + post_time_compression_frames = (self.sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + ) + pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + joint_pos_embedding = torch.zeros( + 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) + + return joint_pos_embedding + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" Args: @@ -371,6 +414,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + + if self.use_positional_embeddings: + embeds = embeds + self.pos_embedding[:, : embeds.size(1)] + return embeds diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 71b393025a4f..b6ba407104d5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 -from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -239,33 +239,29 @@ def __init__( super().__init__() inner_dim = num_attention_heads * attention_head_dim - post_patch_height = sample_height // patch_size - post_patch_width = sample_width // patch_size - post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 - self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames - # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) - self.embedding_dropout = nn.Dropout(dropout) - - # 2. 3D positional embeddings - spatial_pos_embedding = get_3d_sincos_pos_embed( - inner_dim, - (post_patch_width, post_patch_height), - post_time_compression_frames, - spatial_interpolation_scale, - temporal_interpolation_scale, + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, ) - spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) - pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) - pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) - self.register_buffer("pos_embedding", pos_embedding, persistent=False) + self.embedding_dropout = nn.Dropout(dropout) - # 3. Time embeddings + # 2. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) - # 4. Define spatio-temporal transformers blocks + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( @@ -284,7 +280,7 @@ def __init__( ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) - # 5. Output blocks + # 4. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, @@ -422,20 +418,13 @@ def forward( # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) - # 3. Position embedding text_seq_length = encoder_hidden_states.shape[1] - if not self.config.use_rotary_positional_embeddings: - seq_length = height * width * num_frames // (self.config.patch_size**2) - - pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] - hidden_states = hidden_states + pos_embeds - - hidden_states = self.embedding_dropout(hidden_states) encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - # 4. Transformer blocks + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -471,11 +460,11 @@ def custom_forward(*inputs): hidden_states = self.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] - # 5. Final block + # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) - # 6. Unpatchify + # 5. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) From 555ed913cf6e8b2440e49ab90de6c1e25ff53ed5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 27 Aug 2024 12:56:47 +0200 Subject: [PATCH 7/8] add error message --- src/diffusers/models/embeddings.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0d110f09b785..b7a1ac8e1887 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -416,6 +416,14 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): ).contiguous() # [batch, seq_length + num_frames x height x width, channels] if self.use_positional_embeddings: + if embeds.size(1) > self.pos_embedding.size(1): + raise ValueError( + "You are trying to generate at a resolution, or higher number of frames, or longer maximum prompt " + "sequence length than what is supported in the `CogVideoXTransformer3D` model. In order to generate " + "at the resolution/num_frames you desire, configure the transformer initialization attributes " + "`sample_height`, `sample_width`, `sample_frames` and `max_text_seq_length` appropriately when " + "initializing your model either with `.from_pretrained(...)` or `.from_config(...)`" + ) embeds = embeds + self.pos_embedding[:, : embeds.size(1)] return embeds From b3b9ecca504b736a51260b70293f4fc5440cc3cb Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 30 Aug 2024 05:10:41 +0200 Subject: [PATCH 8/8] apply suggestions from review --- src/diffusers/models/embeddings.py | 31 ++++++++++--------- .../pipelines/cogvideo/pipeline_cogvideox.py | 5 +++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b7a1ac8e1887..5236548cdf32 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -370,13 +370,13 @@ def __init__( self.text_proj = nn.Linear(text_embed_dim, embed_dim) if use_positional_embeddings: - pos_embedding = self._create_positional_embeddings() + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) self.register_buffer("pos_embedding", pos_embedding, persistent=False) - def _create_positional_embeddings(self) -> torch.Tensor: - post_patch_height = self.sample_height // self.patch_size - post_patch_width = self.sample_width // self.patch_size - post_time_compression_frames = (self.sample_frames - 1) // self.temporal_compression_ratio + 1 + def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 num_patches = post_patch_height * post_patch_width * post_time_compression_frames pos_embedding = get_3d_sincos_pos_embed( @@ -416,15 +416,18 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): ).contiguous() # [batch, seq_length + num_frames x height x width, channels] if self.use_positional_embeddings: - if embeds.size(1) > self.pos_embedding.size(1): - raise ValueError( - "You are trying to generate at a resolution, or higher number of frames, or longer maximum prompt " - "sequence length than what is supported in the `CogVideoXTransformer3D` model. In order to generate " - "at the resolution/num_frames you desire, configure the transformer initialization attributes " - "`sample_height`, `sample_width`, `sample_frames` and `max_text_seq_length` appropriately when " - "initializing your model either with `.from_pretrained(...)` or `.from_config(...)`" - ) - embeds = embeds + self.pos_embedding[:, : embeds.size(1)] + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + if ( + self.sample_height != height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + else: + pos_embedding = self.pos_embedding + + embeds = embeds + pos_embedding return embeds diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 5add60591aed..11f491e49532 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -584,6 +584,11 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs