From 564785e2dea6e116fe073e2a1d0016796a60b8a7 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 3 May 2022 11:29:32 -0400 Subject: [PATCH] Revert "Fix distilbert scaling (#43)" This reverts commit b84a90ac34dc739bd8c6779c7faac5c98257932b. --- src/transformers/models/distilbert/modeling_distilbert.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index dbd16f0b7e6a..248dbfcbbbd7 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -222,9 +222,8 @@ def unshape(x): k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) - scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, q_length, k_length) - mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)