Skip to content

Commit

Permalink
move attention calculation with mask to fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
zhtmike committed Mar 12, 2024
1 parent 9574946 commit ae4bc9a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mindcv/models/navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def construct(self, x, context=None, token_mask=None):
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)

# fp32 for softmax
attn = attn.to(ms.float32)
if token_mask is not None:
token_mask = ops.unsqueeze(token_mask, 1)
attn = ops.masked_fill(attn, ~token_mask, -ms.numpy.inf)

dtype = attn.dtype
attn = ops.softmax(attn.to(ms.float32), axis=-1).to(dtype)
attn = ops.softmax(attn, axis=-1).to(v.dtype)
attn = self.attn_drop(attn)

out = self.attn_matmul_v(attn, v)
Expand Down

0 comments on commit ae4bc9a

Please sign in to comment.