diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index dbd16f0b7e6a..248dbfcbbbd7 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -222,9 +222,8 @@ def unshape(x): k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) - scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, q_length, k_length) - mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)