Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#43 from eric-tc-wong/patch-1
Browse files Browse the repository at this point in the history
Update flash_attention.py
  • Loading branch information
tridao committed Sep 6, 2022
2 parents 19d1261 + b410d14 commit 04fb198
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flash_attn/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, x, key_padding_mask=None, need_weights=False):
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
h=self.num_heads).unbind(dim=2)
query, key = self.rotary_emb(query, key, seq_dimension=-3)
qkv = torch.stack([query, key, value], dim=2)
qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2)
else:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
Expand Down

0 comments on commit 04fb198

Please sign in to comment.