Skip to content

Conversion of gemma-2-2b-it model to TensorFlow Lite #5570

@shubham0204

Description

@shubham0204

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

None

OS Platform and Distribution

Google Colab (Linux) Ubuntu 22.04.3 LTS

MediaPipe Tasks SDK version

0.10.14

Task name (e.g. Image classification, Gesture recognition etc.)

LLM Inference

Programming Language and version (e.g. C++, Python, Java)

Python

Describe the actual behavior

The gemma-2-2b-it model must get converted to a TFLite model (for cpu)

Describe the expected behaviour

The converter.convert_checkpoint methods throws an AssertionError with no message

Standalone code/steps you may have used to try to get what you need

from huggingface_hub import hf_hub_download
import os
import mediapipe as mp
from mediapipe.tasks.python.genai import converter

REPO_ID = "google/gemma-2-2b-it"
FILENAMES = ["tokenizer.json", "tokenizer_config.json", "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]
os.environ['HF_TOKEN'] = "<token>"
for filename in FILENAMES:
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir="./gemma-2-2b-it")

config = converter.ConversionConfig(
    input_ckpt="/content/gemma-2-2b-it", 
    ckpt_format='safetensors', 
    model_type='GEMMA_2B', 
    backend="cpu", 
    output_dir="/content/intermediate/gemma-2-2b-it/", 
    combine_file_only=False, 
    vocab_model_file="/content/gemma-2-2b-it", 
    output_tflite_file="/content/converted_models/gemma-2-2b-it-cpu"
)
converter.convert_checkpoint(config)

Other info / Complete Logs

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-10-ae16540c09c6> in <cell line: 14>()
     12     output_tflite_file="/content/converted_models/gemma-2-2b-it-cpu"
     13 )
---> 14 converter.convert_checkpoint(config)

3 frames
/usr/local/lib/python3.10/dist-packages/mediapipe/tasks/python/genai/converter/quantization_util.py in quantize_tensor(var, axis, factor, sym, number_bits, use_fp, add_scale_eps, optimization_on_bound, p_value, per_channel, block_size)
    352   """
    353   # TODO: support jnp.float8_e5m2
--> 354   assert number_bits == 8 or number_bits == 4 , f"Number bits {number_bits}"
    355   jnp_var = jnp.asarray(var)
    356   # When using sub-channel, the contracting dim is split into a sub-channel

Metadata

Metadata

Assignees

Labels

os:linux-non-armIssues on linux distributions which run on x86-64 architecture. DOES NOT include ARM devices.platform:pythonMediaPipe Python issuestask:LLM inferenceIssues related to MediaPipe LLM Inference Gen AI setuptype:featureEnhancement in the New Functionality or Request for a New Solution

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions