Use correct mask for packed inputs in Qwen-VL #44157
Use correct mask for packed inputs in Qwen-VL #44157zucchini-nlp wants to merge 7 commits intohuggingface:mainfrom
Conversation
|
run-slow: qwen2_vl, qwen2_5_vl |
|
This comment contains models: ["models/qwen2_5_vl", "models/qwen2_vl"] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3 |
| hidden_states: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| rotary_pos_emb: torch.Tensor | None = None, | ||
| attention_mask: torch.Tensor | None = None, |
There was a problem hiding this comment.
unused arg, prob was a bad copy from other models. We use only position_embeddings
There was a problem hiding this comment.
Oops, very likely yes 😓 I started from qwen vl and opted to make it a bit cleaner back then iirc
|
This comment contains models: ["models/ernie4_5_vl_moe", "models/glm4v", "models/glm4v_moe", "models/glm_ocr", "models/paddleocr_vl", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe", "models/video_llama_3"] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3 |
vasqu
left a comment
There was a problem hiding this comment.
Some comments from my side, definitely the right way to go! We should aim to natively support this in our mask API (without vmap, i.e. the and mask fn) --> otherwise we lose quite a lot of perf
| hidden_states: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| rotary_pos_emb: torch.Tensor | None = None, | ||
| attention_mask: torch.Tensor | None = None, |
There was a problem hiding this comment.
Oops, very likely yes 😓 I started from qwen vl and opted to make it a bit cleaner back then iirc
| ] | ||
| attn_output = torch.cat(attn_outputs, dim=1) | ||
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() # FA-kwargs | ||
| attn_output, _ = attention_interface( |
There was a problem hiding this comment.
Just noticing it now, but we never collected the attention weights of the vision side 👀
There was a problem hiding this comment.
in qwen-vl we did collect as a list of 4D tensors per layer. Prob it wasn't standard and was not copied everywhere
| cu_seq_lens_k=cu_seqlens, | ||
| max_length_q=max_seqlen, | ||
| max_length_k=max_seqlen, | ||
| is_causal=False, |
There was a problem hiding this comment.
| is_causal=False, |
I know this comes from previous implementations, but we really shouldn't pass this manually ourselves but rely on self.is_causal
| config=self.config, | ||
| inputs_embeds=hidden_states[None, ...], | ||
| attention_mask=None, | ||
| and_mask_function=packed_sequence_mask_function(packed_sequence), |
There was a problem hiding this comment.
and_mask_function is quite expensive on runtime. I'd prefer if we could natively integrate it into our mask API instead, not to use vmap (should work OOB)
There was a problem hiding this comment.
hmm, what do you mean by "natively"? I could prepare by looping over cu seq lens and un-masking each block and keep it as a small fn in model file. In that case we don't use create_bidiractional_mask
There was a problem hiding this comment.
If you check the new mask API, we force vmap when we pass and/or mask functions
transformers/src/transformers/masking_utils.py
Lines 1003 to 1017 in efdcbc7
Natively in this case would mean either a new kwarg that we can use to control packed sequences as well (instead of pos ids) or a new function similar to how sliding window is extended
transformers/src/transformers/masking_utils.py
Line 1154 in efdcbc7
The problem with and/or masks is that we can have no idea about their functionality and vmaping is powerful enough to support almost anything, so we are kinda forced to vmap for compatibility/safety
| for q, k, v in zip(*splits) | ||
| ] | ||
| attn_output = torch.cat(attn_outputs, dim=1) | ||
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() # FA-kwargs |
There was a problem hiding this comment.
Unsure how / where max_seqlen would be called down the line, but I remember a lot of .item() to match FA signature (e.g. for compile); so might be smarter to call it once and before entering each attention module
There was a problem hiding this comment.
yeah, might do in the VisionModel once. IIRC fa wouldn't prepare max length for us so we have to pass it explciitly
What does this PR do?
As per title, gets rid of
if/elseper attn implementation