diff --git a/keras_nlp/src/models/gemma/gemma_decoder_block.py b/keras_nlp/src/models/gemma/gemma_decoder_block.py index b05e1d3f0c..860e6a93a3 100644 --- a/keras_nlp/src/models/gemma/gemma_decoder_block.py +++ b/keras_nlp/src/models/gemma/gemma_decoder_block.py @@ -68,7 +68,7 @@ def __init__( self.post_attention_norm = RMSNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, - name="pre_attention_norm", + name="post_attention_norm", ) self.attention = CachedGemmaAttention(