Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
The `vocabulary.spm` is from here:
https://www.kaggle.com/models/keras/paligemma/

The official repo is here:
https://github.com/google-research/big_vision

Setup:

```shell
git clone --quiet --branch=main --depth=1 git@github.com:google-research/big_vision.git

pip install kaggle
export KAGGLE_USERNAME=...
export KAGGLE_KEY=...
Expand All @@ -21,13 +26,15 @@
python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224
python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --weights_path ./path/to/weights.npz
python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --proto_path ./path/to/vocabulary.spm
python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://divyasss/hongyu_sharing/keras/pali_gemma2_3b_pt_224
python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224
```
"""

import functools
import io
import os
import pathlib
import sys

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["KERAS_BACKEND"] = "jax"
Expand Down Expand Up @@ -224,7 +231,14 @@ def recover_dtype(a):


def convert_tokenizer(proto_path):
return keras_hub.models.PaliGemmaTokenizer(proto=proto_path)
try:
tokenizer = keras_hub.models.PaliGemmaTokenizer(proto=proto_path)
except Exception:
raise FileNotFoundError(
f"There is no proto file at proto_path={proto_path}. You can "
"download it from https://www.kaggle.com/models/keras/paligemma/"
)
return tokenizer


def convert_image_converter(image_size):
Expand Down Expand Up @@ -424,7 +438,13 @@ def convert_weights(keras_model, weights):
return keras_model


def validate_output(keras_model, keras_tokenizer, keras_image_converter):
def validate_output(
preset,
keras_model,
keras_tokenizer,
keras_image_converter,
big_vision_weights_path,
):
def read_image(url):
contents = io.BytesIO(requests.get(url).content)
image = PIL.Image.open(contents)
Expand All @@ -437,7 +457,7 @@ def read_image(url):
image = read_image(
"https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png"
)
prompt = "answer en where is the cow standing?\n"
prompt = "describe en\n"
max_length = 32
preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor(
tokenizer=keras_tokenizer, image_converter=keras_image_converter
Expand All @@ -452,7 +472,72 @@ def read_image(url):
print("🔶 Prompt:", prompt.replace("\n", ""))
print("🔶 KerasHub output:", keras_output)

# TODO: Verify numerics with JAX model.
try:
# Try using the official `big_vision` repo to validate the decoded
# output.
sys.path.append("big_vision")

import jax.numpy as jnp
import ml_collections
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

variant = preset.split("_")[2].lower() # 3b, 10b, 28b
if "b" not in variant:
raise ValueError("🔶 Failed to parse the variant from the `preset`")
gemma2_variant_mapping = {"3b": "2b", "10b": "9b", "28b": "27b"}
big_vision_config = ml_collections.FrozenConfigDict(
{
"llm": {
"variant": f"gemma2_{gemma2_variant_mapping.get(variant)}",
"vocab_size": 257_152,
},
"img": {
"variant": "So400m/14",
"pool_type": "none",
"scan": True,
"dtype_mm": "bfloat16",
},
}
)

big_vision_model = paligemma.Model(**big_vision_config)
big_vision_params = paligemma.load(
None, str(big_vision_weights_path), big_vision_config
)
decode_fn = predict_fns.get_all(big_vision_model)["decode"]
decode = functools.partial(
decode_fn,
devices=jax.devices(),
eos_token=preprocessor.tokenizer.end_token_id,
)

preprocessed = preprocessor.generate_preprocess(
{"images": image, "prompts": prompt}, sequence_length=max_length
)
images = jnp.expand_dims(preprocessed["images"], axis=0)
token_ids = jnp.expand_dims(preprocessed["token_ids"], axis=0)
big_vision_output = decode(
{"params": big_vision_params},
{
"image": images,
"text": token_ids,
"mask_input": jnp.greater(token_ids, 0).astype("int32"),
"mask_ar": jnp.zeros_like(token_ids),
"_mask": jnp.array(True),
},
max_decode_len=max_length,
)
big_vision_output = big_vision_output[0]
big_vision_output = preprocessor.generate_postprocess(
{
"token_ids": big_vision_output,
"padding_mask": jnp.ones_like(big_vision_output).astype("bool"),
}
)
print("🔶 big_vision output:", big_vision_output)
except Exception as e:
print(f"🔶 big_vision could not be run. Error: {e}")


def main(_):
Expand All @@ -464,7 +549,7 @@ def main(_):
keras.config.set_floatx("bfloat16")

if FLAGS.weights_path is not None:
weights_path = pathlib.Path(FLAGS.weights_path)
big_vision_weights_path = pathlib.Path(FLAGS.weights_path)
else:
presets = PRESET_MAP.keys()
if preset not in presets:
Expand All @@ -481,9 +566,9 @@ def main(_):
f"Found too many files in {model_dir}. Expected only one file. "
f"Recevied: {files}"
)
weights_path = files[0]
big_vision_weights_path = files[0]

weights = np.load(weights_path, allow_pickle=False)
weights = np.load(big_vision_weights_path, allow_pickle=False)
weights = format_weights(weights)
image_size = int(preset.split("_")[-1])
print("✅ JAX model weights loaded")
Expand All @@ -497,7 +582,13 @@ def main(_):
del weights
print("✅ Weights converted")

validate_output(keras_model, keras_tokenizer, keras_image_converter)
validate_output(
preset,
keras_model,
keras_tokenizer,
keras_image_converter,
big_vision_weights_path,
)
print("✅ Output validated")

keras_model.save_to_preset(preset)
Expand Down
Loading