Skip to content

Remove unsupported modeling flags#950

Merged
dacorvo merged 11 commits intomainfrom
remove_unsupported_modeling_flags
Sep 5, 2025
Merged

Remove unsupported modeling flags#950
dacorvo merged 11 commits intomainfrom
remove_unsupported_modeling_flags

Conversation

@dacorvo
Copy link
Collaborator

@dacorvo dacorvo commented Sep 3, 2025

What does this PR do?

This removes some optimized code paths that are not supported with the current models (mainly Llama and its variants) and/or deployment configurations (Trainium instances).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dacorvo dacorvo force-pushed the remove_unsupported_modeling_flags branch from 2265176 to 63110ce Compare September 3, 2025 14:36
This kernel is not supported on Trainium 1, and its integration is likely to
change anyway when we support Trainium 2.
This kernel is not supported on Trainium 1, and its integration is
likely to change anyway when we support Trainium 2.
@dacorvo dacorvo force-pushed the remove_unsupported_modeling_flags branch from 63110ce to a13a4a2 Compare September 5, 2025 09:48
Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Note that flash attention is still used for prefill whenever it is
relevant.
It is never set to something else than 2, and it can be overridden in
the model compiler flags anyway.
@dacorvo
Copy link
Collaborator Author

dacorvo commented Sep 5, 2025

@tengomucho thank you for the review, but I realized that while removing flash decoding I had also inadvertently removed flash attention for prefill. Will force-push.

@dacorvo dacorvo force-pushed the remove_unsupported_modeling_flags branch from a13a4a2 to 3523b40 Compare September 5, 2025 14:02

return FlashAttentionStrategy.NONE

def compute_for_flash_decoding(self, Q, K, V, past_key_value, attention_mask, active_mask) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you need to restore these bits

@dacorvo dacorvo merged commit e5a0faf into main Sep 5, 2025
8 of 9 checks passed
@dacorvo dacorvo deleted the remove_unsupported_modeling_flags branch September 5, 2025 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants