Skip to content

Conversation

@akashpalla
Copy link

@akashpalla akashpalla commented Oct 16, 2025

What does this PR do?

Fixes #41668, Enabling CLIP to work with Flash attention 3

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @Cyrilvallez

@akashpalla akashpalla changed the title CLIP support for FA3 CLIP support for flash-attention-3 Oct 16, 2025
@akashpalla akashpalla marked this pull request as ready for review October 16, 2025 20:31
@akashpalla akashpalla changed the title CLIP support for flash-attention-3 [FIX]: CLIP support for flash-attention-3 Oct 16, 2025
Comment on lines 322 to 323
if "flash" in self.config._attn_implementation:
self.is_causal = causal_attention_mask is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, first time seeing it. In general it's not a good idea to dynamically change self attributes in forward call. This can have unwanted side-effects especially if the attention is changed dynamically after loading with model.set_attn_implementation("eager")

cc @vasqu for attention interfaces, I remember there was a discussion on self.is_causal somewhere. Do we have a suggested way to define is_causal based on inputs. For CLIP we can have causal mask if the attention is used with text and otherwise it's bidirectional

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the models I was involved with, have never seen dynamic changes to the attribute 👀 we need this for fa and sdpa attention (sdpa is less important when a mask is always provided)

I think there is no good alternative for now. I thought about allowing is_causal as kwarg to take precedence over the attribute - already possible in flash attn, would need adjustments in sdpa tho. Could already be done here ig, the sdpa case is already ignored either way it seems 😓

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fa logic:

# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = kwargs.pop("is_causal", None)
if is_causal is None:
is_causal = module.is_causal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allowing is_causal as kwarg to take precedence over the attribute

Personally I prefer this more to give us freedom and re-use single attention layer for multimodal models like CLIP

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, totally fair. I'll open a PR for SDPA to handle this. In the meantime, we can already adjust this here as well, setting it to a local attribute and pass it then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akashpalla can you fix this part and instead try to pass an is_causal=True/False arg to the attention function?

Comment on lines 609 to 616
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)

# expand attention_mask
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
if attention_mask is not None and "flash" not in self.config._attn_implementation:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to see the masking utils used here, for ex from ...masking_utils import create_causal_mask, create_bidirectional_mask. They'll handle attention implementations internally

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not super familiar with attention/masking, so I tried hacking together a solution based this on patterns I saw in other models but this resulted in a null causal_attention_mask for SDPA + FA and broke some tests

cache_position = torch.arange(
    hidden_states.shape[1], device=hidden_states.device
)

causal_attention_mask = create_causal_mask(
    config=self.config,
    input_embeds=hidden_states,
    attention_mask=attention_mask,
    cache_position=cache_position,
    past_key_values=None,
    position_ids=position_ids,
)

if attention_mask is not None:
    attention_mask = create_bidirectional_mask(
        config=self.config,
        input_embeds=hidden_states,
        attention_mask=attention_mask,
    )
    

If there’s a better way to handle this, I’d love some guidance — or if you prefer, this could also be addressed in a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resulted in a null causal_attention_mask for SDPA + FA and broke some tests

yeah, the null is actually intended and in that case we fallback to the FA kernel and the masking is controlled by is_causal attr. So imo the issue is that is_causal is not set to correct value. Can you check in the tests what is the value now and what did the mask look like before this PR?

Same for FA2 tests, must be the mask

Copy link
Contributor

@vasqu vasqu Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, if you look closer _create_4d_causal_attention_mask is a very old mask API. If I see it correctly, then what happened here is that only a normal causal attention mask with padding is created atm (by adding them up) - _create_4d_causal_attention_mask creates a causal mask, the other mask creates the padding mask.

Since it relies on a 4D mask, FA is not supported there (also marked within CLIP accordingly for anything text-related).

What should happen instead --> use new causal mask API create_causal_mask and only generate one mask --> pass is_causal=True for the layers below --> remove the overcomplicated FA logic and masking logic with additions. This should also enable FA support in text-based situations. I'm not sure if this is out-of-scope / too hard here; I would try looking into it myself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we need to merge with main, for SDPA to work with is_causal kwarg

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: clip, metaclip_2

Comment on lines -340 to +341
is_causal=self.is_causal,
is_causal=is_causal,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do is_causal=causal_attention_mask is not None or self.is_text_attention for all cases of attentions. Rn prob it is not differentiating correctly between vision or text attention because the presence of causal mask is not a marker with new masking utility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be careful with inline conditionals, they have caused troubles with torch compile in the past 😓

@vasqu
Copy link
Contributor

vasqu commented Oct 20, 2025

Sorry about this but #41750 will likely supersede this PR. Things got a bit more complicated than initially anticipated 😓

@akashpalla
Copy link
Author

Closing in favor of: #41750. Thanks for the quick fix

@akashpalla akashpalla closed this Oct 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CLIP incompatible with Flash Attention 3

3 participants