diff --git a/keras_nlp/src/models/gemma/gemma_backbone.py b/keras_nlp/src/models/gemma/gemma_backbone.py index 9a6ed5ac8a..c6bca59bc1 100644 --- a/keras_nlp/src/models/gemma/gemma_backbone.py +++ b/keras_nlp/src/models/gemma/gemma_backbone.py @@ -54,9 +54,9 @@ class GemmaBackbone(Backbone): layer_norm_epsilon: float. The epsilon value user for every layer norm in the transformer model. dropout: float. Dropout probability for the Transformer encoder. - query_head_dim_normalize: boolean. Whether to normalize attention with - head dimension or hidden_dim/num_query_heads. Gemma2 uses the - second option. Defaults to True. + query_head_dim_normalize: boolean. If `True` normalize the query before + attention with `head_dim`. If `False`, normalize the query with + `hidden_dim / num_query_heads`. Defaults to True. use_post_ffw_norm: boolean. Whether to normalize after the feedforward block. Defaults to False. use_post_attention_norm: boolean. Whether to normalize after the attention diff --git a/keras_nlp/src/utils/transformers/convert_gemma.py b/keras_nlp/src/utils/transformers/convert_gemma.py index 53894193ca..ef2cd093c5 100644 --- a/keras_nlp/src/utils/transformers/convert_gemma.py +++ b/keras_nlp/src/utils/transformers/convert_gemma.py @@ -59,7 +59,10 @@ def load_gemma_backbone(cls, preset, load_weights): "hidden_dim": transformers_config["hidden_size"], "intermediate_dim": transformers_config["intermediate_size"] * 2, "head_dim": transformers_config["head_dim"], - "query_head_dim_normalize": False, + "query_head_dim_normalize": ( + transformers_config["head_dim"] + == transformers_config["query_pre_attn_scalar"] + ), "use_post_ffw_norm": True, "use_post_attention_norm": True, "final_logit_soft_cap": transformers_config[