Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the input sharding match exact optimization of long sequence? #3

Closed
guanzhchen opened this issue Apr 6, 2024 · 2 comments
Closed

Comments

@guanzhchen
Copy link

Thanks for your exciting work!

I found the extract_local function seems to split the input sequence length L into L/world_size. Are parameters optimized (backward) for each chunk rather than the whole long sequence? So have you tried if there are any approximation errors or the optimization is length-agnostic?

@jzhang38
Copy link
Owner

jzhang38 commented Apr 6, 2024

Are parameters optimized (backward) for each chunk rather than the whole long sequence?

The whole sequence.

have you tried if there are any approximation errors or the optimization is length-agnostic?

https://github.com/zhuzilin/ring-flash-attention/blob/55ff66fd35f329dfcc24ce7a448bfdd532865966/test/test_zigzag_ring_flash_attn_func.py#L121

@guanzhchen
Copy link
Author

That makes sense! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants