diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py index e997d2a72cc3..172194d1925d 100644 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ b/src/transformers/models/distilbert/modeling_tf_distilbert.py @@ -170,7 +170,7 @@ def call(self, query, key, value, mask, head_mask, output_attentions, training=F k_length = shape_list(key)[1] # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # assert key.size() == value.size() - dim_per_head = tf.math.divide(self.dim, self.n_heads) + dim_per_head = int(self.dim / self.n_heads) dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) mask_reshape = [bs, 1, 1, k_length]