-
Notifications
You must be signed in to change notification settings - Fork 301
Port bart transformer checkpoint #1783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,372 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
|
||
from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE | ||
from keras_nlp.src.utils.preset_utils import get_file | ||
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup | ||
from keras_nlp.src.utils.preset_utils import load_config | ||
from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader | ||
|
||
|
||
def convert_backbone_config(transformers_config): | ||
return { | ||
"vocabulary_size": transformers_config["vocab_size"], | ||
"num_layers": transformers_config["num_hidden_layers"], | ||
"num_heads": transformers_config["encoder_attention_heads"], | ||
"hidden_dim": transformers_config["d_model"], | ||
"intermediate_dim": transformers_config["encoder_ffn_dim"], | ||
"dropout": transformers_config["dropout"], | ||
"max_sequence_length": transformers_config["max_position_embeddings"], | ||
} | ||
|
||
|
||
def convert_weights(backbone, loader): | ||
# Embeddings | ||
loader.port_weight( | ||
keras_variable=backbone.token_embedding.embeddings, | ||
hf_weight_key="shared.weight", | ||
) | ||
|
||
# Encoder blocks | ||
for index in range(backbone.num_layers): | ||
encoder_layer = backbone.encoder_transformer_layers[index] | ||
encoder_self_attention = encoder_layer._self_attention_layer | ||
hf_encoder_prefix = f"encoder.layers.{index}" | ||
|
||
# Norm layers | ||
loader.port_weight( | ||
keras_variable=encoder_layer._self_attention_layer_norm.gamma, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._self_attention_layer_norm.beta, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_layer_norm.gamma, | ||
hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_layer_norm.beta, | ||
hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.bias", | ||
) | ||
|
||
# Self Attention layers | ||
# Query | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.query_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.query_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Key | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.key_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.key_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Value | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.value_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.value_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Output | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.output_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_self_attention.output_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# MLP layers | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_intermediate_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.fc1.weight", | ||
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_intermediate_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.fc1.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_output_dense.kernel, | ||
hf_weight_key=f"{hf_encoder_prefix}.fc2.weight", | ||
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), | ||
) | ||
loader.port_weight( | ||
keras_variable=encoder_layer._feedforward_output_dense.bias, | ||
hf_weight_key=f"{hf_encoder_prefix}.fc2.bias", | ||
) | ||
|
||
# Decoder blocks | ||
for index in range(backbone.num_layers): | ||
decoder_layer = backbone.decoder_transformer_layers[index] | ||
decoder_self_attention = decoder_layer._self_attention_layer | ||
decoder_cross_attention = decoder_layer._cross_attention_layer | ||
hf_decoder_prefix = f"decoder.layers.{index}" | ||
|
||
# Norm layers | ||
loader.port_weight( | ||
keras_variable=decoder_layer._self_attention_layer_norm.gamma, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._self_attention_layer_norm.beta, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_layer_norm.gamma, | ||
hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_layer_norm.beta, | ||
hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._cross_attention_layer_norm.gamma, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._cross_attention_layer_norm.beta, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.bias", | ||
) | ||
|
||
# Self Attention layers | ||
# Query | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.query_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.query_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Key | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.key_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.key_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Value | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.value_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.value_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Output | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.output_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_self_attention.output_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# MLP layers | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.fc1.weight", | ||
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_intermediate_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.fc1.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_output_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.fc2.weight", | ||
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_layer._feedforward_output_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.fc2.bias", | ||
) | ||
|
||
# Cross Attention Layers | ||
# Query | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.query_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.query_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Key | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.key_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.key_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Value | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.value_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.value_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Output | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.output_dense.kernel, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.weight", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
loader.port_weight( | ||
keras_variable=decoder_cross_attention.output_dense.bias, | ||
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.bias", | ||
hook_fn=lambda hf_tensor, keras_shape: np.reshape( | ||
np.transpose(hf_tensor), keras_shape | ||
), | ||
) | ||
|
||
# Normalization | ||
loader.port_weight( | ||
keras_variable=backbone.encoder_embeddings_layer_norm.gamma, | ||
hf_weight_key="encoder.layernorm_embedding.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=backbone.encoder_embeddings_layer_norm.beta, | ||
hf_weight_key="encoder.layernorm_embedding.bias", | ||
) | ||
loader.port_weight( | ||
keras_variable=backbone.decoder_embeddings_layer_norm.gamma, | ||
hf_weight_key="decoder.layernorm_embedding.weight", | ||
) | ||
loader.port_weight( | ||
keras_variable=backbone.decoder_embeddings_layer_norm.beta, | ||
hf_weight_key="decoder.layernorm_embedding.bias", | ||
) | ||
|
||
return backbone | ||
|
||
|
||
def load_bart_backbone(cls, preset, load_weights): | ||
transformers_config = load_config(preset, HF_CONFIG_FILE) | ||
keras_config = convert_backbone_config(transformers_config) | ||
backbone = cls(**keras_config) | ||
if load_weights: | ||
jax_memory_cleanup(backbone) | ||
with SafetensorLoader(preset) as loader: | ||
convert_weights(backbone, loader) | ||
return backbone | ||
|
||
|
||
def load_bart_tokenizer(cls, preset): | ||
vocab_file = get_file(preset, "vocab.json") | ||
merges_file = get_file(preset, "merges.txt") | ||
return cls( | ||
vocabulary=vocab_file, | ||
merges=merges_file, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.