Skip to content

Commit

Permalink
Fix weight tying in TF-ESM (#22839)
Browse files Browse the repository at this point in the history
Fix weight tying in ESM
  • Loading branch information
Rocketknight1 committed Apr 20, 2023
1 parent 3b61d28 commit 6dc0a84
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/transformers/models/esm/modeling_tf_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch ESM model."""

import os
from typing import Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1102,6 +1103,11 @@ def __init__(self, config):

self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
self.lm_head = TFEsmLMHead(config, name="lm_head")
if config.tie_word_embeddings:
# Ensure word embeddings are built so that we actually have something to tie
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
self.esm.embeddings.word_embeddings.build((None, None))
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]

def get_output_embeddings(self):
return self.lm_head.decoder
Expand Down Expand Up @@ -1211,18 +1217,22 @@ def __init__(self, config, name=None):

self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")

self.decoder = Dense(
config.vocab_size,
use_bias=False,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
)
self.decoder = None
self.config = config

def build(self, input_shape):
super().build(input_shape)
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if not self.config.tie_word_embeddings:
if self.decoder is not None:
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
self.decoder = self.add_weight(
"decoder.weight",
shape=(self.config.hidden_size, self.config.vocab_size),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
)
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)

def get_bias(self):
Expand All @@ -1234,8 +1244,7 @@ def call(self, features):
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = self.decoder(x)
x = x + self.bias
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
return x


Expand Down
18 changes: 18 additions & 0 deletions tests/models/esm/test_modeling_tf_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,24 @@ def test_resize_token_embeddings(self):
def test_save_load_after_resize_token_embeddings(self):
pass

def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class is TFEsmForMaskedLM:
# Output embedding test differs from the main test because they're a matrix, not a layer
name = model.get_bias()
assert isinstance(name, dict)
for k, v in name.items():
assert isinstance(v, tf.Variable)
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None


@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):
Expand Down

0 comments on commit 6dc0a84

Please sign in to comment.