From 11786fcfa5c2bb3cc73faaee7eaeef83376dd2b2 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:11:19 +0800 Subject: [PATCH] Add numeric check to PaliGemma2 conversion script --- .../convert_pali_gemma2_checkpoints.py | 109 ++++++++++++++++-- 1 file changed, 100 insertions(+), 9 deletions(-) diff --git a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py index 596fca6063..c992cc98f3 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py @@ -7,9 +7,14 @@ The `vocabulary.spm` is from here: https://www.kaggle.com/models/keras/paligemma/ +The official repo is here: +https://github.com/google-research/big_vision + Setup: ```shell +git clone --quiet --branch=main --depth=1 git@github.com:google-research/big_vision.git + pip install kaggle export KAGGLE_USERNAME=... export KAGGLE_KEY=... @@ -21,13 +26,15 @@ python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --weights_path ./path/to/weights.npz python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --proto_path ./path/to/vocabulary.spm -python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://divyasss/hongyu_sharing/keras/pali_gemma2_3b_pt_224 +python -m tools.checkpoint_conversion.convert_pali_gemma2_checkpoints --preset pali_gemma2_3b_pt_224 --upload_uri kaggle://keras/paligemma2/keras/pali_gemma2_3b_pt_224 ``` """ +import functools import io import os import pathlib +import sys os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["KERAS_BACKEND"] = "jax" @@ -224,7 +231,14 @@ def recover_dtype(a): def convert_tokenizer(proto_path): - return keras_hub.models.PaliGemmaTokenizer(proto=proto_path) + try: + tokenizer = keras_hub.models.PaliGemmaTokenizer(proto=proto_path) + except Exception: + raise FileNotFoundError( + f"There is no proto file at proto_path={proto_path}. You can " + "download it from https://www.kaggle.com/models/keras/paligemma/" + ) + return tokenizer def convert_image_converter(image_size): @@ -424,7 +438,13 @@ def convert_weights(keras_model, weights): return keras_model -def validate_output(keras_model, keras_tokenizer, keras_image_converter): +def validate_output( + preset, + keras_model, + keras_tokenizer, + keras_image_converter, + big_vision_weights_path, +): def read_image(url): contents = io.BytesIO(requests.get(url).content) image = PIL.Image.open(contents) @@ -437,7 +457,7 @@ def read_image(url): image = read_image( "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png" ) - prompt = "answer en where is the cow standing?\n" + prompt = "describe en\n" max_length = 32 preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor( tokenizer=keras_tokenizer, image_converter=keras_image_converter @@ -452,7 +472,72 @@ def read_image(url): print("🔶 Prompt:", prompt.replace("\n", "")) print("🔶 KerasHub output:", keras_output) - # TODO: Verify numerics with JAX model. + try: + # Try using the official `big_vision` repo to validate the decoded + # output. + sys.path.append("big_vision") + + import jax.numpy as jnp + import ml_collections + from big_vision.models.proj.paligemma import paligemma + from big_vision.trainers.proj.paligemma import predict_fns + + variant = preset.split("_")[2].lower() # 3b, 10b, 28b + if "b" not in variant: + raise ValueError("🔶 Failed to parse the variant from the `preset`") + gemma2_variant_mapping = {"3b": "2b", "10b": "9b", "28b": "27b"} + big_vision_config = ml_collections.FrozenConfigDict( + { + "llm": { + "variant": f"gemma2_{gemma2_variant_mapping.get(variant)}", + "vocab_size": 257_152, + }, + "img": { + "variant": "So400m/14", + "pool_type": "none", + "scan": True, + "dtype_mm": "bfloat16", + }, + } + ) + + big_vision_model = paligemma.Model(**big_vision_config) + big_vision_params = paligemma.load( + None, str(big_vision_weights_path), big_vision_config + ) + decode_fn = predict_fns.get_all(big_vision_model)["decode"] + decode = functools.partial( + decode_fn, + devices=jax.devices(), + eos_token=preprocessor.tokenizer.end_token_id, + ) + + preprocessed = preprocessor.generate_preprocess( + {"images": image, "prompts": prompt}, sequence_length=max_length + ) + images = jnp.expand_dims(preprocessed["images"], axis=0) + token_ids = jnp.expand_dims(preprocessed["token_ids"], axis=0) + big_vision_output = decode( + {"params": big_vision_params}, + { + "image": images, + "text": token_ids, + "mask_input": jnp.greater(token_ids, 0).astype("int32"), + "mask_ar": jnp.zeros_like(token_ids), + "_mask": jnp.array(True), + }, + max_decode_len=max_length, + ) + big_vision_output = big_vision_output[0] + big_vision_output = preprocessor.generate_postprocess( + { + "token_ids": big_vision_output, + "padding_mask": jnp.ones_like(big_vision_output).astype("bool"), + } + ) + print("🔶 big_vision output:", big_vision_output) + except Exception as e: + print(f"🔶 big_vision could not be run. Error: {e}") def main(_): @@ -464,7 +549,7 @@ def main(_): keras.config.set_floatx("bfloat16") if FLAGS.weights_path is not None: - weights_path = pathlib.Path(FLAGS.weights_path) + big_vision_weights_path = pathlib.Path(FLAGS.weights_path) else: presets = PRESET_MAP.keys() if preset not in presets: @@ -481,9 +566,9 @@ def main(_): f"Found too many files in {model_dir}. Expected only one file. " f"Recevied: {files}" ) - weights_path = files[0] + big_vision_weights_path = files[0] - weights = np.load(weights_path, allow_pickle=False) + weights = np.load(big_vision_weights_path, allow_pickle=False) weights = format_weights(weights) image_size = int(preset.split("_")[-1]) print("✅ JAX model weights loaded") @@ -497,7 +582,13 @@ def main(_): del weights print("✅ Weights converted") - validate_output(keras_model, keras_tokenizer, keras_image_converter) + validate_output( + preset, + keras_model, + keras_tokenizer, + keras_image_converter, + big_vision_weights_path, + ) print("✅ Output validated") keras_model.save_to_preset(preset)