# Fine-tune FunctionGemma for Square Color Function Calling

This notebook fine-tunes FunctionGemma for function calling using QLoRA.

**Base Model:** `google/functiongemma-270m-it`

**Based on:** [Google's Emoji Gemma notebook](https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Demos/Emoji-Gemma-on-Web/resources/Fine_tune_Gemma_3_270M_for_emoji_generation.ipynb)

## Steps:
1. Setup environment
2. Load the square color dataset
3. Format data for FunctionGemma
4. Fine-tune with QLoRA
5. **Validate model outputs** (CRITICAL)
6. Save merged model

## 1. Setup Environment

In [None]:
%pip install torch tensorboard
%pip install -U transformers trl datasets accelerate evaluate sentencepiece bitsandbytes protobuf==3.20.3 peft

In [None]:
# Restart runtime after installing packages
# import os
# os.kill(os.getpid(), 9)

## 2. Hugging Face Authentication

1. Accept license on [model page](http://huggingface.co/google/functiongemma-270m-it)
2. Get [Access Token](https://huggingface.co/settings/tokens) with 'Write' access
3. Create Colab secret: `HF_TOKEN`

In [None]:
from google.colab import userdata
from huggingface_hub import login

hf_token = userdata.get('HF_TOKEN')
login(hf_token)

## 3. Load Dataset

Upload the `square_color_dataset.json` file to Colab or mount Google Drive.

In [None]:
import json

# Option 1: Upload file to Colab
# from google.colab import files
# uploaded = files.upload()

# Option 2: Use file path directly
DATASET_PATH = "/content/square_color_dataset.json"  #@param {type:"string"}

with open(DATASET_PATH, 'r') as f:
    raw_dataset = json.load(f)

print(f"Loaded {len(raw_dataset)} examples")
print(f"\nExample entry:")
print(json.dumps(raw_dataset[0], indent=2))

## 4. Define FunctionGemma Format

FunctionGemma uses specific tokens for function calling:
- `<start_function_call>` / `<end_function_call>`
- `<escape>` to wrap string values

Format: `<start_function_call>call:function_name{param:<escape>value<escape>}<end_function_call>`

In [None]:
# System prompt required by FunctionGemma
SYSTEM_PROMPT = """You are a model that can do function calling with the following functions

<start_function_declaration>
name:set_square_color
description:Sets the color of the square to a specified color
parameters:{color:{type:string,description:The color to set the square to,required:true}}
<end_function_declaration>
<start_function_declaration>
name:get_square_color
description:Gets the current color of the square
parameters:{}
<end_function_declaration>"""

def format_function_call_output(tool_name: str, tool_arguments: str) -> str:
    """Format the expected output in FunctionGemma format."""
    if tool_name == "set_square_color":
        args = json.loads(tool_arguments)
        color = args.get("color", "")
        return f"<start_function_call>call:set_square_color{{color:<escape>{color}<escape>}}<end_function_call>"
    elif tool_name == "get_square_color":
        return "<start_function_call>call:get_square_color{}<end_function_call>"
    else:
        raise ValueError(f"Unknown tool: {tool_name}")

# Test the format
test_output = format_function_call_output("set_square_color", '{"color": "blue"}')
print(f"Example output format: {test_output}")

## 5. Format Training Dataset

In [None]:
from datasets import Dataset

def format_sample(sample):
    """Convert raw sample to training format."""
    expected_output = format_function_call_output(
        sample["tool_name"],
        sample["tool_arguments"]
    )
    
    return {
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": sample["user_content"]},
            {"role": "assistant", "content": expected_output}
        ]
    }

# Convert to HuggingFace Dataset
formatted_data = [format_sample(s) for s in raw_dataset]
dataset = Dataset.from_list(formatted_data)

# Split into train/test
dataset_splits = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42)

print(f"Training samples: {len(dataset_splits['train'])}")
print(f"Test samples: {len(dataset_splits['test'])}")
print(f"\nExample formatted sample:")
print(json.dumps(dataset_splits['train'][0], indent=2))

## 6. Load Base Model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Use the FunctionGemma model (specialized for function calling)
BASE_MODEL = "google/functiongemma-270m-it"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    attn_implementation="eager",
    torch_dtype=torch.bfloat16
)

print(f"Model loaded on: {base_model.device}")
print(f"Model dtype: {base_model.dtype}")

## 7. Test Base Model (Before Fine-tuning)

In [None]:
from transformers import pipeline

pipe = pipeline("text-generation", model=base_model, tokenizer=tokenizer)

test_inputs = [
    "change the color to blue",
    "what color is the square?",
    "make it red",
    "tell me the current color"
]

print("=" * 60)
print("BASE MODEL OUTPUT (before fine-tuning)")
print("=" * 60)

for test_input in test_inputs:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": test_input}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = pipe(prompt, max_new_tokens=64, do_sample=False)
    model_output = output[0]['generated_text'][len(prompt):].strip()
    
    print(f"\nInput: {test_input}")
    print(f"Output: {model_output}")

## 8. Configure QLoRA Fine-tuning

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig

ADAPTER_PATH = "/content/functiongemma-adapters"

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# LoRA config for parameter-efficient fine-tuning
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"]
)

# Training config
training_args = SFTConfig(
    output_dir=ADAPTER_PATH,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    lr_scheduler_type="constant",
    max_length=512,  # Longer for function declarations
    gradient_checkpointing=False,
    packing=False,
    optim="adamw_torch_fused",
    report_to="tensorboard",
    weight_decay=0.01,
)

# Reload model with quantization
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation='eager'
)
model.config.pad_token_id = tokenizer.pad_token_id

print("Training configured!")

## 9. Train the Model

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_splits['train'],
    eval_dataset=dataset_splits['test'],
    peft_config=lora_config,
)

print("Starting training...")
trainer.train()
trainer.save_model(ADAPTER_PATH)

print(f"\nLoRA adapters saved to {ADAPTER_PATH}")

## 10. Plot Training Results

In [None]:
import matplotlib.pyplot as plt

log_history = trainer.state.log_history

train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

plt.figure(figsize=(10, 6))
plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("FunctionGemma Fine-tuning: Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

## 11. Merge LoRA Adapters

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

MERGED_MODEL_PATH = "/content/functiongemma-merged"

# Load base model (without quantization for merging)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)

# Load and merge LoRA adapters
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model = model.merge_and_unload()

# Save merged model
model.save_pretrained(MERGED_MODEL_PATH)
tokenizer.save_pretrained(MERGED_MODEL_PATH)

print(f"Merged model saved to {MERGED_MODEL_PATH}")
print(f"Vocabulary size: {model.config.vocab_size}")

## 12. CRITICAL: Validate Fine-tuned Model

**IMPORTANT**: Test the model BEFORE exporting to ONNX!

If it generates garbage here, the problem is in the fine-tuning, not the ONNX export.

In [None]:
from transformers import pipeline

# Load the merged model
merged_model = AutoModelForCausalLM.from_pretrained(MERGED_MODEL_PATH, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH)
pipe = pipeline("text-generation", model=merged_model, tokenizer=tokenizer)

test_inputs = [
    "change the color to blue",
    "what color is the square?",
    "make it red",
    "tell me the current color",
    "set to green",
    "color?",
    "I want purple",
    "gimme yellow"
]

print("=" * 70)
print("FINE-TUNED MODEL VALIDATION")
print("=" * 70)

success_count = 0
for test_input in test_inputs:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": test_input}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = pipe(prompt, max_new_tokens=64, do_sample=False)
    model_output = output[0]['generated_text'][len(prompt):].strip()
    
    # Check if output contains expected function call format
    is_valid = "<start_function_call>" in model_output and "<end_function_call>" in model_output
    status = "OK" if is_valid else "FAIL"
    if is_valid:
        success_count += 1
    
    print(f"\n[{status}] Input: {test_input}")
    print(f"      Output: {model_output}")

print("\n" + "=" * 70)
print(f"VALIDATION RESULT: {success_count}/{len(test_inputs)} passed")
print("=" * 70)

if success_count < len(test_inputs) * 0.8:
    print("\n WARNING: Model is not generating correct function calls!")
    print("Do NOT proceed to ONNX export until this is fixed.")
else:
    print("\n Model is generating correct function calls. Ready for ONNX export!")

## 13. Upload to Hugging Face Hub

In [None]:
from huggingface_hub import ModelCard, whoami

MODEL_NAME = "functiongemma-square-color"  #@param {type:"string"}

username = whoami()['name']
hf_repo_id = f"{username}/{MODEL_NAME}"

# Push model and tokenizer
repo_url = merged_model.push_to_hub(hf_repo_id, create_repo=True, commit_message="Upload fine-tuned FunctionGemma")
tokenizer.push_to_hub(hf_repo_id)

# Create model card
card_content = f"""
---
base_model: google/functiongemma-270m-it
tags:
- text-generation
- function-calling
- gemma
---

# FunctionGemma Square Color

A fine-tuned FunctionGemma model for square color function calling.

## Functions
- `set_square_color(color: string)` - Sets the square color
- `get_square_color()` - Gets the current color

## Usage
```python
from transformers import pipeline

pipe = pipeline("text-generation", model="{hf_repo_id}")
output = pipe("change the color to blue")
```
"""
card = ModelCard(card_content)
card.push_to_hub(hf_repo_id)

print(f"Model uploaded to: {repo_url}")

## Next Steps

If the model validation passed, proceed to:

1. **Export to ONNX** using `export_to_onnx.ipynb`
2. Upload ONNX model to HuggingFace
3. Test in the browser with Transformers.js