# FunctionGemma Fine-tuning


In [1]:
# =============================================================================
# Install dependencies (use Colab's pre-installed torch)
# =============================================================================
# Don't reinstall torch - Colab has optimized version pre-installed

!pip install -q transformers==4.57.3 datasets accelerate evaluate trl==0.26.2 protobuf sentencepiece
!pip install -q huggingface_hub tensorboard

print("\nDependencies installed!")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h
Dependencies installed!


## 2. HuggingFace Authentication


In [2]:
# =============================================================================
# Authenticate with HuggingFace using Colab Secrets
# =============================================================================
from google.colab import userdata
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## 3. Define Tools and Prepare Dataset


In [5]:
import json
from datasets import Dataset
from transformers.utils import get_json_schema

# =============================================================================
# STEP 4.1: Define Python functions for JSON Schema generation
# =============================================================================
# These functions are NOT executed - they're only used for JSON Schema generation.
# get_json_schema() reads the function name, docstring, and type hints,
# and creates a JSON Schema in OpenAI function calling format.
def pagamento(valor: float, metodo_pagamento: str):
    """
    Tool de pagamento

    Args:
        valor: valor numérico da transação em reais
        metodo_pagamento: método utilizado para pagamento (ex: "pix", "debito", "credito")

    Returns:
        dict: Dicionário contendo mensagem de confirmação e dados de pagamento
    """
    return {
        "mensagem": "Pagamento realizado com sucesso",
        "dados": {
            "valor": valor,
            "metodo_pagamento": metodo_pagamento
        }
    }


TOOLS = [
    get_json_schema(pagamento)
]

print("Tools defined:")
for tool in TOOLS:
    print(f"   - {tool['function']['name']}: {tool['function']['description'][:50]}...")


Tools defined:
   - pagamento: Tool de pagamento...


### 4.2 Convert Dataset to Google FunctionGemma Format

We use the **official Google FunctionGemma format** as documented at:
https://huggingface.co/google/functiongemma-270m-it

We do NOT use HuggingFace's `apply_chat_template` because it produces a different format
that breaks compatibility with the base model.

**Google FunctionGemma format:**

**Input (prompt):**
```
<start_of_turn>developer
You are a model that can do function calling with the following functions
<start_function_declaration>declaration:function_name{description:<escape>...<escape>,parameters:{...}}<end_function_declaration>
<end_of_turn>
<start_of_turn>user
make it red
<end_of_turn>
<start_of_turn>model
```

**Output (completion):**
```
<start_function_call>call:function_name{param:<escape>value<escape>}<end_function_call>
```

**Key format elements:**
- `<escape>` tokens wrap all string values (not JSON quotes)
- `call:` prefix before function name in output
- Parameters use `{param:<escape>value<escape>}` format

In [8]:
# =============================================================================
# STEP 4.2: Convert dataset to Google FunctionGemma format
# =============================================================================
# CRITICAL: We manually create prompts in the EXACT format Flutter uses!
# Do NOT use apply_chat_template - it produces a different format.

import json
from datasets import Dataset

# FunctionGemma special tokens (same as Flutter uses)
START_TURN = "<start_of_turn>"
END_TURN = "<end_of_turn>"
START_DECL = "<start_function_declaration>"
END_DECL = "<end_function_declaration>"
START_CALL = "<start_function_call>"
END_CALL = "<end_function_call>"
ESCAPE = "<escape>"

# =============================================================================
# Function declaration adapted for pagamentos (PIX / CREDITO / DEBITO)
# =============================================================================
FUNCTION_DECLARATIONS = f"""
{START_DECL}declaration:pagamento{{
  description:{ESCAPE}Realiza um pagamento{ESCAPE},
  parameters:{{
    type:{ESCAPE}OBJECT{ESCAPE},
    properties:{{
      valor:{{
        description:{ESCAPE}Valor do pagamento em reais{ESCAPE},
        type:{ESCAPE}NUMBER{ESCAPE}
      }},
      metodo_pagamento:{{
        description:{ESCAPE}Método do pagamento (pix, credito ou debito){ESCAPE},
        type:{ESCAPE}STRING{ESCAPE}
      }}
    }},
    required:[
      {ESCAPE}valor{ESCAPE},
      {ESCAPE}metodo_pagamento{ESCAPE}
    ]
  }}
}}{END_DECL}
"""

# =============================================================================
# System prompt for training
# =============================================================================
SYSTEM_PROMPT = f"""{START_TURN}developer
You are a model that can do function calling with the following functions
{FUNCTION_DECLARATIONS}
{END_TURN}
"""

# =============================================================================
# Convert a single example into Google FunctionGemma format
# =============================================================================
def create_training_example(sample):
    """
    Input: {"user_content": "passa quinhentos reais e cinquenta centavos no credito",
            "tool_name": "pagamento",
            "tool_arguments": {"valor": 500.5, "metodo_pagamento": "credito"}}

    Output text for training:
    <start_of_turn>developer
    You are a model...
    <end_of_turn>
    <start_of_turn>user
    passa quinhentos reais e cinquenta centavos no credito
    <end_of_turn>
    <start_of_turn>model
    <start_function_call>call:pagamento{valor:<escape>500.5<escape>,metodo_pagamento:<escape>credito<escape>}<end_function_call>
    """
    user_content = sample["user_content"]
    tool_name = sample["tool_name"]
    tool_args = sample["tool_arguments"]  # assume já é dict

    # Build prompt (input)
    prompt = f"""{SYSTEM_PROMPT}{START_TURN}user
{user_content}
{END_TURN}
{START_TURN}model
"""

    # Build completion (output) in Google format: {param:<escape>value<escape>}
    params_str = ",".join([f"{k}:{ESCAPE}{v}{ESCAPE}" for k, v in tool_args.items()])
    completion = f"{START_CALL}call:{tool_name}{{{params_str}}}{END_CALL}"

    # Full training text = prompt + completion
    return {"text": prompt + completion}

# =============================================================================
# Load raw dataset
# =============================================================================
raw_data = []

with open('/content/drive/MyDrive/dataset/machine_actions_large.jsonl', 'r', encoding='utf-8') as f:
    for line in f:
        line = line.strip()
        if not line:
            continue  # pula linhas vazias
        try:
            raw_data.append(json.loads(line))
        except json.JSONDecodeError as e:
            print(f"Erro na linha: {line}")
            raise e

print(f"Loaded {len(raw_data)} raw examples")


dataset = Dataset.from_list(raw_data)
dataset = dataset.map(create_training_example, remove_columns=dataset.features)

# Split into train/test (80%/20%)
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)

print(f"\nDataset prepared:")
print(f"   Train: {len(dataset['train'])} examples")
print(f"   Test:  {len(dataset['test'])} examples")

# Show sample
print(f"\n{'='*60}")
print("Sample training example:")
print("="*60)
print(dataset['train'][0]['text'][:800])
print("...")


Loaded 9663 raw examples


Map:   0%|          | 0/9663 [00:00<?, ? examples/s]


Dataset prepared:
   Train: 7730 examples
   Test:  1933 examples

Sample training example:
<start_of_turn>developer
You are a model that can do function calling with the following functions

<start_function_declaration>declaration:pagamento{
  description:<escape>Realiza um pagamento<escape>,
  parameters:{
    type:<escape>OBJECT<escape>,
    properties:{
      valor:{
        description:<escape>Valor do pagamento em reais<escape>,
        type:<escape>NUMBER<escape>
      },
      metodo_pagamento:{
        description:<escape>Método do pagamento (pix, credito ou debito)<escape>,
        type:<escape>STRING<escape>
      }
    },
    required:[
      <escape>valor<escape>,
      <escape>metodo_pagamento<escape>
    ]
  }
}<end_function_declaration>

<end_of_turn>
<start_of_turn>user
passa cento e cinquenta e sete reais e vinte centavos no credito
<end_of_turn>
<start_of_turn>
...


## 5. Load Base Model

**Model:** `google/functiongemma-270m-it`
- 270M parameters (compact, designed for on-device)
- Instruction-tuned (it) - already trained to follow instructions
- Specialized for function calling

**Loading parameters:**
- `torch_dtype=bfloat16` - 16-bit weights to save memory (~540MB instead of ~1GB)
- `device_map="auto"` - automatically load to GPU
- `attn_implementation="eager"` - without FlashAttention (for compatibility)

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# =============================================================================
# Load FunctionGemma base model
# =============================================================================
BASE_MODEL = "google/functiongemma-270m-it"

print(f"Loading {BASE_MODEL}...")
print("   (Downloads ~540MB on first run, then uses cache)")

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,      # 16-bit to save VRAM
    device_map="auto",                # Automatically load to GPU
    attn_implementation="eager"       # Without FlashAttention for compatibility
)

# Tokenizer converts text to tokens and back
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

print(f"\nModel loaded!")
print(f"   Parameters: {model.num_parameters():,}")
print(f"   Memory: ~{model.num_parameters() * 2 / 1e9:.1f} GB (bfloat16)")
print(f"   Device: {model.device}")

Loading google/functiongemma-270m-it...
   (Downloads ~540MB on first run, then uses cache)


config.json:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/176 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/63.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/13.8k [00:00<?, ?B/s]


Model loaded!
   Parameters: 268,098,176
   Memory: ~0.5 GB (bfloat16)
   Device: cuda:0


## 6. Configure Training

**Hyperparameters from official Google FunctionGemma cookbook:**

| Parameter | Value | Explanation |
|-----------|-------|-------------|
| `num_train_epochs` | 5 | Extended training for enum support (320 examples) |
| `learning_rate` | 5e-5 | Learning rate (conservative for fine-tuning) |
| `lr_scheduler_type` | cosine | Smoothly decreases LR towards end of training |
| `gradient_accumulation_steps` | 8 | Gradient accumulation (effective batch = 32) |
| `max_length` | 1024 | Maximum sequence length in tokens |
| `bf16` | True | 16-bit training to save memory |

**Why these values:**
- LR 5e-5 - prevents "forgetting" base knowledge
- Cosine scheduler - smooth LR decay improves convergence
- Gradient accumulation 8 - simulates large batch without running out of memory

In [16]:
from trl import SFTConfig, SFTTrainer

# Output directory
OUTPUT_DIR = "functiongemma-machine-demo"

# =============================================================================
# Training configuration (based on official Google FunctionGemma cookbook)
# https://github.com/google-gemini/gemma-cookbook/blob/main/FunctionGemma/
# =============================================================================
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,

    # Dataset field with pre-formatted Google FunctionGemma format
    dataset_text_field="text",          # Use our pre-formatted text, NOT apply_chat_template

    # Training params (Google official uses 2 epochs, we use 5 for enum support)
    max_length=1024,                    # Max sequence length in tokens
    packing=False,                      # Don't pack multiple examples into one sequence
    num_train_epochs=5,                 # Extended training for enum support (320 examples)
    per_device_train_batch_size=4,      # Batch size per GPU
    per_device_eval_batch_size=4,       # Eval batch size
    gradient_accumulation_steps=8,      # Effective batch size: 4 * 8 = 32

    # Optimizer (Google official params)
    learning_rate=5e-5,                 # Google official: 1e-5 (more conservative than 5e-5)
    lr_scheduler_type="cosine",         # Google official: cosine decay
    optim="adamw_torch_fused",          # Fused AdamW for faster training
    warmup_ratio=0.1,                   # 10% warmup steps

    # Logging and checkpoints
    logging_steps=10,                   # Log every 10 steps
    eval_strategy="epoch",              # Evaluate after each epoch
    save_strategy="epoch",              # Save checkpoint after each epoch

    # Memory optimization
    gradient_checkpointing=False,       # Trade compute for memory (enable if OOM)
    bf16=True,                          # Use bfloat16 for training

    # Output
    report_to="tensorboard",            # Log to TensorBoard
    push_to_hub=False,                  # Set to True to upload to HuggingFace
)

print("Training configuration (Google official params):")
print(f"   Epochs: {training_args.num_train_epochs}")
print(f"   Batch size: {training_args.per_device_train_batch_size}")
print(f"   Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"   Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   LR scheduler: {training_args.lr_scheduler_type}")
print(f"   Max length: {training_args.max_length}")
print(f"   Dataset field: {training_args.dataset_text_field}")

Training configuration (Google official params):
   Epochs: 5
   Batch size: 4
   Gradient accumulation: 8
   Effective batch size: 32
   Learning rate: 5e-05
   LR scheduler: SchedulerType.COSINE
   Max length: 1024
   Dataset field: text


## 7. Start Training

**What happens:**
1. `SFTTrainer` uses the pre-formatted `text` field from our dataset
2. Model learns to predict the correct function call for each user message
3. We do NOT use `apply_chat_template` - data is already in correct Google format!

**Training time:** ~5 minutes on A100 GPU (for ~300 examples)

**Monitoring:**
- `loss` should decrease
- `eval_loss` should not increase (otherwise overfitting)

In [15]:
# =============================================================================
# Create SFTTrainer and start training
# =============================================================================
# dataset_text_field="text" is configured in SFTConfig above.
# This tells SFTTrainer to use our pre-formatted text directly,
# without applying HuggingFace's chat template (which uses wrong format).

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,  # TRL 0.26.2: use processing_class, not tokenizer
)

print("Starting training...")
print(f"   Train examples: {len(dataset['train'])}")
print(f"   Eval examples: {len(dataset['test'])}")
print(f"   Format: Google FunctionGemma (manual)")
print(f"   Estimated time: ~5 minutes on A100")
print("-" * 50)

# Train!
train_result = trainer.train()

print("\n" + "=" * 50)
print("Training complete!")
print(f"   Final loss: {train_result.training_loss:.4f}")

Adding EOS to train dataset:   0%|          | 0/7730 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/7730 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/7730 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/1933 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1933 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/1933 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting training...
   Train examples: 7730
   Eval examples: 1933
   Format: Google FunctionGemma (manual)
   Estimated time: ~5 minutes on A100
--------------------------------------------------


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.59 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.24 GiB is free. Process 2595 has 13.50 GiB memory in use. Of the allocated memory 12.87 GiB is allocated by PyTorch, and 506.30 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## 8. Save Model

**Files saved:**
- `model.safetensors` - model weights (~540MB)
- `config.json` - architecture configuration
- `tokenizer.json`, `tokenizer_config.json` - tokenizer
- `special_tokens_map.json` - special tokens

**Format:** SafeTensors (safe, no pickle)

In [None]:
# =============================================================================
# Save the fine-tuned model to Google Drive
# =============================================================================
from google.colab import drive

FINAL_MODEL_DIR = f"{OUTPUT_DIR}-final"
DRIVE_MODEL_DIR = f"/content/drive/MyDrive/function-gemma-tuned/{FINAL_MODEL_DIR}"

# Save model weights and config
trainer.save_model(FINAL_MODEL_DIR)

# Save tokenizer (needed for inference)
tokenizer.save_pretrained(FINAL_MODEL_DIR)

print(f"Model saved locally to {FINAL_MODEL_DIR}/")

# Copy to Google Drive
!cp -r {FINAL_MODEL_DIR} /content/drive/MyDrive/

print(f"\nModel copied to Google Drive: {DRIVE_MODEL_DIR}/")
print("You can now use this in the conversion notebook!")
!ls -la {DRIVE_MODEL_DIR}/

## 9. Test Fine-tuned Model


In [None]:
# =============================================================================
# Test the fine-tuned model on new prompts
# =============================================================================
# CRITICAL: Use the same Google format as training (not apply_chat_template!)

test_prompts = [
    "faz um pix de vinte cinco",
    "passa vinte no credito",
    "faz 60 no debito"
]

print("Testing fine-tuned model:")
print("=" * 60)

for prompt in test_prompts:
    # Create prompt in SAME format as training (Google FunctionGemma)
    input_text = f"""{SYSTEM_PROMPT}{START_TURN}user
{prompt}
{END_TURN}
{START_TURN}model
"""

    # Tokenize and send to GPU
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    # Generate response
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

    # Decode only new tokens (without prompt)
    response = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=False
    )

    print(f"\nUser: {prompt}")
    print(f"Model: {response.strip()}")

    # Verify format
    if "<start_function_call>call:" in response:
        print("   ✅ Correct format!")
    else:
        print("   ⚠️  Unexpected format")
    print("-" * 60)