Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_mix_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-mix-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_mix_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-224.npz \
--image_size=224 --checkpoint_name=pali_gemma_3b_224
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-448.npz \
--image_size=448 --checkpoint_name=pali_gemma_3b_448
python tools/checkpoint_conversion/convert_pali_gemma_checkpoints.py \
--weights_path=paligemma-3b-pt-896.npz \
--image_size=896 --checkpoint_name=pali_gemma_3b_896
"""

import argparse
import os

Expand All @@ -15,6 +33,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)

os.environ["KERAS_BACKEND"] = "jax"

Expand Down Expand Up @@ -308,15 +329,27 @@ def main(args):
pali_gemma_backbone_config = {
"vit_num_layers": 27,
"vit_hidden_dim": 1152,
"vocabulary_size": 257152,
"image_size": args.image_size,
"num_layers": 18,
"num_query_heads": 8,
"num_key_value_heads": 1,
"hidden_dim": 2048,
"intermediate_dim": 32768,
"head_dim": 256,
"vit_patch_size": 14,
"vit_num_heads": 16,
}
pg_image_converter = PaliGemmaImageConverter(
image_size=(args.image_size, args.image_size),
scale=1.0 / 127.5,
offset=-1,
)
tokenizer = PaliGemmaTokenizer(
proto="vocabulary.spm",
)
pg_presprocessor = PaliGemmaCausalLMPreprocessor(
image_converter=pg_image_converter
tokenizer=tokenizer, image_converter=pg_image_converter
)
pg_backbone = PaliGemmaBackbone(**pali_gemma_backbone_config)
keras_model = PaliGemmaCausalLM(
Expand All @@ -325,8 +358,10 @@ def main(args):
# This could be from kaggle or provide local dir path
weights = np.load(args.weights_path)
jax_weights = get_weights_as_numpy(weights, **pali_gemma_backbone_config)
keras_model = convert_pali_gemma_weights(
keras_model, jax_weights["params"], **pali_gemma_backbone_config
keras_model.backbone = convert_pali_gemma_weights(
keras_model.backbone,
jax_weights["params"],
**pali_gemma_backbone_config,
)
# Specify preset name
keras_model.save_to_preset(args.checkpoint_name)
Expand Down
Loading