Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider using PyTorch 2.0 version of FlashAttention (remove dependency on flash-attn) #103

Closed
Sciumo opened this issue May 11, 2023 · 2 comments
Assignees

Comments

@Sciumo
Copy link

Sciumo commented May 11, 2023

Per: Dao-AILab/flash-attention#203

if you're using pytorch 2.0 then FlashAttention is already available through torch.nn.functional.scaled_dot_product_attention.

The flash-attn project has build problems for many people.
Is it possible to consider using PyTorch 2.0 equivalent Flash Attention?

@vchiley vchiley self-assigned this May 11, 2023
@vchiley
Copy link
Contributor

vchiley commented May 11, 2023

For MPT we need to be able to use causal=True and we'd need to use attn_mask (aka attn_bias) to have ALiBi.

The variant exposed in scaled_dot_product_attention docs does not allow both. From docs:
Screenshot 2023-05-10 at 9 58 06 PM

@Sciumo
Copy link
Author

Sciumo commented May 11, 2023

So this issues needs to propagate to PyTorch.
Until then, I'll work with Flash Attn then as is.
Thanks.

@Sciumo Sciumo closed this as completed May 11, 2023
bmosaicml pushed a commit that referenced this issue Jun 6, 2023
Single-line typo fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants