Skip to content

Use correct mask for packed inputs in Qwen-VL #44157

Open
zucchini-nlp wants to merge 7 commits intohuggingface:mainfrom
zucchini-nlp:qwen-packed-mask
Open

Use correct mask for packed inputs in Qwen-VL #44157
zucchini-nlp wants to merge 7 commits intohuggingface:mainfrom
zucchini-nlp:qwen-packed-mask

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

What does this PR do?

As per title, gets rid of if/else per attn implementation

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: qwen2_vl, qwen2_5_vl

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/qwen2_5_vl", "models/qwen2_vl"]
quantizations: []

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3

hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused arg, prob was a bad copy from other models. We use only position_embeddings

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, very likely yes 😓 I started from qwen vl and opted to make it a bit cleaner back then iirc

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 71eed1a0 workflow commit (merge commit)
PR cd08613c branch commit (from PR)
main 35324377 base commit (on main)

⚠️ No test being reported (jobs are skipped or cancelled)!

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/ernie4_5_vl_moe", "models/glm4v", "models/glm4v_moe", "models/glm_ocr", "models/paddleocr_vl", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe", "models/video_llama_3"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 03104e90 workflow commit (merge commit)
PR 19b52db6 branch commit (from PR)
main 35324377 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments from my side, definitely the right way to go! We should aim to natively support this in our mask API (without vmap, i.e. the and mask fn) --> otherwise we lose quite a lot of perf

hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, very likely yes 😓 I started from qwen vl and opted to make it a bit cleaner back then iirc

]
attn_output = torch.cat(attn_outputs, dim=1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() # FA-kwargs
attn_output, _ = attention_interface(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticing it now, but we never collected the attention weights of the vision side 👀

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in qwen-vl we did collect as a list of 4D tensors per layer. Prob it wasn't standard and was not copied everywhere

cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_causal=False,

I know this comes from previous implementations, but we really shouldn't pass this manually ourselves but rely on self.is_causal

config=self.config,
inputs_embeds=hidden_states[None, ...],
attention_mask=None,
and_mask_function=packed_sequence_mask_function(packed_sequence),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and_mask_function is quite expensive on runtime. I'd prefer if we could natively integrate it into our mask API instead, not to use vmap (should work OOB)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, what do you mean by "natively"? I could prepare by looping over cu seq lens and un-masking each block and keep it as a small fn in model file. In that case we don't use create_bidiractional_mask

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the new mask API, we force vmap when we pass and/or mask functions

# Allow slight deviations from the base mask
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
# padding mask, etc) as the resulting mask may otherwise not be correct!
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True

Natively in this case would mean either a new kwarg that we can use to control packed sequences as well (instead of pos ids) or a new function similar to how sliding window is extended

def create_bidirectional_sliding_window_mask(

The problem with and/or masks is that we can have no idea about their functionality and vmaping is powerful enough to support almost anything, so we are kinda forced to vmap for compatibility/safety

for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() # FA-kwargs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure how / where max_seqlen would be called down the line, but I remember a lot of .item() to match FA signature (e.g. for compile); so might be smarter to call it once and before entering each attention module

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, might do in the VisionModel once. IIRC fa wouldn't prepare max length for us so we have to pass it explciitly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants