-
Notifications
You must be signed in to change notification settings - Fork 301
Porting Gemma 2 transformers checkpoint #1678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Porting Gemma 2 transformers checkpoint #1678
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This all looks good to me. Let's add a test.
@mattdangerw @grasskin this PR is ready for review! Note: The KerasNLP Gemma 2 model works only on the JAX backend (for the time being) Also thanks to the Hugging Face team (Matt et. al.) for providing me with compute to work on this model. |
if transformers_config["model_type"] == "gemma": | ||
port_weight( | ||
keras_variable=decoder_layer.pre_ffw_norm.variables[0], | ||
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", | ||
) | ||
elif transformers_config["model_type"] == "gemma2": | ||
port_weight( | ||
keras_variable=decoder_layer.pre_ffw_norm.variables[0], | ||
hf_weight_key=f"model.layers.{i}.pre_feedforward_layernorm.weight", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done in order to align the gemma 1 and gemma 2 checkpoints.
I am open to better ways to go around it.
Thanks! |
Porting Gemma 2 transformers checkpoints in Keras NLP