From 0fd3b20d71a8f95d3cec3757ab8f252e6ba8e3f0 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 24 May 2022 00:33:46 -0700 Subject: [PATCH 1/3] Minor fixes to token and pos embedding --- keras_nlp/layers/token_and_position_embedding.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index e2607ee651..019520aa33 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -80,15 +80,20 @@ def __init__( self.vocabulary_size = int(vocabulary_size) self.sequence_length = int(sequence_length) self.embedding_dim = int(embedding_dim) + self.embeddings_initializer = keras.initializers.get( + embeddings_initializer + ) self.token_embedding = keras.layers.Embedding( vocabulary_size, embedding_dim, - embeddings_initializer=embeddings_initializer, + embeddings_initializer=self.embeddings_initializer, mask_zero=mask_zero, + name="token_embedding", ) self.position_embedding = keras_nlp.layers.PositionEmbedding( sequence_length=sequence_length, - initializer=embeddings_initializer, + initializer=self.embeddings_initializer, + name="position_embedding", ) self.supports_masking = self.token_embedding.supports_masking @@ -100,7 +105,7 @@ def get_config(self): "sequence_length": self.sequence_length, "embedding_dim": self.embedding_dim, "embeddings_initializer": keras.initializers.serialize( - self.token_embedding.embeddings_initializer + self.embeddings_initializer ), "mask_zero": self.token_embedding.mask_zero, }, From 17c6450d555aedbfaf01c097d81c89bbeac761e1 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 24 May 2022 10:30:20 -0700 Subject: [PATCH 2/3] fix names --- keras_nlp/layers/token_and_position_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index 019520aa33..7e7ef7a5c2 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -88,12 +88,12 @@ def __init__( embedding_dim, embeddings_initializer=self.embeddings_initializer, mask_zero=mask_zero, - name="token_embedding", + name=keras.backend.get_uid("token_embedding"), ) self.position_embedding = keras_nlp.layers.PositionEmbedding( sequence_length=sequence_length, initializer=self.embeddings_initializer, - name="position_embedding", + name=keras.backend.get_uid("position_embedding"), ) self.supports_masking = self.token_embedding.supports_masking From b5d11db59581ae4a45a43f5a2024b9c50d9c189f Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 24 May 2022 10:38:27 -0700 Subject: [PATCH 3/3] fix get_uid --- keras_nlp/layers/token_and_position_embedding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/token_and_position_embedding.py b/keras_nlp/layers/token_and_position_embedding.py index 7e7ef7a5c2..cd21d40c18 100644 --- a/keras_nlp/layers/token_and_position_embedding.py +++ b/keras_nlp/layers/token_and_position_embedding.py @@ -88,12 +88,14 @@ def __init__( embedding_dim, embeddings_initializer=self.embeddings_initializer, mask_zero=mask_zero, - name=keras.backend.get_uid("token_embedding"), + name="token_embedding" + + str(keras.backend.get_uid("token_embedding")), ) self.position_embedding = keras_nlp.layers.PositionEmbedding( sequence_length=sequence_length, initializer=self.embeddings_initializer, - name=keras.backend.get_uid("position_embedding"), + name="position_embedding" + + str(keras.backend.get_uid("position_embedding")), ) self.supports_masking = self.token_embedding.supports_masking