# üé® Fine-tuning FunctionGemma for Square Color Control

This notebook demonstrates how to fine-tune FunctionGemma to recognize color control commands.

**Author:** [Your Name]
**Portfolio:** AI Engineering

## Objectives
1. Train the model to call `set_square_color` when the user wants to change the color
2. Train the model to call `get_square_color` when the user asks about the current color
3. Support various natural language command styles

## üì¶ 1. Setup and Installation

In [None]:
# Install dependencies
%pip install -q torch tensorboard
%pip install -q transformers datasets accelerate evaluate trl protobuf sentencepiece

# If running on Ampere+ GPU (A100, L4), uncomment:
# %pip install -q flash-attn

In [None]:
# Login to Hugging Face Hub
from huggingface_hub import login

# If using Colab secrets:
# from google.colab import userdata
# login(token=userdata.get('HF_TOKEN'))

# Or interactive login:
login()

In [None]:
# Configuration
BASE_MODEL = "google/functiongemma-270m-it"
OUTPUT_DIR = "functiongemma-square-color"  # Model name on your HF Hub
LEARNING_RATE = 5e-5
NUM_EPOCHS = 8
BATCH_SIZE = 4

## üìä 2. Prepare Dataset

In [None]:
import json
from datasets import Dataset
from transformers.utils import get_json_schema

# Tool definitions
def set_square_color(color: str) -> str:
    """
    Sets the color of the square displayed on the screen.
    
    Args:
        color: The color to set, e.g. red, blue, green
    """
    return f"Color set to {color}"

def get_square_color() -> str:
    """
    Returns the current color of the square.
    Use this when the user asks about the current color.
    """
    return "Current color"

# Generate schemas automatically
TOOLS = [
    get_json_schema(set_square_color),
    get_json_schema(get_square_color)
]

print("Tool schemas:")
print(json.dumps(TOOLS, indent=2))

In [None]:
# Load training dataset from file
with open("dataset/square_color_dataset.json", "r") as f:
    square_color_dataset = json.load(f)

print(f"Total examples: {len(square_color_dataset)}")
print(f"  - SET: {len([x for x in square_color_dataset if x['tool_name'] == 'set_square_color'])}")
print(f"  - GET: {len([x for x in square_color_dataset if x['tool_name'] == 'get_square_color'])}")

# Preview first few examples
print("\nFirst 3 examples:")
for i, sample in enumerate(square_color_dataset[:3]):
    print(f"  {i+1}. \"{sample['user_content']}\" ‚Üí {sample['tool_name']}")

In [None]:
# Convert to conversation format
SYSTEM_PROMPT = "You are a model that can do function calling with the following functions"

def create_conversation(sample):
    return {
        "messages": [
            {"role": "developer", "content": SYSTEM_PROMPT},
            {"role": "user", "content": sample["user_content"]},
            {
                "role": "assistant",
                "tool_calls": [{
                    "type": "function",
                    "function": {
                        "name": sample["tool_name"],
                        "arguments": json.loads(sample["tool_arguments"])
                    }
                }]
            },
        ],
        "tools": TOOLS
    }

# Create dataset
dataset = Dataset.from_list(square_color_dataset)
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)

# Split 80/20
dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)

print(f"Train: {len(dataset['train'])} examples")
print(f"Test: {len(dataset['test'])} examples")

In [None]:
# Visualize an example
print("Formatted conversation example:")
print(json.dumps(dataset["train"][0], indent=2))

## ü§ñ 3. Load Model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype="auto",
    device_map="auto",
    quantization_config=quantization_config,        
    attn_implementation="eager"
    
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
print(f"Parameters: {model.num_parameters():,}")

In [None]:
# Visualize how the tokenizer formats the prompt
debug_msg = tokenizer.apply_chat_template(
    dataset["train"][0]["messages"],
    tools=dataset["train"][0]["tools"],
    add_generation_prompt=False,
    tokenize=False
)

print("=== Formatted prompt ===")
print(debug_msg)

## üß™ 3.5. Pre-Training Evaluation (Baseline)

Before fine-tuning, let's evaluate the base model to establish a baseline. This helps us measure the actual improvement from fine-tuning.

In [None]:
def evaluate_model(model, tokenizer, test_samples, tools, system_prompt, verbose=True):
    """
    Evaluate model on a set of test samples.
    Returns accuracy metrics and detailed results.
    """
    results = {
        "total": len(test_samples),
        "correct": 0,
        "correct_tool": 0,
        "correct_args": 0,
        "details": []
    }
    
    for sample in test_samples:
        messages = [
            {"role": "developer", "content": system_prompt},
            {"role": "user", "content": sample["user_content"]},
        ]
        
        inputs = tokenizer.apply_chat_template(
            messages,
            tools=tools,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
            )
        
        input_length = inputs['input_ids'].shape[1]
        response = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)
        
        # Check if correct tool was called
        tool_correct = sample["tool_name"] in response
        
        # Check if arguments are correct (for set_square_color)
        args_correct = False
        if tool_correct and sample["tool_name"] == "set_square_color":
            expected_args = json.loads(sample["tool_arguments"])
            args_correct = expected_args.get("color", "") in response
        elif tool_correct and sample["tool_name"] == "get_square_color":
            args_correct = True  # No args needed
        
        if tool_correct:
            results["correct_tool"] += 1
        if tool_correct and args_correct:
            results["correct"] += 1
            results["correct_args"] += 1
        
        results["details"].append({
            "input": sample["user_content"],
            "expected_tool": sample["tool_name"],
            "expected_args": sample["tool_arguments"],
            "response": response,
            "tool_correct": tool_correct,
            "args_correct": args_correct
        })
    
    results["tool_accuracy"] = results["correct_tool"] / results["total"] * 100
    results["full_accuracy"] = results["correct"] / results["total"] * 100
    
    if verbose:
        print(f"Tool Accuracy: {results['correct_tool']}/{results['total']} ({results['tool_accuracy']:.1f}%)")
        print(f"Full Accuracy (tool + args): {results['correct']}/{results['total']} ({results['full_accuracy']:.1f}%)")
    
    return results

In [None]:
# Create evaluation test set from the dataset (sample 5 SET + 5 GET)
import random

random.seed(42)  # For reproducibility

set_samples = [s for s in square_color_dataset if s["tool_name"] == "set_square_color"]
get_samples = [s for s in square_color_dataset if s["tool_name"] == "get_square_color"]

eval_test_cases = random.sample(set_samples, min(5, len(set_samples))) + \
                  random.sample(get_samples, min(5, len(get_samples)))

print("=" * 50)
print("PRE-TRAINING EVALUATION (Baseline)")
print("=" * 50)
print(f"\nEvaluating base model on {len(eval_test_cases)} test cases...\n")

baseline_results = evaluate_model(
    model=model,
    tokenizer=tokenizer,
    test_samples=eval_test_cases,
    tools=TOOLS,
    system_prompt=SYSTEM_PROMPT
)

# Show some example outputs
print("\n--- Sample Outputs (Base Model) ---")
for i, detail in enumerate(baseline_results["details"][:4]):
    status = "‚úÖ" if detail["tool_correct"] else "‚ùå"
    print(f"\n{status} Input: {detail['input']}")
    print(f"   Expected: {detail['expected_tool']}")
    print(f"   Output: {detail['response'][:200]}...")

## üî• 4. Fine-tuning

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

print("Reloading model for fine-tuning (without quantization)...")

del model
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager"
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
print(f"Parameters: {model.num_parameters():,}")
print("Ready for fine-tuning!")

In [None]:
from trl import SFTConfig, SFTTrainer

torch_dtype = model.dtype

# Training configuration
args = SFTConfig(
    output_dir=OUTPUT_DIR,
    max_length=512,
    packing=False,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_checkpointing=False,
    optim="adamw_torch_fused",
    logging_steps=1,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    fp16=True if torch_dtype == torch.float16 else False,
    bf16=True if torch_dtype == torch.bfloat16 else False,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

# Create trainer
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

print("Trainer created successfully!")

In [None]:
# üöÄ Start training!
print("Starting fine-tuning...")
trainer.train()

print("\n‚úÖ Training complete!")

In [None]:
# Save final model
trainer.save_model()
print(f"Model saved to: {OUTPUT_DIR}")

## üìà 5. Visualize Results

In [None]:
import matplotlib.pyplot as plt

# Extract loss history
log_history = trainer.state.log_history

train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]

# Plot
plt.figure(figsize=(10, 6))
plt.plot(epoch_train, train_losses, label="Training Loss", alpha=0.7)
plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

## üß™ 6. Post-Training Evaluation

Now let's evaluate the fine-tuned model and compare it with the baseline to measure the improvement.

In [None]:
print("=" * 50)
print("POST-TRAINING EVALUATION (Fine-tuned)")
print("=" * 50)
print(f"\nEvaluating fine-tuned model on {len(eval_test_cases)} test cases...\n")

finetuned_results = evaluate_model(
    model=model,
    tokenizer=tokenizer,
    test_samples=eval_test_cases,
    tools=TOOLS,
    system_prompt=SYSTEM_PROMPT
)

# Show some example outputs
print("\n--- Sample Outputs (Fine-tuned Model) ---")
for i, detail in enumerate(finetuned_results["details"][:4]):
    status = "‚úÖ" if detail["tool_correct"] else "‚ùå"
    print(f"\n{status} Input: {detail['input']}")
    print(f"   Expected: {detail['expected_tool']}")
    print(f"   Output: {detail['response'][:200]}...")

In [None]:
# Compare baseline vs fine-tuned results
print("=" * 60)
print("üìä COMPARISON: Baseline vs Fine-tuned")
print("=" * 60)

print(f"\n{'Metric':<30} {'Baseline':>12} {'Fine-tuned':>12} {'Improvement':>12}")
print("-" * 66)

# Tool accuracy comparison
tool_improvement = finetuned_results["tool_accuracy"] - baseline_results["tool_accuracy"]
print(f"{'Tool Accuracy':<30} {baseline_results['tool_accuracy']:>11.1f}% {finetuned_results['tool_accuracy']:>11.1f}% {tool_improvement:>+11.1f}%")

# Full accuracy comparison
full_improvement = finetuned_results["full_accuracy"] - baseline_results["full_accuracy"]
print(f"{'Full Accuracy (tool + args)':<30} {baseline_results['full_accuracy']:>11.1f}% {finetuned_results['full_accuracy']:>11.1f}% {full_improvement:>+11.1f}%")

print("-" * 66)

# Summary
if full_improvement > 0:
    print(f"\n‚úÖ Fine-tuning improved accuracy by {full_improvement:.1f} percentage points!")
elif full_improvement == 0:
    print(f"\n‚ö†Ô∏è No change in accuracy. Consider adjusting training parameters.")
else:
    print(f"\n‚ùå Accuracy decreased. Check for overfitting or data issues.")

In [None]:
# Visualization: Baseline vs Fine-tuned comparison
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Chart 1: Bar chart comparison
metrics = ['Tool\nAccuracy', 'Full\nAccuracy']
baseline_vals = [baseline_results["tool_accuracy"], baseline_results["full_accuracy"]]
finetuned_vals = [finetuned_results["tool_accuracy"], finetuned_results["full_accuracy"]]

x = np.arange(len(metrics))
width = 0.35

bars1 = axes[0].bar(x - width/2, baseline_vals, width, label='Baseline', color='#ff6b6b', alpha=0.8)
bars2 = axes[0].bar(x + width/2, finetuned_vals, width, label='Fine-tuned', color='#4ecdc4', alpha=0.8)

axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('Model Performance: Baseline vs Fine-tuned')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics)
axes[0].legend()
axes[0].set_ylim(0, 110)
axes[0].axhline(y=100, color='gray', linestyle='--', alpha=0.3)

# Add value labels on bars
for bar in bars1:
    height = bar.get_height()
    axes[0].annotate(f'{height:.1f}%', xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10)
for bar in bars2:
    height = bar.get_height()
    axes[0].annotate(f'{height:.1f}%', xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10)

# Chart 2: Per-sample comparison
sample_labels = [d["input"][:20] + "..." for d in baseline_results["details"]]
baseline_correct = [1 if d["tool_correct"] else 0 for d in baseline_results["details"]]
finetuned_correct = [1 if d["tool_correct"] else 0 for d in finetuned_results["details"]]

x2 = np.arange(len(sample_labels))
width2 = 0.35

axes[1].barh(x2 - width2/2, baseline_correct, width2, label='Baseline', color='#ff6b6b', alpha=0.8)
axes[1].barh(x2 + width2/2, finetuned_correct, width2, label='Fine-tuned', color='#4ecdc4', alpha=0.8)

axes[1].set_xlabel('Correct (1) / Incorrect (0)')
axes[1].set_title('Per-Sample Results')
axes[1].set_yticks(x2)
axes[1].set_yticklabels(sample_labels, fontsize=8)
axes[1].legend(loc='lower right')
axes[1].set_xlim(-0.1, 1.5)

plt.tight_layout()
plt.show()

# Print detailed per-sample comparison
print("\nüìã Detailed Per-Sample Comparison:")
print("-" * 80)
for i, (b, f) in enumerate(zip(baseline_results["details"], finetuned_results["details"])):
    b_status = "‚úÖ" if b["tool_correct"] else "‚ùå"
    f_status = "‚úÖ" if f["tool_correct"] else "‚ùå"
    change = ""
    if not b["tool_correct"] and f["tool_correct"]:
        change = " üéâ FIXED!"
    elif b["tool_correct"] and not f["tool_correct"]:
        change = " ‚ö†Ô∏è REGRESSED"
    print(f"{b['input'][:40]:<42} Base: {b_status}  Fine-tuned: {f_status}{change}")

## üì§ 7. Push to Hugging Face Hub

In [None]:
# Push to Hub
trainer.push_to_hub()

print(f"\n‚úÖ Model pushed to: https://huggingface.co/{trainer.hub_model_id}")