# Export FunctionGemma to ONNX

This notebook converts the fine-tuned FunctionGemma model to ONNX format for use with Transformers.js in the browser.

**Based on:** [Google's ONNX conversion notebook](https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Demos/Emoji-Gemma-on-Web/resources/Convert_Gemma_3_270M_to_ONNX.ipynb)

## Prerequisites
- Complete the fine-tuning notebook first (`finetune_functiongemma.ipynb`)
- Model must be validated and uploaded to HuggingFace

## Steps:
1. Setup environment
2. Convert model to ONNX (multiple quantizations)
3. Test converted model
4. Upload to HuggingFace

## 1. Setup Environment

In [None]:
%pip install transformers==4.56.1 onnx==1.19.0 onnx_ir==0.1.7 onnxruntime==1.22.1 numpy==2.3.2 huggingface_hub

In [None]:
# Restart runtime after installing
# import os
# os.kill(os.getpid(), 9)

## 2. Hugging Face Authentication

In [None]:
import os
from google.colab import userdata
from huggingface_hub import login

hf_token = userdata.get('HF_TOKEN')
login(hf_token)

## 3. Download Conversion Script

The `build_gemma.py` script by Xenova handles the conversion and quantization specifically for Gemma 3 models.

In [None]:
!wget https://gist.githubusercontent.com/xenova/a219dbf3c7da7edd5dbb05f92410d7bd/raw/45f4c5a5227c1123efebe1e36d060672ee685a8e/build_gemma.py

print("Conversion script downloaded!")

## 4. Convert Model to ONNX

Specify your fine-tuned model from HuggingFace.

**Note:** We use FP16 quantization as it provides the best balance between size (~544MB) and accuracy. Q4 quantization causes model hallucinations with fine-tuned models.

In [None]:
# Your fine-tuned model from HuggingFace
MODEL_AUTHOR = "harlley"  #@param {type:"string"}
MODEL_NAME = "functiongemma-square-color"  #@param {type:"string"}

REPO_ID = f"{MODEL_AUTHOR}/{MODEL_NAME}"
SAVE_PATH = f"/content/{MODEL_NAME}-ONNX"

print(f"Converting model: {REPO_ID}")
print(f"Output path: {SAVE_PATH}")

In [None]:
# Run the conversion (FP16 only - Q4 causes hallucinations with fine-tuned models)
!python build_gemma.py \
    --model_name {REPO_ID} \
    --output {SAVE_PATH} \
    -p fp16

print(f"\nONNX model saved to {SAVE_PATH}")

## 5. Verify Output Structure

The output should have this structure:
```
model-ONNX/
├── config.json
├── tokenizer.json
├── tokenizer_config.json
└── onnx/
    ├── model_fp16.onnx
    └── model_fp16.onnx_data
```

In [None]:
import os

print("Output structure:")
print("=" * 50)

for root, dirs, files in os.walk(SAVE_PATH):
    level = root.replace(SAVE_PATH, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    subindent = ' ' * 2 * (level + 1)
    for file in files:
        filepath = os.path.join(root, file)
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"{subindent}{file} ({size_mb:.1f} MB)")

In [None]:
# Verify critical files exist at root
critical_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json']

print("Checking critical files at root...")
for f in critical_files:
    path = os.path.join(SAVE_PATH, f)
    exists = os.path.exists(path)
    status = "OK" if exists else "MISSING"
    print(f"  [{status}] {f}")

# Check ONNX files
onnx_dir = os.path.join(SAVE_PATH, 'onnx')
if os.path.exists(onnx_dir):
    onnx_files = [f for f in os.listdir(onnx_dir) if f.endswith('.onnx')]
    print(f"\nONNX files: {len(onnx_files)}")
    for f in onnx_files:
        print(f"  - {f}")
else:
    print("\n WARNING: onnx/ directory not found!")

## 6. Test Converted Model

Test the ONNX model using ONNX Runtime before uploading.

In [None]:
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
import onnxruntime
import numpy as np

# Load config and tokenizer
config = AutoConfig.from_pretrained(SAVE_PATH)
tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)

# FP16 model (best balance of size and accuracy)
MODEL_FILE = "onnx/model_fp16.onnx"

model_path = f"{SAVE_PATH}/{MODEL_FILE}"
print(f"Loading model: {model_path}")

decoder_session = onnxruntime.InferenceSession(model_path)
print("Model loaded!")

In [None]:
# System prompt for FunctionGemma
SYSTEM_PROMPT = """You are a model that can do function calling with the following functions

<start_function_declaration>
name:set_square_color
description:Sets the color of the square to a specified color
parameters:{color:{type:string,description:The color to set the square to,required:true}}
<end_function_declaration>
<start_function_declaration>
name:get_square_color
description:Gets the current color of the square
parameters:{}
<end_function_declaration>"""

def run_inference(user_input, max_new_tokens=64):
    """Run inference on the ONNX model."""
    # Config values
    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim
    num_hidden_layers = config.num_hidden_layers
    eos_token_id = tokenizer.eos_token_id
    
    # Get token IDs for stop sequence
    end_function_token_ids = tokenizer.encode("<end_function_call>", add_special_tokens=False)
    
    # Prepare inputs
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_input},
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="np"
    )
    
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    batch_size = input_ids.shape[0]
    
    # Use float16 for FP16 model
    kv_dtype = np.float16
    
    past_key_values = {
        f'past_key_values.{layer}.{kv}': np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=kv_dtype)
        for layer in range(num_hidden_layers)
        for kv in ('key', 'value')
    }
    
    position_ids = np.tile(np.arange(0, input_ids.shape[-1]), (batch_size, 1))
    
    # Generation loop
    generated_tokens = np.array([[]], dtype=np.int64)
    
    for i in range(max_new_tokens):
        logits, *present_key_values = decoder_session.run(None, dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            **past_key_values,
        ))
        
        # Get next token
        input_ids = logits[:, -1].argmax(-1, keepdims=True)
        attention_mask = np.concatenate([attention_mask, np.ones_like(input_ids, dtype=np.int64)], axis=-1)
        position_ids = position_ids[:, -1:] + 1
        
        # Update KV cache
        for j, key in enumerate(past_key_values):
            past_key_values[key] = present_key_values[j]
        
        generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
        
        # Stop at EOS or end_function_call
        if np.isin(input_ids, eos_token_id).any() or np.isin(input_ids, end_function_token_ids).any():
            break
    
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

In [None]:
# Test the ONNX model
import re

test_inputs = [
    "change the color to blue",
    "what color is the square?",
    "make it red",
    "tell me the current color",
    "set to green",
    "color?"
]

# Expected function call patterns
SET_COLOR_PATTERN = r"<start_function_call>call:set_square_color\{color:<escape>\w+<escape>\}<end_function_call>"
GET_COLOR_PATTERN = r"<start_function_call>call:get_square_color\{\}<end_function_call>"

print("=" * 70)
print("ONNX MODEL VALIDATION")
print(f"Model: {MODEL_FILE}")
print("=" * 70)

success_count = 0
for test_input in test_inputs:
    output = run_inference(test_input)
    
    # Check if output matches expected function call format
    is_set_color = re.search(SET_COLOR_PATTERN, output)
    is_get_color = re.search(GET_COLOR_PATTERN, output)
    is_valid = bool(is_set_color or is_get_color)
    
    status = "OK" if is_valid else "FAIL"
    if is_valid:
        success_count += 1
    
    print(f"\n[{status}] Input: {test_input}")
    print(f"      Output: {output}")

print("\n" + "=" * 70)
print(f"RESULT: {success_count}/{len(test_inputs)} passed")
print("=" * 70)

if success_count < len(test_inputs) * 0.8:
    print("\n WARNING: ONNX model is not generating correct function calls!")
    print("Check the fine-tuned model before exporting.")
else:
    print("\n ONNX model is working correctly!")

## 7. Upload to HuggingFace

In [None]:
import huggingface_hub
from huggingface_hub import whoami

username = whoami()['name']

ONNX_MODEL_NAME = "functiongemma-square-color-ONNX"  #@param {type:"string"}
HF_REPO_ID = f"{username}/{ONNX_MODEL_NAME}"

print(f"Uploading to: {HF_REPO_ID}")

# Create repo if needed
huggingface_hub.create_repo(HF_REPO_ID, exist_ok=True)

# Upload entire folder
repo_url = huggingface_hub.upload_folder(
    folder_path=SAVE_PATH,
    repo_id=HF_REPO_ID,
    repo_type="model",
    commit_message=f"Upload ONNX model for {ONNX_MODEL_NAME}"
)

print(f"\nUploaded to: {repo_url}")

## 8. Update Browser Code

After uploading, update `src/worker.ts` in your project:

```typescript
const MODEL_ID = "harlley/functiongemma-square-color-ONNX";

AutoModelForCausalLM.from_pretrained(MODEL_ID, {
  dtype: "fp16",  // FP16 for accurate function calling
  device: "webgpu",
  progress_callback: onProgress,
})
```

## Summary

Your FunctionGemma ONNX model is now ready for browser deployment!

**Files uploaded:**
- `config.json` - Model configuration
- `tokenizer.json` - Tokenizer
- `tokenizer_config.json` - Tokenizer config with chat template
- `onnx/model_fp16.onnx` - FP16 model (~544MB)

**Why FP16?**
- Q4 quantization causes hallucinations with fine-tuned models
- FP16 maintains accuracy while being half the size of FP32

**Next steps:**
1. Update `src/worker.ts` with your model ID and `dtype: "fp16"`
2. Run `npm run dev` to test
3. Verify function calls work in the browser