diff --git a/keras_hub/src/models/llama/llama_backbone.py b/keras_hub/src/models/llama/llama_backbone.py index f4d3dd8730..a654bdf267 100644 --- a/keras_hub/src/models/llama/llama_backbone.py +++ b/keras_hub/src/models/llama/llama_backbone.py @@ -90,13 +90,14 @@ def __init__( layer_norm_epsilon=1e-6, dropout=0, dtype=None, + tie_word_embeddings=False, **kwargs, ): # === Layers === self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, - tie_weights=False, + tie_weights=tie_word_embeddings, embeddings_initializer=_llama_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", @@ -155,6 +156,7 @@ def __init__( self.rope_scaling_factor = rope_scaling_factor self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout + self.tie_word_embeddings = tie_word_embeddings def get_config(self): config = super().get_config() diff --git a/keras_hub/src/utils/transformers/convert_llama3.py b/keras_hub/src/utils/transformers/convert_llama3.py index 1207e5bc78..08e982e862 100644 --- a/keras_hub/src/utils/transformers/convert_llama3.py +++ b/keras_hub/src/utils/transformers/convert_llama3.py @@ -14,6 +14,7 @@ def convert_backbone_config(transformers_config): "hidden_dim": transformers_config["hidden_size"], "intermediate_dim": transformers_config["intermediate_size"], "num_key_value_heads": transformers_config["num_key_value_heads"], + "tie_word_embeddings": transformers_config["tie_word_embeddings"], } @@ -22,12 +23,15 @@ def convert_weights(backbone, loader, transformers_config): keras_variable=backbone.get_layer("token_embedding").embeddings, hf_weight_key="model.embed_tokens.weight", ) - loader.port_weight( - keras_variable=backbone.get_layer("token_embedding").reverse_embeddings, - hf_weight_key="lm_head.weight", - # rearrange_pattern="b a -> a b", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), - ) + if not backbone.tie_word_embeddings: + loader.port_weight( + keras_variable=backbone.get_layer( + "token_embedding" + ).reverse_embeddings, + hf_weight_key="lm_head.weight", + # rearrange_pattern="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) def transpose_and_reshape(x, shape): return np.reshape(np.transpose(x), shape)