You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Does the current implementation of FlashAttention account for the cases of sequence parallelism?
For example, here the self.core_attention_flash is called, however, the q, k, v passed into flash attention are not all-gathered (correct me if I'm wrong!!), which means we only calculate attention on a chunk of sentences (say total seq_len is L, parallelism being 4, we only calculate attention within L/4 sequence) which will cause issues in the trained models (i.e., the trained model might stop attending to previous content after L/4 tokens).
Their original implementation of ParallelAttention does not have this issue since they perform all-gather in the forward pass and reduce-scatter in the backward pass; see this issue for details.
To Reproduce
N/A
Expected behavior
We should perform all-gather across sequence parallel dimension before using flash-attention, and do reduce-scatter in the backward pass (just like the ParallelAttention implementation).
Stack trace/logs
N/A
The text was updated successfully, but these errors were encountered:
Never mind, I dug in deeper to do some interactive debugging and found out the all-gather operation happens implicitly in self.query_key_value, a ColumnParallelLinear that helps take care of the all-gather. The current implementation should be fine. :)
Describe the bug
Does the current implementation of FlashAttention account for the cases of sequence parallelism?
For example, here the
self.core_attention_flash
is called, however, theq, k, v
passed into flash attention are not all-gathered (correct me if I'm wrong!!), which means we only calculate attention on a chunk of sentences (say total seq_len isL
, parallelism being 4, we only calculate attention withinL/4
sequence) which will cause issues in the trained models (i.e., the trained model might stop attending to previous content afterL/4
tokens).Their original implementation of
ParallelAttention
does not have this issue since they perform all-gather in the forward pass and reduce-scatter in the backward pass; see this issue for details.To Reproduce
N/A
Expected behavior
We should perform all-gather across sequence parallel dimension before using flash-attention, and do reduce-scatter in the backward pass (just like the
ParallelAttention
implementation).Stack trace/logs
N/A
The text was updated successfully, but these errors were encountered: