diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index e2607ee651..cd21d40c18 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -80,15 +80,22 @@ def __init__( self.vocabulary_size = int(vocabulary_size) self.sequence_length = int(sequence_length) self.embedding_dim = int(embedding_dim) + self.embeddings_initializer = keras.initializers.get( + embeddings_initializer + ) self.token_embedding = keras.layers.Embedding( vocabulary_size, embedding_dim, - embeddings_initializer=embeddings_initializer, + embeddings_initializer=self.embeddings_initializer, mask_zero=mask_zero, + name="token_embedding" + + str(keras.backend.get_uid("token_embedding")), ) self.position_embedding = keras_nlp.layers.PositionEmbedding( sequence_length=sequence_length, - initializer=embeddings_initializer, + initializer=self.embeddings_initializer, + name="position_embedding" + + str(keras.backend.get_uid("position_embedding")), ) self.supports_masking = self.token_embedding.supports_masking @@ -100,7 +107,7 @@ def get_config(self): "sequence_length": self.sequence_length, "embedding_dim": self.embedding_dim, "embeddings_initializer": keras.initializers.serialize( - self.token_embedding.embeddings_initializer + self.embeddings_initializer ), "mask_zero": self.token_embedding.mask_zero, },