1919import torch .nn as nn
2020import torch .nn .functional as F
2121from einops import rearrange
22+ from torch .nn .utils .rnn import pad_sequence
2223
2324from ...configuration_utils import ConfigMixin , register_to_config
2425from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
@@ -355,6 +356,7 @@ def __init__(
355356
356357 self .rope_theta = rope_theta
357358 self .t_scale = t_scale
359+ self .gradient_checkpointing = False
358360
359361 assert len (all_patch_size ) == len (all_f_patch_size )
360362
@@ -579,29 +581,18 @@ def forward(
579581 x = list (x .split (x_item_seqlens , dim = 0 ))
580582 x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
581583
582- pad_tensor = torch .zeros ((1 , self .dim ), dtype = x [0 ].dtype , device = device )
583- freqs_pad_tensor = torch .zeros (
584- (1 , self .dim // self .n_heads // 2 ),
585- dtype = x_freqs_cis [0 ].dtype ,
586- device = device ,
587- )
588- x_attn_mask = torch .ones ((bsz , x_max_item_seqlen ), dtype = torch .bool , device = device )
589- for i , (item , freqs_item ) in enumerate (zip (x , x_freqs_cis )):
590- seq_len = x_item_seqlens [i ]
591- pad_len = x_max_item_seqlen - seq_len
592- x [i ] = torch .cat ([item , pad_tensor .repeat (pad_len , 1 )])
593- x_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
594- x_attn_mask [i , seq_len :] = 0
595- x = torch .stack (x )
596- x_freqs_cis = torch .stack (x_freqs_cis )
597-
598- for layer in self .noise_refiner :
599- x = layer (
600- x ,
601- x_attn_mask ,
602- x_freqs_cis ,
603- adaln_input ,
604- )
584+ x = pad_sequence (x , batch_first = True , padding_value = 0.0 )
585+ x_freqs_cis = pad_sequence (x_freqs_cis , batch_first = True , padding_value = 0.0 )
586+ x_attn_mask = torch .zeros ((bsz , x_max_item_seqlen ), dtype = torch .bool , device = device )
587+ for i , seq_len in enumerate (x_item_seqlens ):
588+ x_attn_mask [i , :seq_len ] = 1
589+
590+ if torch .is_grad_enabled () and self .gradient_checkpointing :
591+ for layer in self .noise_refiner :
592+ x = self ._gradient_checkpointing_func (layer , x , x_attn_mask , x_freqs_cis , adaln_input )
593+ else :
594+ for layer in self .noise_refiner :
595+ x = layer (x , x_attn_mask , x_freqs_cis , adaln_input )
605596
606597 # cap embed & refine
607598 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
@@ -614,29 +605,18 @@ def forward(
614605 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
615606 cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
616607
617- # Reuse padding tensors (convert dtype if needed)
618- cap_pad_tensor = pad_tensor .to (cap_feats [0 ].dtype ) if pad_tensor .dtype != cap_feats [0 ].dtype else pad_tensor
619- cap_freqs_pad_tensor = (
620- freqs_pad_tensor .to (cap_freqs_cis [0 ].dtype )
621- if freqs_pad_tensor .dtype != cap_freqs_cis [0 ].dtype
622- else freqs_pad_tensor
623- )
624- cap_attn_mask = torch .ones ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
625- for i , (item , freqs_item ) in enumerate (zip (cap_feats , cap_freqs_cis )):
626- seq_len = cap_item_seqlens [i ]
627- pad_len = cap_max_item_seqlen - seq_len
628- cap_feats [i ] = torch .cat ([item , cap_pad_tensor .repeat (pad_len , 1 )])
629- cap_freqs_cis [i ] = torch .cat ([freqs_item , cap_freqs_pad_tensor .repeat (pad_len , 1 )])
630- cap_attn_mask [i , seq_len :] = 0
631- cap_feats = torch .stack (cap_feats )
632- cap_freqs_cis = torch .stack (cap_freqs_cis )
633-
634- for layer in self .context_refiner :
635- cap_feats = layer (
636- cap_feats ,
637- cap_attn_mask ,
638- cap_freqs_cis ,
639- )
608+ cap_feats = pad_sequence (cap_feats , batch_first = True , padding_value = 0.0 )
609+ cap_freqs_cis = pad_sequence (cap_freqs_cis , batch_first = True , padding_value = 0.0 )
610+ cap_attn_mask = torch .zeros ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
611+ for i , seq_len in enumerate (cap_item_seqlens ):
612+ cap_attn_mask [i , :seq_len ] = 1
613+
614+ if torch .is_grad_enabled () and self .gradient_checkpointing :
615+ for layer in self .context_refiner :
616+ cap_feats = self ._gradient_checkpointing_func (layer , cap_feats , cap_attn_mask , cap_freqs_cis )
617+ else :
618+ for layer in self .context_refiner :
619+ cap_feats = layer (cap_feats , cap_attn_mask , cap_freqs_cis )
640620
641621 # unified
642622 unified = []
@@ -650,29 +630,18 @@ def forward(
650630 assert unified_item_seqlens == [len (_ ) for _ in unified ]
651631 unified_max_item_seqlen = max (unified_item_seqlens )
652632
653- unified_pad_tensor = pad_tensor .to (unified [0 ].dtype ) if pad_tensor .dtype != unified [0 ].dtype else pad_tensor
654- unified_freqs_pad_tensor = (
655- freqs_pad_tensor .to (unified_freqs_cis [0 ].dtype )
656- if freqs_pad_tensor .dtype != unified_freqs_cis [0 ].dtype
657- else freqs_pad_tensor
658- )
659- unified_attn_mask = torch .ones ((bsz , unified_max_item_seqlen ), dtype = torch .bool , device = device )
660- for i , (item , freqs_item ) in enumerate (zip (unified , unified_freqs_cis )):
661- seq_len = unified_item_seqlens [i ]
662- pad_len = unified_max_item_seqlen - seq_len
663- unified [i ] = torch .cat ([item , unified_pad_tensor .repeat (pad_len , 1 )])
664- unified_freqs_cis [i ] = torch .cat ([freqs_item , unified_freqs_pad_tensor .repeat (pad_len , 1 )])
665- unified_attn_mask [i , seq_len :] = 0
666- unified = torch .stack (unified )
667- unified_freqs_cis = torch .stack (unified_freqs_cis )
668-
669- for layer in self .layers :
670- unified = layer (
671- unified ,
672- unified_attn_mask ,
673- unified_freqs_cis ,
674- adaln_input ,
675- )
633+ unified = pad_sequence (unified , batch_first = True , padding_value = 0.0 )
634+ unified_freqs_cis = pad_sequence (unified_freqs_cis , batch_first = True , padding_value = 0.0 )
635+ unified_attn_mask = torch .zeros ((bsz , unified_max_item_seqlen ), dtype = torch .bool , device = device )
636+ for i , seq_len in enumerate (unified_item_seqlens ):
637+ unified_attn_mask [i , :seq_len ] = 1
638+
639+ if torch .is_grad_enabled () and self .gradient_checkpointing :
640+ for layer in self .layers :
641+ unified = self ._gradient_checkpointing_func (layer , unified , unified_attn_mask , unified_freqs_cis , adaln_input )
642+ else :
643+ for layer in self .layers :
644+ unified = layer (unified , unified_attn_mask , unified_freqs_cis , adaln_input )
676645
677646 unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](unified , adaln_input )
678647 unified = list (unified .unbind (dim = 0 ))
0 commit comments