Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table align="left">
  <td>
      <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/FunctionGemma/%5BFunctionGemma%5DFinetune_FunctionGemma_270M_for_Mobile_Actions_with_Tunix.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

# Finetune FunctionGemma 270M for Mobile Actions using Tunix

This notebook LoRA-finetunes FunctionGemma for the task of taking user request to perform mobile actions through the [Google Tunix](https://github.com/google/tunix) library on Google TPU v5e-1 (available on free-tier Google Colab). It is adapted from this [notebook](https://github.com/google-gemini/gemma-cookbook/blob/main/FunctionGemma/%5BFunctionGemma%5DFinetune_FunctionGemma_270M_for_Mobile_Actions_with_Hugging_Face.ipynb).

[Tunix](https://github.com/google/tunix) is a lightweight JAX-native LLM post-training library with scale and efficiency, supporting a wide range of techniques including SFT, LoRA, RL, DPO and etc. It is easy to use and works well in practice.

## Install dependencies

In [None]:
import importlib

if importlib.util.find_spec("tunix") is None:
  print("Required packages not found. Running full installation...")
  %pip install -q kagglehub
  %pip install -q safetensors
  %pip install -q tensorflow
  %pip install -q tensorflow_datasets
  %pip install -q tensorboardX
  %pip install -q transformers
  %pip install -q grain
  %pip install -q datasets
  %pip install -q wandb
  %pip install -q git+https://github.com/jax-ml/jax
  %pip install -q git+https://github.com/google/tunix
  %pip install -q git+https://github.com/google/qwix
  %pip uninstall -q flax -y
  %pip install -q git+https://github.com/google/flax
  %pip install -q 'numpy>2'
  %pip install -U transformers==4.57.1

## Restart Colab runtime

Restart Colab runtime for the newly-installed libraries to take effect.

## Imports

In [None]:
import os
import json
import logging
import re
import shutil
import functools
import pandas as pd
import numpy as np
import wandb
import jax
import jax.numpy as jnp
from flax import nnx
import optax
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import load_dataset
from transformers import AutoTokenizer

# Tunix imports
from tunix.models.gemma3 import params as gemma_params
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.generate import sampler as sampler_lib
from tunix.sft import peft_trainer
from tunix.sft import metrics_logger
from tunix.sft import utils
from tunix.generate import tokenizer_adapter as tokenizer_lib
import qwix



## Setup

Set up some configs and constants.

In [None]:
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"
OUTPUT_DIR = os.path.abspath("./mobile-actions-tunix")
BATCH_SIZE = 2 # Set to 8 or bigger on TPU v6e-1
NUM_EPOCHS = 1
LEARNING_RATE = 1e-4
MAX_LENGTH = 1024
EVAL_EVERY_N_STEPS = 50
LORA_RANK = 8
LORA_ALPHA = 16

# os.environ['XLA_FLAGS'] = "--xla_cpu_multi_thread_eigen=false" # --xla_interpreter_thread_pool_size=1"



from google.colab import userdata

if userdata.get('WANDB_API_KEY'):
    wandb.login(key=userdata.get('WANDB_API_KEY'))
    wandb.init(project="functiongemma-mobile-actions")
else:
    logger.warning("WANDB_API_KEY not found. Initializing wandb in disabled mode.")
    wandb.init(mode="disabled")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Helper function to extract function calls

In [None]:
def extract_function_call(model_output):
    results = []
    call_pattern = r"<start_function_call>(.*?)<end_function_call>"
    raw_calls = re.findall(call_pattern, model_output, re.DOTALL)

    for raw_call in raw_calls:
        if not raw_call.strip().startswith("call:"):
            continue
        try:
            pre_brace, args_segment = raw_call.split("{", 1)
            function_name = pre_brace.replace("call:", "").strip()
            args_content = args_segment.strip()
            if args_content.endswith("}"):
                args_content = args_content[:-1]
            arguments = {}
            arg_pattern = r"(?P<key>[^:,]*?):<escape>(?P<value>.*?)<escape>"
            arg_matches = re.finditer(arg_pattern, args_content, re.DOTALL)
            for match in arg_matches:
                key = match.group("key").strip()
                value = match.group("value")
                arguments[key] = value
            if function_name:
                results.append({"function": {"name": function_name, "arguments": arguments}})
        except ValueError:
            continue
    return results

# Download the dataset

In [None]:
print("Downloading model and dataset...")
local_model_path = snapshot_download(repo_id=MODEL_ID, ignore_patterns=["*.pth"])
data_file = hf_hub_download(repo_id=DATASET_ID, filename="dataset.jsonl", repo_type="dataset")
dataset = load_dataset("text", data_files=data_file, encoding="utf-8")["train"].shuffle(seed=42)

train_data = dataset.filter(lambda x: json.loads(x['text'])['metadata'] == 'train')
full_eval = dataset.filter(lambda x: json.loads(x['text'])['metadata'] == 'eval')
eval_data_for_acc = full_eval
val_data_for_loss = full_eval

Downloading model and dataset...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

## Prepare tokenizer, model and sampler

In [None]:
tokenizer = AutoTokenizer.from_pretrained(local_model_path, fix_mistral_regex=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

def get_model_config(config_path):
    config = gemma_lib.ModelConfig.gemma3_270m()
    return config

config_path = os.path.join(local_model_path, "config.json")
model_config = get_model_config(config_path)

NUM_TPUS = len(jax.devices())
MESH = [(1, NUM_TPUS), ("fsdp", "tp")] if NUM_TPUS > 1 else [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))

with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(local_model_path, model_config, mesh)
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=LORA_RANK, alpha=LORA_ALPHA,
    )
    model_input = base_model.get_model_input()
    model = qwix.apply_lora_to_model(base_model, lora_provider, rngs=nnx.Rngs(0), **model_input)
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)
    print("LoRA applied and sharded.")

sampler = sampler_lib.Sampler(
    transformer=model, tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(cache_size=4096, num_layers=model_config.num_layers, num_kv_heads=model_config.num_kv_heads, head_dim=model_config.head_dim)
)
STOP_IDS = [1, 106, 50, tokenizer.eos_token_id]

LoRA applied and sharded.


## Run a single sample and full evalation before finetune

We run a single sample to see how FunctionGemma performs. It is doing OK but calling the right tool, but the body parameter is not very good ('<escape>The body of the email.<escape>').

In [None]:
def run_sample(dataset, label):
    print(f"\n--- {label} ---")
    idx = 42
    sample_text = dataset[idx]['text']
    template_inputs = json.loads(sample_text)

    prompt = tokenizer.apply_chat_template(
        template_inputs['messages'][:-1],
        tools=template_inputs['tools'],
        tokenize=False,
        add_generation_prompt=True)

    prompt_and_completion = tokenizer.apply_chat_template(
        template_inputs['messages'],
        tools=template_inputs['tools'],
        tokenize=False,
        add_generation_prompt=False)

    expected_output = prompt_and_completion[len(prompt):]

    print(f"\n\033[1mInput prompt\033[0m   : {prompt}")
    print(f"\n\033[1mExpected output\033[0m: {expected_output}")

    # Generate the output
    out = sampler([prompt], max_generation_steps=MAX_LENGTH, eos_tokens=STOP_IDS)
    actual_output = out.text[0]
    print(f"\n\033[1mActual output\033[0m  : {actual_output}\n")

run_sample(train_data, "Pre-train Demo")


--- Pre-train Demo ---

[1mInput prompt[0m   : <bos><start_of_turn>developer
Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-03-16T02:02:17
Day of week is Saturday
You are a model that can do function calling with the following functions<start_function_declaration>declaration:turn_on_flashlight{description:<escape>Turns the flashlight on.<escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:show_map{description:<escape>Shows a location on the map.<escape>,parameters:{properties:{query:{description:<escape>The location to search for. May be the name of a place, a business, or an address.<escape>,type:<escape>STRING<escape>}},required:[<escape>query<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:turn_off_flashlight{description:<escape>Turns the flashlight off.<escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_de

We also run the full evaluation on the eval dataset. FunctionGemma achieves ~65% accurary. Not bad, but we can do better.

In [None]:
def run_eval(data_subset, label):
    print(f"--- {label} ---")
    correct_count, total_count = 0, 0
    for i, example in enumerate(data_subset):
        orig_data = json.loads(example['text'])
        messages = orig_data['messages']
        prompt = tokenizer.apply_chat_template(messages[:-1], tools=orig_data['tools'], tokenize=False, add_generation_prompt=True)
        try:
            out = sampler([prompt], max_generation_steps=MAX_LENGTH, eos_tokens=STOP_IDS)
            model_output = out.text[0]
        except Exception as e:
            print(f"Error: {e}")
            continue
        output_fc = extract_function_call(model_output)
        target_fc = messages[2].get('tool_calls', [])
        target_names = [fc['function']['name'] for fc in target_fc]
        output_names = [fc['function']['name'] for fc in output_fc]
        target_args = [dict(sorted(fc['function']['arguments'].items())) for fc in target_fc]
        output_args = [dict(sorted(fc['function']['arguments'].items())) for fc in output_fc]
        if (target_names == output_names) and (target_args == output_args):
            correct_count += 1
        total_count += 1
        if (i+1) % 50 == 0:
            print(f"Processed {i+1}/{len(data_subset)} - Accuracy: {correct_count/total_count:.2%}")
    acc = correct_count/total_count if total_count > 0 else 0
    print(f"Final {label} Accuracy: {acc:.2%}")
    return acc

run_eval(eval_data_for_acc, "Pre-train Eval")

--- Pre-train Eval ---
Processed 50/961 - Accuracy: 64.00%
Processed 100/961 - Accuracy: 65.00%
Processed 150/961 - Accuracy: 62.67%
Processed 200/961 - Accuracy: 60.50%
Processed 250/961 - Accuracy: 61.60%
Processed 300/961 - Accuracy: 62.67%
Processed 350/961 - Accuracy: 62.86%
Processed 400/961 - Accuracy: 63.75%
Processed 450/961 - Accuracy: 64.22%
Processed 500/961 - Accuracy: 64.40%
Processed 550/961 - Accuracy: 64.55%
Processed 600/961 - Accuracy: 64.17%
Processed 650/961 - Accuracy: 63.69%
Processed 700/961 - Accuracy: 63.86%
Processed 750/961 - Accuracy: 64.53%
Processed 800/961 - Accuracy: 64.12%
Processed 850/961 - Accuracy: 64.47%
Processed 900/961 - Accuracy: 64.78%
Processed 950/961 - Accuracy: 64.21%
Final Pre-train Eval Accuracy: 64.00%


0.6399583766909469

## Finetune the model

Tunix has certain expectations on the input data, so we create a `CustomDataset` for Tunix and prepare the dataset accordingly.

In [None]:
class CustomDataset:
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self): return len(self.data)

    def __iter__(self):
        for item in self.data:
            template_inputs = json.loads(item['text'])
            prompt_and_completion = self.tokenizer.apply_chat_template(
                template_inputs['messages'], tools=template_inputs['tools'], tokenize=False, add_generation_prompt=False
            )
            prompt_only = self.tokenizer.apply_chat_template(
                template_inputs['messages'][:-1], tools=template_inputs['tools'], tokenize=False, add_generation_prompt=True
            )

            tokenized_full = self.tokenizer(prompt_and_completion, add_special_tokens=False)
            tokenized_prompt = self.tokenizer(prompt_only, add_special_tokens=False)

            full_ids = tokenized_full['input_ids']
            prompt_len = len(tokenized_prompt['input_ids'])

            if len(full_ids) > self.max_length:
                full_ids = full_ids[:self.max_length]

            input_tokens = np.full((self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32)
            input_tokens[:len(full_ids)] = full_ids

            input_mask = np.zeros((self.max_length,), dtype=np.int32)
            if len(full_ids) > prompt_len:
                mask_end = min(len(full_ids), self.max_length)
                input_mask[prompt_len:mask_end] = 1

            yield peft_trainer.TrainingInput(
                input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
                input_mask=jnp.array(input_mask, dtype=jnp.int32)
            )

def data_generator(split_data, batch_size):
    dataset_obj = CustomDataset(split_data, tokenizer, MAX_LENGTH)
    batch_tokens, batch_masks = [], []
    for item in dataset_obj:
        batch_tokens.append(item.input_tokens)
        batch_masks.append(item.input_mask)
        if len(batch_tokens) == batch_size:
            yield peft_trainer.TrainingInput(input_tokens=jnp.array(np.stack(batch_tokens)), input_mask=jnp.array(np.stack(batch_masks)))
            batch_tokens, batch_masks = [], []

print("Preparing training data...")
train_batches = list(data_generator(train_data, BATCH_SIZE))
val_batches = list(data_generator(val_data_for_loss, BATCH_SIZE))

Preparing training data...


Now we kick off the finetuning. Tunix integrates seamlessly with TensorBoard and Weight and Biases, so that we can visualize the training progress.

In [None]:
def gen_model_input_fn(x: peft_trainer.TrainingInput):
    pad_mask = x.input_tokens != tokenizer.pad_token_id
    positions = utils.build_positions_from_mask(pad_mask)
    attention_mask = utils.make_causal_attn_mask(pad_mask)
    return {'input_tokens': x.input_tokens, 'input_mask': x.input_mask, 'positions': positions, 'attention_mask': attention_mask}

print("Starting Training...")
max_steps = len(train_batches) * NUM_EPOCHS
lr_schedule = optax.cosine_decay_schedule(init_value=LEARNING_RATE, decay_steps=max_steps)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir=os.path.join(OUTPUT_DIR, "logs"), flush_every_n_steps=10
)
training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=max_steps,
    checkpoint_root_directory=os.path.join(OUTPUT_DIR, "ckpts"),
    metrics_logging_options=metrics_logging_options,
)
trainer = peft_trainer.PeftTrainer(model, optax.adamw(lr_schedule), training_config).with_gen_model_input_fn(gen_model_input_fn)

with mesh:
    trainer.train(train_batches, val_batches)
print("Training Complete.")

Starting Training...




Training:   0%|          | 0/4346 [00:00<?, ?step/s]



0,1
eval/loss,█▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/perplexity,█▇▇▅▆▄▃▄▃▃▃▅▃▃▂▂▂▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/step_time_sec,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/steps_per_sec,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
jax/checkpoint/write/blocking_gbytes_per_sec,▁
jax/checkpoint/write/gbytes,▁
jax/checkpoint/write/gbytes_per_sec,▁
jax/core/compile/backend_compile_duration,▁
jax/core/compile/jaxpr_to_mlir_module_duration,▁
jax/core/compile/jaxpr_trace_duration,▁

0,1
eval/loss,0.01168
eval/perplexity,1.01174
eval/step_time_sec,0
eval/steps_per_sec,1000000000.0
jax/checkpoint/write/blocking_gbytes_per_sec,0.00366
jax/checkpoint/write/gbytes,0.00584
jax/checkpoint/write/gbytes_per_sec,0.00283
jax/core/compile/backend_compile_duration,1768267480.1987
jax/core/compile/jaxpr_to_mlir_module_duration,1768267479.54601
jax/core/compile/jaxpr_trace_duration,1768267606.53934


Training Complete.


## Post-train evaluation

Now we run the same test sample gain and this time the response is better (body is now 'Don't forget to finalize your quarterly goals before the meeting.').

In [None]:
# Re-initialize wandb in disabled mode if the trainer finished the run
# This prevents "You must call wandb.init() before wandb.log()" errors in run_eval
if wandb.run is None:
    wandb.init(mode="disabled")

run_sample(train_data, "Post-train Demo")


--- Post-train Demo ---

[1mInput prompt[0m   : <bos><start_of_turn>developer
Current date and time given in YYYY-MM-DDTHH:MM:SS format: 2024-03-16T02:02:17
Day of week is Saturday
You are a model that can do function calling with the following functions<start_function_declaration>declaration:turn_on_flashlight{description:<escape>Turns the flashlight on.<escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:show_map{description:<escape>Shows a location on the map.<escape>,parameters:{properties:{query:{description:<escape>The location to search for. May be the name of a place, a business, or an address.<escape>,type:<escape>STRING<escape>}},required:[<escape>query<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:turn_off_flashlight{description:<escape>Turns the flashlight off.<escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_d

And the accuracy reaches ~88% after just one epoch of finetune.

In [None]:
run_eval(eval_data_for_acc, "Post-train Eval")

--- Post-train Eval ---
Processed 50/961 - Accuracy: 84.00%
Processed 100/961 - Accuracy: 87.00%
Processed 150/961 - Accuracy: 86.67%
Processed 200/961 - Accuracy: 85.50%
Processed 250/961 - Accuracy: 85.60%
Processed 300/961 - Accuracy: 87.33%
Processed 350/961 - Accuracy: 86.86%
Processed 400/961 - Accuracy: 87.25%
Processed 450/961 - Accuracy: 87.56%
Processed 500/961 - Accuracy: 88.40%
Processed 550/961 - Accuracy: 88.00%
Processed 600/961 - Accuracy: 88.33%
Processed 650/961 - Accuracy: 88.00%
Processed 700/961 - Accuracy: 88.29%
Processed 750/961 - Accuracy: 88.93%
Processed 800/961 - Accuracy: 89.12%
Processed 850/961 - Accuracy: 89.41%
Processed 900/961 - Accuracy: 89.44%
Processed 950/961 - Accuracy: 89.05%
Final Post-train Eval Accuracy: 89.18%


0.8917793964620188

## Export to safetensors

After finetuning we can merge the addapter and export the model to safetensors, which allows us to convert for ODML deployment.

In [None]:
merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
print(f"Saving merged LoRA model to {merged_output_dir}")
gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=merged_output_dir,
    lora_model=model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)
print("Model Exported Successfully.")

Saving merged LoRA model to /content/mobile-actions-tunix/merged
Model Exported Successfully.


## Summary

Congratulation! You have finetuned the FunctionGemma model successfully.