@@ -90,6 +90,12 @@ class ZSingleStreamAttnProcessor:
9090 _attention_backend = None
9191 _parallel_config = None
9292
93+ def __init__ (self ):
94+ if not hasattr (F , "scaled_dot_product_attention" ):
95+ raise ImportError (
96+ "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
97+ )
98+
9399 def __call__ (
94100 self ,
95101 attn : Attention ,
@@ -493,7 +499,6 @@ def patchify_and_embed(
493499
494500 image_ori_len = len (image )
495501 image_padding_len = (- image_ori_len ) % SEQ_MULTI_OF
496- # padded_pos_ids
497502
498503 image_ori_pos_ids = self .create_coordinate_grid (
499504 size = (F_tokens , H_tokens , W_tokens ),
@@ -574,11 +579,7 @@ def forward(
574579 x = list (x .split (x_item_seqlens , dim = 0 ))
575580 x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
576581
577- pad_tensor = torch .zeros (
578- (1 , self .dim ),
579- dtype = x [0 ].dtype ,
580- device = device ,
581- )
582+ pad_tensor = torch .zeros ((1 , self .dim ), dtype = x [0 ].dtype , device = device )
582583 freqs_pad_tensor = torch .zeros (
583584 (1 , self .dim // self .n_heads // 2 ),
584585 dtype = x_freqs_cis [0 ].dtype ,
@@ -613,22 +614,19 @@ def forward(
613614 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
614615 cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
615616
616- pad_tensor = torch .zeros (
617- (1 , self .dim ),
618- dtype = cap_feats [0 ].dtype ,
619- device = device ,
620- )
621- freqs_pad_tensor = torch .zeros (
622- (1 , self .dim // self .n_heads // 2 ),
623- dtype = cap_freqs_cis [0 ].dtype ,
624- device = device ,
617+ # Reuse padding tensors (convert dtype if needed)
618+ cap_pad_tensor = pad_tensor .to (cap_feats [0 ].dtype ) if pad_tensor .dtype != cap_feats [0 ].dtype else pad_tensor
619+ cap_freqs_pad_tensor = (
620+ freqs_pad_tensor .to (cap_freqs_cis [0 ].dtype )
621+ if freqs_pad_tensor .dtype != cap_freqs_cis [0 ].dtype
622+ else freqs_pad_tensor
625623 )
626624 cap_attn_mask = torch .ones ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
627625 for i , (item , freqs_item ) in enumerate (zip (cap_feats , cap_freqs_cis )):
628626 seq_len = cap_item_seqlens [i ]
629627 pad_len = cap_max_item_seqlen - seq_len
630- cap_feats [i ] = torch .cat ([item , pad_tensor .repeat (pad_len , 1 )])
631- cap_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
628+ cap_feats [i ] = torch .cat ([item , cap_pad_tensor .repeat (pad_len , 1 )])
629+ cap_freqs_cis [i ] = torch .cat ([freqs_item , cap_freqs_pad_tensor .repeat (pad_len , 1 )])
632630 cap_attn_mask [i , seq_len :] = 0
633631 cap_feats = torch .stack (cap_feats )
634632 cap_freqs_cis = torch .stack (cap_freqs_cis )
@@ -652,22 +650,18 @@ def forward(
652650 assert unified_item_seqlens == [len (_ ) for _ in unified ]
653651 unified_max_item_seqlen = max (unified_item_seqlens )
654652
655- pad_tensor = torch .zeros (
656- (1 , self .dim ),
657- dtype = unified [0 ].dtype ,
658- device = device ,
659- )
660- freqs_pad_tensor = torch .zeros (
661- (1 , self .dim // self .n_heads // 2 ),
662- dtype = unified_freqs_cis [0 ].dtype ,
663- device = device ,
653+ unified_pad_tensor = pad_tensor .to (unified [0 ].dtype ) if pad_tensor .dtype != unified [0 ].dtype else pad_tensor
654+ unified_freqs_pad_tensor = (
655+ freqs_pad_tensor .to (unified_freqs_cis [0 ].dtype )
656+ if freqs_pad_tensor .dtype != unified_freqs_cis [0 ].dtype
657+ else freqs_pad_tensor
664658 )
665659 unified_attn_mask = torch .ones ((bsz , unified_max_item_seqlen ), dtype = torch .bool , device = device )
666660 for i , (item , freqs_item ) in enumerate (zip (unified , unified_freqs_cis )):
667661 seq_len = unified_item_seqlens [i ]
668662 pad_len = unified_max_item_seqlen - seq_len
669- unified [i ] = torch .cat ([item , pad_tensor .repeat (pad_len , 1 )])
670- unified_freqs_cis [i ] = torch .cat ([freqs_item , freqs_pad_tensor .repeat (pad_len , 1 )])
663+ unified [i ] = torch .cat ([item , unified_pad_tensor .repeat (pad_len , 1 )])
664+ unified_freqs_cis [i ] = torch .cat ([freqs_item , unified_freqs_pad_tensor .repeat (pad_len , 1 )])
671665 unified_attn_mask [i , seq_len :] = 0
672666 unified = torch .stack (unified )
673667 unified_freqs_cis = torch .stack (unified_freqs_cis )
0 commit comments