From f86824be81b508a365c36ec0e900e92cce34d8e4 Mon Sep 17 00:00:00 2001 From: Kun Chen <76744114+Kwen-Chen@users.noreply.github.com> Date: Thu, 23 May 2024 05:52:39 +0800 Subject: [PATCH] Add Ulysses DistributedAttention compatibility (#5525) The `DistributedAttention` in DeepSpeed-Ulysses has a compatibility with the training code in [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/model/transformer.py#L811) because it only takes sequential sequences as input parameters. However, this is not compatible with the frequently used scenarios of specifying parameters, such as the following scenario when using Flash Attention: ```python ulysses_attn = DistributedAttention(local_attention=flash_attn_func, sequence_process_group=None, scatter_idx=2, gather_idx=1) attn_output = ulysses_attn( query_states, key_states, value_states, dropout, softmax_scale, causal=causal, ) ``` Therefore, the `**kwargs` parameter has been added to increase compatibility with more local attention, while making minimal code modifications. Co-authored-by: Kwen-Chen <2133949025@qq.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/sequence/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e1dbff87f4ec..a876596fb7bb 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -81,7 +81,7 @@ def __init__( self.scatter_idx = scatter_idx self.gather_idx = gather_idx - def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: """ forward Arguments: @@ -101,7 +101,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)