Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from keras_nlp.src.utils.transformers.convert_bert import load_bert_tokenizer
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer
from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_backbone
from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_tokenizer
from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone
from keras_nlp.src.utils.transformers.convert_llama3 import (
load_llama3_tokenizer,
Expand Down Expand Up @@ -52,6 +54,8 @@ def load_transformers_backbone(cls, preset, load_weights):
return load_llama3_backbone(cls, preset, load_weights)
if cls.__name__ == "PaliGemmaBackbone":
return load_pali_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "GPT2Backbone":
return load_gpt2_backbone(cls, preset, load_weights)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
Expand Down Expand Up @@ -79,6 +83,8 @@ def load_transformers_tokenizer(cls, preset):
return load_llama3_tokenizer(cls, preset)
if cls.__name__ == "PaliGemmaTokenizer":
return load_pali_gemma_tokenizer(cls, preset)
if cls.__name__ == "GPT2Tokenizer":
return load_gpt2_tokenizer(cls, preset)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
Expand Down
186 changes: 186 additions & 0 deletions keras_nlp/src/utils/transformers/convert_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader


def convert_backbone_config(transformers_config):
return {
"vocabulary_size": transformers_config["vocab_size"],
"num_layers": transformers_config["n_layer"],
"num_heads": transformers_config["n_head"],
"hidden_dim": transformers_config["n_embd"],
"intermediate_dim": transformers_config["n_embd"] * 4,
"dropout": transformers_config["resid_pdrop"],
"max_sequence_length": transformers_config["n_positions"],
}


def convert_weights(backbone, loader, transformers_config):
# Embeddings
loader.port_weight(
keras_variable=backbone.token_embedding.embeddings,
hf_weight_key="wte.weight",
)
loader.port_weight(
keras_variable=backbone.position_embedding.position_embeddings,
hf_weight_key="wpe.weight",
)

# Attention blocks
for index in range(backbone.num_layers):
decoder_layer = backbone.transformer_layers[index]

# Norm layers
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer_norm.gamma,
hf_weight_key=f"h.{index}.ln_1.weight",
)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer_norm.beta,
hf_weight_key=f"h.{index}.ln_1.bias",
)
loader.port_weight(
keras_variable=decoder_layer._feedforward_layer_norm.gamma,
hf_weight_key=f"h.{index}.ln_2.weight",
)
loader.port_weight(
keras_variable=decoder_layer._feedforward_layer_norm.beta,
hf_weight_key=f"h.{index}.ln_2.bias",
)

# Attention layers
n_embd = transformers_config["n_embd"]

# Query
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.query_dense.kernel,
hf_weight_key=f"h.{index}.attn.c_attn.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[:, :n_embd], keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.query_dense.bias,
hf_weight_key=f"h.{index}.attn.c_attn.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[:n_embd], keras_shape
),
)

# Key
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.key_dense.kernel,
hf_weight_key=f"h.{index}.attn.c_attn.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[:, n_embd : 2 * n_embd], keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.key_dense.bias,
hf_weight_key=f"h.{index}.attn.c_attn.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[n_embd : 2 * n_embd], keras_shape
),
)

# Value
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.value_dense.kernel,
hf_weight_key=f"h.{index}.attn.c_attn.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[:, 2 * n_embd :], keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.value_dense.bias,
hf_weight_key=f"h.{index}.attn.c_attn.bias",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor[2 * n_embd :], keras_shape
),
)

# Output
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
hf_weight_key=f"h.{index}.attn.c_proj.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.output_dense.bias,
hf_weight_key=f"h.{index}.attn.c_proj.bias",
)

# MLP layers
loader.port_weight(
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
hf_weight_key=f"h.{index}.mlp.c_fc.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
hf_weight_key=f"h.{index}.mlp.c_fc.bias",
)
loader.port_weight(
keras_variable=decoder_layer._feedforward_output_dense.kernel,
hf_weight_key=f"h.{index}.mlp.c_proj.weight",
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
hf_tensor, keras_shape
),
)
loader.port_weight(
keras_variable=decoder_layer._feedforward_output_dense.bias,
hf_weight_key=f"h.{index}.mlp.c_proj.bias",
)

# Normalization
loader.port_weight(
keras_variable=backbone.layer_norm.gamma,
hf_weight_key="ln_f.weight",
)
loader.port_weight(
keras_variable=backbone.layer_norm.beta,
hf_weight_key="ln_f.bias",
)

return backbone


def load_gpt2_backbone(cls, preset, load_weights):
transformers_config = load_config(preset, HF_CONFIG_FILE)
keras_config = convert_backbone_config(transformers_config)
backbone = cls(**keras_config)
if load_weights:
jax_memory_cleanup(backbone)
with SafetensorLoader(preset) as loader:
convert_weights(backbone, loader, transformers_config)
return backbone


def load_gpt2_tokenizer(cls, preset):
vocab_file = get_file(preset, "vocab.json")
merges_file = get_file(preset, "merges.txt")
return cls(
vocabulary=vocab_file,
merges=merges_file,
)
27 changes: 27 additions & 0 deletions keras_nlp/src/utils/transformers/convert_gpt2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from keras_nlp.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.src.tests.test_case import TestCase


class TestTask(TestCase):
@pytest.mark.large
def test_convert_tiny_preset(self):
model = GPT2CausalLM.from_preset("hf://openai-community/gpt2")
prompt = "What is your favorite condiment?"
model.generate([prompt], max_length=15)

# TODO: compare numerics with huggingface model
34 changes: 32 additions & 2 deletions keras_nlp/src/utils/transformers/safetensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,41 @@ def __init__(self, preset):
else:
self.safetensor_config = None
self.safetensor_files = {}
self.prefix = None

def get_prefixed_key(self, hf_weight_key, dict_like):
"""
Determine and return a prefixed key for a given hf weight key.

This method checks if there's a common prefix for the weight keys and caches it
for future use.

Args:
hf_weight_key (str): The hf weight key to check for a prefix.
dict_like (object): An object to get keys of safetensor file using keys() method.

Returns:
str: The full key including the prefix (if any).
"""
if self.prefix is not None:
return self.prefix + hf_weight_key

for full_key in dict_like.keys():
if full_key.endswith(hf_weight_key) and full_key != hf_weight_key:
self.prefix = full_key[: -len(hf_weight_key)]
return full_key

self.prefix = ""
return hf_weight_key

def get_tensor(self, hf_weight_key):
if self.safetensor_config is None:
fname = SAFETENSOR_FILE
else:
fname = self.safetensor_config["weight_map"][hf_weight_key]
full_key = self.get_prefixed_key(
hf_weight_key, self.safetensor_config["weight_map"]
)
fname = self.safetensor_config["weight_map"][full_key]

if fname in self.safetensor_files:
file = self.safetensor_files[fname]
Expand All @@ -58,7 +87,8 @@ def get_tensor(self, hf_weight_key):
)
self.safetensor_files[fname] = file

return file.get_tensor(hf_weight_key)
full_key = self.get_prefixed_key(hf_weight_key, file)
return file.get_tensor(full_key)

def port_weight(self, keras_variable, hf_weight_key, hook_fn=None):
hf_tensor = self.get_tensor(hf_weight_key)
Expand Down