2323from ...utils .torch_utils import maybe_allow_in_graph
2424from ..attention import Attention , FeedForward
2525from ..attention_processor import AttentionProcessor , CogVideoXAttnProcessor2_0 , FusedCogVideoXAttnProcessor2_0
26- from ..embeddings import CogVideoXPatchEmbed , TimestepEmbedding , Timesteps , get_3d_sincos_pos_embed
26+ from ..embeddings import CogVideoXPatchEmbed , TimestepEmbedding , Timesteps
2727from ..modeling_outputs import Transformer2DModelOutput
2828from ..modeling_utils import ModelMixin
2929from ..normalization import AdaLayerNorm , CogVideoXLayerNormZero
@@ -239,33 +239,29 @@ def __init__(
239239 super ().__init__ ()
240240 inner_dim = num_attention_heads * attention_head_dim
241241
242- post_patch_height = sample_height // patch_size
243- post_patch_width = sample_width // patch_size
244- post_time_compression_frames = (sample_frames - 1 ) // temporal_compression_ratio + 1
245- self .num_patches = post_patch_height * post_patch_width * post_time_compression_frames
246-
247242 # 1. Patch embedding
248- self .patch_embed = CogVideoXPatchEmbed (patch_size , in_channels , inner_dim , text_embed_dim , bias = True )
249- self .embedding_dropout = nn .Dropout (dropout )
250-
251- # 2. 3D positional embeddings
252- spatial_pos_embedding = get_3d_sincos_pos_embed (
253- inner_dim ,
254- (post_patch_width , post_patch_height ),
255- post_time_compression_frames ,
256- spatial_interpolation_scale ,
257- temporal_interpolation_scale ,
243+ self .patch_embed = CogVideoXPatchEmbed (
244+ patch_size = patch_size ,
245+ in_channels = in_channels ,
246+ embed_dim = inner_dim ,
247+ text_embed_dim = text_embed_dim ,
248+ bias = True ,
249+ sample_width = sample_width ,
250+ sample_height = sample_height ,
251+ sample_frames = sample_frames ,
252+ temporal_compression_ratio = temporal_compression_ratio ,
253+ max_text_seq_length = max_text_seq_length ,
254+ spatial_interpolation_scale = spatial_interpolation_scale ,
255+ temporal_interpolation_scale = temporal_interpolation_scale ,
256+ use_positional_embeddings = not use_rotary_positional_embeddings ,
258257 )
259- spatial_pos_embedding = torch .from_numpy (spatial_pos_embedding ).flatten (0 , 1 )
260- pos_embedding = torch .zeros (1 , max_text_seq_length + self .num_patches , inner_dim , requires_grad = False )
261- pos_embedding .data [:, max_text_seq_length :].copy_ (spatial_pos_embedding )
262- self .register_buffer ("pos_embedding" , pos_embedding , persistent = False )
258+ self .embedding_dropout = nn .Dropout (dropout )
263259
264- # 3 . Time embeddings
260+ # 2 . Time embeddings
265261 self .time_proj = Timesteps (inner_dim , flip_sin_to_cos , freq_shift )
266262 self .time_embedding = TimestepEmbedding (inner_dim , time_embed_dim , timestep_activation_fn )
267263
268- # 4 . Define spatio-temporal transformers blocks
264+ # 3 . Define spatio-temporal transformers blocks
269265 self .transformer_blocks = nn .ModuleList (
270266 [
271267 CogVideoXBlock (
@@ -284,7 +280,7 @@ def __init__(
284280 )
285281 self .norm_final = nn .LayerNorm (inner_dim , norm_eps , norm_elementwise_affine )
286282
287- # 5 . Output blocks
283+ # 4 . Output blocks
288284 self .norm_out = AdaLayerNorm (
289285 embedding_dim = time_embed_dim ,
290286 output_dim = 2 * inner_dim ,
@@ -422,20 +418,13 @@ def forward(
422418
423419 # 2. Patch embedding
424420 hidden_states = self .patch_embed (encoder_hidden_states , hidden_states )
421+ hidden_states = self .embedding_dropout (hidden_states )
425422
426- # 3. Position embedding
427423 text_seq_length = encoder_hidden_states .shape [1 ]
428- if not self .config .use_rotary_positional_embeddings :
429- seq_length = height * width * num_frames // (self .config .patch_size ** 2 )
430-
431- pos_embeds = self .pos_embedding [:, : text_seq_length + seq_length ]
432- hidden_states = hidden_states + pos_embeds
433- hidden_states = self .embedding_dropout (hidden_states )
434-
435424 encoder_hidden_states = hidden_states [:, :text_seq_length ]
436425 hidden_states = hidden_states [:, text_seq_length :]
437426
438- # 4 . Transformer blocks
427+ # 3 . Transformer blocks
439428 for i , block in enumerate (self .transformer_blocks ):
440429 if self .training and self .gradient_checkpointing :
441430
@@ -471,11 +460,11 @@ def custom_forward(*inputs):
471460 hidden_states = self .norm_final (hidden_states )
472461 hidden_states = hidden_states [:, text_seq_length :]
473462
474- # 5 . Final block
463+ # 4 . Final block
475464 hidden_states = self .norm_out (hidden_states , temb = emb )
476465 hidden_states = self .proj_out (hidden_states )
477466
478- # 6 . Unpatchify
467+ # 5 . Unpatchify
479468 p = self .config .patch_size
480469 output = hidden_states .reshape (batch_size , num_frames , height // p , width // p , channels , p , p )
481470 output = output .permute (0 , 1 , 4 , 2 , 5 , 3 , 6 ).flatten (5 , 6 ).flatten (3 , 4 )
0 commit comments