Skip to content

Commit

Permalink
Revert "Error (also in original) model, scaling only q matrix not qk.…
Browse files Browse the repository at this point in the history
…T dot product (qk.T/sqrt(dim_per_head))" (#22444)

Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head)) (#21627)"

This reverts commit bad8300.
  • Loading branch information
sgugger committed Mar 29, 2023
1 parent 5e3b19a commit 4277b3d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/flaubert/modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def unshape(x):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)

scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/xlm/modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def unshape(x):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)

scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)

Expand Down

0 comments on commit 4277b3d

Please sign in to comment.