-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[FIX]: CLIP support for flash-attention-3 #41673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| if "flash" in self.config._attn_implementation: | ||
| self.is_causal = causal_attention_mask is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, first time seeing it. In general it's not a good idea to dynamically change self attributes in forward call. This can have unwanted side-effects especially if the attention is changed dynamically after loading with model.set_attn_implementation("eager")
cc @vasqu for attention interfaces, I remember there was a discussion on self.is_causal somewhere. Do we have a suggested way to define is_causal based on inputs. For CLIP we can have causal mask if the attention is used with text and otherwise it's bidirectional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the models I was involved with, have never seen dynamic changes to the attribute 👀 we need this for fa and sdpa attention (sdpa is less important when a mask is always provided)
I think there is no good alternative for now. I thought about allowing is_causal as kwarg to take precedence over the attribute - already possible in flash attn, would need adjustments in sdpa tho. Could already be done here ig, the sdpa case is already ignored either way it seems 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fa logic:
transformers/src/transformers/integrations/flash_attention.py
Lines 66 to 69 in 38fdad1
| # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented | |
| is_causal = kwargs.pop("is_causal", None) | |
| if is_causal is None: | |
| is_causal = module.is_causal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allowing is_causal as kwarg to take precedence over the attribute
Personally I prefer this more to give us freedom and re-use single attention layer for multimodal models like CLIP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, totally fair. I'll open a PR for SDPA to handle this. In the meantime, we can already adjust this here as well, setting it to a local attribute and pass it then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@akashpalla can you fix this part and instead try to pass an is_causal=True/False arg to the attention function?
| causal_attention_mask = _create_4d_causal_attention_mask( | ||
| input_shape, hidden_states.dtype, device=hidden_states.device | ||
| ) | ||
|
|
||
| # expand attention_mask | ||
| if attention_mask is not None and self.config._attn_implementation != "flash_attention_2": | ||
| if attention_mask is not None and "flash" not in self.config._attn_implementation: | ||
| # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] | ||
| attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd love to see the masking utils used here, for ex from ...masking_utils import create_causal_mask, create_bidirectional_mask. They'll handle attention implementations internally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m not super familiar with attention/masking, so I tried hacking together a solution based this on patterns I saw in other models but this resulted in a null causal_attention_mask for SDPA + FA and broke some tests
cache_position = torch.arange(
hidden_states.shape[1], device=hidden_states.device
)
causal_attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
if attention_mask is not None:
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
)
If there’s a better way to handle this, I’d love some guidance — or if you prefer, this could also be addressed in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resulted in a null causal_attention_mask for SDPA + FA and broke some tests
yeah, the null is actually intended and in that case we fallback to the FA kernel and the masking is controlled by is_causal attr. So imo the issue is that is_causal is not set to correct value. Can you check in the tests what is the value now and what did the mask look like before this PR?
Same for FA2 tests, must be the mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit weird, if you look closer _create_4d_causal_attention_mask is a very old mask API. If I see it correctly, then what happened here is that only a normal causal attention mask with padding is created atm (by adding them up) - _create_4d_causal_attention_mask creates a causal mask, the other mask creates the padding mask.
Since it relies on a 4D mask, FA is not supported there (also marked within CLIP accordingly for anything text-related).
What should happen instead --> use new causal mask API create_causal_mask and only generate one mask --> pass is_causal=True for the layers below --> remove the overcomplicated FA logic and masking logic with additions. This should also enable FA support in text-based situations. I'm not sure if this is out-of-scope / too hard here; I would try looking into it myself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we need to merge with main, for SDPA to work with is_causal kwarg
|
[For maintainers] Suggested jobs to run (before merge) run-slow: clip, metaclip_2 |
| is_causal=self.is_causal, | ||
| is_causal=is_causal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do is_causal=causal_attention_mask is not None or self.is_text_attention for all cases of attentions. Rn prob it is not differentiating correctly between vision or text attention because the presence of causal mask is not a marker with new masking utility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be careful with inline conditionals, they have caused troubles with torch compile in the past 😓
|
Sorry about this but #41750 will likely supersede this PR. Things got a bit more complicated than initially anticipated 😓 |
|
Closing in favor of: #41750. Thanks for the quick fix |
What does this PR do?
Fixes #41668, Enabling CLIP to work with Flash attention 3
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @Cyrilvallez