Skip to content

Commit

Permalink
Add Ulysses DistributedAttention compatibility (#5525)
Browse files Browse the repository at this point in the history
The `DistributedAttention` in DeepSpeed-Ulysses has a compatibility with
the training code in
[Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py#L811)
because it only takes sequential sequences as input parameters. However,
this is not compatible with the frequently used scenarios of specifying
parameters, such as the following scenario when using Flash Attention:
```python
ulysses_attn = DistributedAttention(local_attention=flash_attn_func, sequence_process_group=None, scatter_idx=2, gather_idx=1)

attn_output = ulysses_attn(
    query_states,
    key_states,
    value_states,
    dropout,
    softmax_scale,
    causal=causal,
)

```
Therefore, the `**kwargs` parameter has been added to increase
compatibility with more local attention, while making minimal code
modifications.

Co-authored-by: Kwen-Chen <2133949025@qq.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
4 people committed May 22, 2024
1 parent 995ba11 commit f86824b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx

def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
""" forward
Arguments:
Expand All @@ -101,7 +101,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

Expand Down

0 comments on commit f86824b

Please sign in to comment.