diff --git a/keras_nlp/src/utils/transformers/convert.py b/keras_nlp/src/utils/transformers/convert.py index 9befb880ea..522a534e93 100644 --- a/keras_nlp/src/utils/transformers/convert.py +++ b/keras_nlp/src/utils/transformers/convert.py @@ -36,6 +36,12 @@ from keras_nlp.src.utils.transformers.convert_llama3 import ( load_llama3_tokenizer, ) +from keras_nlp.src.utils.transformers.convert_mistral import ( + load_mistral_backbone, +) +from keras_nlp.src.utils.transformers.convert_mistral import ( + load_mistral_tokenizer, +) from keras_nlp.src.utils.transformers.convert_pali_gemma import ( load_pali_gemma_backbone, ) @@ -74,6 +80,8 @@ def load_transformers_backbone(cls, preset, load_weights): return load_albert_backbone(cls, preset, load_weights) if cls.__name__ == "BartBackbone": return load_bart_backbone(cls, preset, load_weights) + if cls.__name__ == "MistralBackbone": + return load_mistral_backbone(cls, preset, load_weights) raise ValueError( f"{cls} has not been ported from the Hugging Face format yet. " "Please check Hugging Face Hub for the Keras model. " @@ -109,6 +117,8 @@ def load_transformers_tokenizer(cls, preset): return load_albert_tokenizer(cls, preset) if cls.__name__ == "BartTokenizer": return load_bart_tokenizer(cls, preset) + if cls.__name__ == "MistralTokenizer": + return load_mistral_tokenizer(cls, preset) raise ValueError( f"{cls} has not been ported from the Hugging Face format yet. " "Please check Hugging Face Hub for the Keras model. " diff --git a/keras_nlp/src/utils/transformers/convert_mistral.py b/keras_nlp/src/utils/transformers/convert_mistral.py new file mode 100644 index 0000000000..5a8b989a4a --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_mistral.py @@ -0,0 +1,143 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import get_file +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "intermediate_dim": transformers_config["intermediate_size"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "rope_max_wavelength": transformers_config["rope_theta"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "sliding_window": transformers_config["sliding_window"], + } + + +def convert_weights(backbone, loader): + # Embeddings + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="model.embed_tokens.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + loader.port_weight( + keras_variable=backbone.token_embedding.reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + + # Attention blocks + for index in range(backbone.num_layers): + decoder_layer = backbone.transformer_layers[index] + + # Norm layers + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{index}.input_layernorm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + + # Attention layers + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.q_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.v_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.o_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.gate_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.up_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.down_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + + # Normalization + loader.port_weight( + keras_variable=backbone.layer_norm.scale, + hf_weight_key="model.norm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + + return backbone + + +def load_mistral_backbone(cls, preset, load_weights): + transformers_config = load_config(preset, HF_CONFIG_FILE) + keras_config = convert_backbone_config(transformers_config) + backbone = cls(**keras_config) + if load_weights: + jax_memory_cleanup(backbone) + with SafetensorLoader(preset) as loader: + convert_weights(backbone, loader) + return backbone + + +def load_mistral_tokenizer(cls, preset): + return cls(get_file(preset, "tokenizer.model")) diff --git a/keras_nlp/src/utils/transformers/convert_mistral_test.py b/keras_nlp/src/utils/transformers/convert_mistral_test.py new file mode 100644 index 0000000000..82ac9eccc4 --- /dev/null +++ b/keras_nlp/src/utils/transformers/convert_mistral_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.src.models.mistral.mistral_causal_lm import MistralCausalLM +from keras_nlp.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = MistralCausalLM.from_preset("hf://cosmo3769/tiny-mistral-test") + prompt = "What is your favorite condiment?" + model.generate([prompt], max_length=15) + + # TODO: compare numerics with huggingface model