# Fine-tuning FunctionGemma for SQL

Teaching FunctionGemma-270M to be good at text-to-SQL through SFT, GRPO and APO.

## Environment setup

In [1]:
# Install dependencies
!pip install transformers torch accelerate bitsandbytes datasets trl peft sqlparse python-dotenv

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting trl
  Downloading trl-0.26.2-py3-none-any.whl.metadata (11 kB)
Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading trl-0.26.2-py3-none-any.whl (518 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes, trl
Successfully installed bitsandbytes-0.49.0 trl-0.26.2


In [2]:
import json
import sqlite3
import os
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from collections import defaultdict
import re

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm.auto import tqdm
import pandas as pd

## Download Spider benchmark

Spider is the standard text-to-SQL benchmark with 200 databases across 138 domains. We download it with the actual SQLite database files so we can execute generated queries.

In [4]:
!pip install gdown -q

import gdown
import zipfile
from pathlib import Path

SPIDER_DIR = Path("spider")

if not SPIDER_DIR.exists():
    # Download Spider dataset from Google Drive
    # File ID: 1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J
    url = "https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J"
    output = "spider.zip"
    
    print("Downloading Spider dataset...")
    gdown.download(url, output, quiet=False)
    
    print("Extracting...")
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall(".")
    
    # Clean up
    os.remove(output)
    print(f"Spider dataset ready at {SPIDER_DIR}")
else:
    print(f"Spider dataset already exists at {SPIDER_DIR}")

Downloading Spider dataset...


Downloading...
From (original): https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J
From (redirected): https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J&confirm=t&uuid=f9d13177-b23e-4010-a6dd-d6469fc01931
To: /content/spider.zip
100%|██████████| 206M/206M [00:02<00:00, 74.2MB/s] 


Extracting...
Spider dataset ready at spider


In [20]:
# Verify Spider structure
SPIDER_DIR = Path("spider_data")
spider_db_dir = SPIDER_DIR / "database"
spider_train = SPIDER_DIR / "train_spider.json"
spider_dev = SPIDER_DIR / "dev.json"
spider_tables = SPIDER_DIR / "tables.json"

print(f"Database directory exists: {spider_db_dir.exists()}")
print(f"Train file exists: {spider_train.exists()}")
print(f"Dev file exists: {spider_dev.exists()}")
print(f"Tables file exists: {spider_tables.exists()}")

if spider_db_dir.exists():
    databases = list(spider_db_dir.iterdir())
    print(f"\nNumber of databases: {len(databases)}")
    print(f"Sample databases: {[d.name for d in databases[:5]]}")

Database directory exists: True
Train file exists: True
Dev file exists: True
Tables file exists: True

Number of databases: 166
Sample databases: ['race_track', 'entrepreneur', 'flight_1', 'hr_1', 'battle_death']


## SQL Benchmarker

The benchmarker evaluates text-to-SQL models using execution accuracy: whether the generated SQL produces the same results as the gold SQL when run against the actual database.

In [21]:
@dataclass
class BenchmarkResult:
    """Result of a single benchmark query."""
    question: str
    db_id: str
    gold_sql: str
    predicted_sql: str
    gold_result: Optional[list] = None
    predicted_result: Optional[list] = None
    execution_match: bool = False
    syntax_valid: bool = False
    error: Optional[str] = None


class SpiderBenchmarker:
    """Benchmarker for text-to-SQL using Spider dataset with execution accuracy."""
    
    def __init__(self, spider_dir: Path, split: str = "dev"):
        self.spider_dir = Path(spider_dir)
        self.db_dir = self.spider_dir / "database"
        
        # Load the appropriate split
        if split == "dev":
            data_file = self.spider_dir / "dev.json"
        else:
            data_file = self.spider_dir / "train_spider.json"
        
        with open(data_file) as f:
            self.data = json.load(f)
        
        # Load table schemas
        with open(self.spider_dir / "tables.json") as f:
            tables_data = json.load(f)
            self.schemas = {t["db_id"]: t for t in tables_data}
        
        print(f"Loaded {len(self.data)} examples from {split} split")
        print(f"Loaded schemas for {len(self.schemas)} databases")
    
    def get_schema_prompt(self, db_id: str) -> str:
        """Generate a schema description for prompting."""
        schema = self.schemas.get(db_id)
        if not schema:
            return ""
        
        lines = [f"Database: {db_id}", "Tables:"]
        
        table_names = schema["table_names_original"]
        column_names = schema["column_names_original"]
        column_types = schema["column_types"]
        primary_keys = schema.get("primary_keys", [])
        foreign_keys = schema.get("foreign_keys", [])
        
        # Group columns by table
        table_columns = defaultdict(list)
        for i, (table_idx, col_name) in enumerate(column_names):
            if table_idx >= 0:  # Skip the special "*" column
                col_type = column_types[i] if i < len(column_types) else "unknown"
                is_pk = i in primary_keys
                pk_marker = " (PK)" if is_pk else ""
                table_columns[table_idx].append(f"{col_name} {col_type}{pk_marker}")
        
        for idx, table_name in enumerate(table_names):
            cols = table_columns.get(idx, [])
            lines.append(f"  {table_name}: {', '.join(cols)}")
        
        # Add foreign keys
        if foreign_keys:
            lines.append("Foreign keys:")
            for fk in foreign_keys:
                if len(fk) == 2:
                    from_col = column_names[fk[0]]
                    to_col = column_names[fk[1]]
                    if from_col[0] >= 0 and to_col[0] >= 0:
                        from_table = table_names[from_col[0]]
                        to_table = table_names[to_col[0]]
                        lines.append(f"  {from_table}.{from_col[1]} -> {to_table}.{to_col[1]}")
        
        return "\n".join(lines)
    
    def execute_sql(self, db_id: str, sql: str, timeout: float = 5.0) -> tuple[Optional[list], Optional[str]]:
        """Execute SQL against a database and return results."""
        db_path = self.db_dir / db_id / f"{db_id}.sqlite"
        
        if not db_path.exists():
            return None, f"Database not found: {db_path}"
        
        try:
            conn = sqlite3.connect(str(db_path), timeout=timeout)
            conn.text_factory = str
            cursor = conn.cursor()
            cursor.execute(sql)
            results = cursor.fetchall()
            conn.close()
            return results, None
        except Exception as e:
            return None, str(e)
    
    def results_match(self, result1: Optional[list], result2: Optional[list]) -> bool:
        """Check if two query results match (order-independent for sets)."""
        if result1 is None or result2 is None:
            return False
        
        # Convert to sets of tuples for order-independent comparison
        try:
            set1 = set(tuple(row) if isinstance(row, (list, tuple)) else (row,) for row in result1)
            set2 = set(tuple(row) if isinstance(row, (list, tuple)) else (row,) for row in result2)
            return set1 == set2
        except (TypeError, ValueError):
            # Fall back to list comparison if unhashable
            return sorted(str(r) for r in result1) == sorted(str(r) for r in result2)
    
    def evaluate_single(self, example: dict, predicted_sql: str) -> BenchmarkResult:
        """Evaluate a single prediction."""
        question = example["question"]
        db_id = example["db_id"]
        gold_sql = example["query"]
        
        result = BenchmarkResult(
            question=question,
            db_id=db_id,
            gold_sql=gold_sql,
            predicted_sql=predicted_sql
        )
        
        # Execute gold SQL
        gold_result, gold_error = self.execute_sql(db_id, gold_sql)
        result.gold_result = gold_result
        
        if gold_error:
            result.error = f"Gold SQL error: {gold_error}"
            return result
        
        # Execute predicted SQL
        pred_result, pred_error = self.execute_sql(db_id, predicted_sql)
        result.predicted_result = pred_result
        
        if pred_error:
            result.error = f"Predicted SQL error: {pred_error}"
            result.syntax_valid = False
        else:
            result.syntax_valid = True
            result.execution_match = self.results_match(gold_result, pred_result)
        
        return result
    
    def run_benchmark(self, model_fn, num_samples: Optional[int] = None, 
                      verbose: bool = True) -> list[BenchmarkResult]:
        """
        Run the benchmark with a model function.
        
        Args:
            model_fn: Function that takes (question, schema_prompt) and returns SQL
            num_samples: Limit number of samples (None for all)
            verbose: Print progress
        
        Returns:
            List of BenchmarkResult objects
        """
        samples = self.data[:num_samples] if num_samples else self.data
        results = []
        
        iterator = tqdm(samples, desc="Benchmarking") if verbose else samples
        
        for example in iterator:
            schema_prompt = self.get_schema_prompt(example["db_id"])
            
            try:
                predicted_sql = model_fn(example["question"], schema_prompt)
            except Exception as e:
                predicted_sql = ""
                
            result = self.evaluate_single(example, predicted_sql)
            results.append(result)
        
        return results
    
    def compute_metrics(self, results: list[BenchmarkResult]) -> dict:
        """Compute aggregate metrics from results."""
        total = len(results)
        if total == 0:
            return {"total": 0}
        
        syntax_valid = sum(1 for r in results if r.syntax_valid)
        execution_match = sum(1 for r in results if r.execution_match)
        
        return {
            "total": total,
            "syntax_valid": syntax_valid,
            "syntax_accuracy": syntax_valid / total,
            "execution_match": execution_match,
            "execution_accuracy": execution_match / total,
        }


print("SpiderBenchmarker class defined")

SpiderBenchmarker class defined


## Load FunctionGemma-270M

We load the base FunctionGemma-270M model to establish baseline performance before fine-tuning.

In [22]:
# Load FunctionGemma-270M
MODEL_ID = "google/functiongemma-270m-it"

print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    attn_implementation="eager",
    torch_dtype="auto",
)
model.eval()
print(f"Model loaded on {model.device}")

Loading google/functiongemma-270m-it...
Model loaded on cuda:0


In [26]:
# Define the SQL execution tool for FunctionGemma
SQL_TOOL = {
    "function": {
        "name": "execute_sql",
        "description": "Execute a SQL query to answer the user's question. The query parameter should contain a valid SELECT statement.",
        "parameters": {
            "type": "OBJECT",
            "properties": {
                "query": {
                    "type": "STRING",
                    "description": "A valid SQL SELECT statement"
                }
            },
            "required": ["query"]
        }
    }
}


def parse_function_call(output: str) -> dict:
    """Parse FunctionGemma's function call format.
    
    Format: <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
    """
    import re
    
    # Try to extract function call
    match = re.search(r'call:(\w+)\{(.+?)\}', output)
    if match:
        func_name = match.group(1)
        params_str = match.group(2)
        
        # Parse parameters: key:<escape>value<escape>
        params = {}
        param_matches = re.findall(r'(\w+):<escape>(.+?)<escape>', params_str)
        for key, value in param_matches:
            params[key] = value
        
        return {"name": func_name, "parameters": params}
    
    return {}


def generate_sql(question: str, schema: str, max_new_tokens: int = 256) -> str:
    """Generate SQL from a natural language question using FunctionGemma."""
    
    messages = [
        {
            "role": "developer",
            "content": "You are a model that can do function calling with the following functions"
        },
        {
            "role": "user",
            "content": f"""{schema}

Question: {question}

Call execute_sql with the appropriate SQL query."""
        }
    ]
    
    prompt = tokenizer.apply_chat_template(
        messages,
        tools=[SQL_TOOL],
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # Decode WITHOUT skipping special tokens to preserve function call format
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
    new_output = full_output[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=False)):]
        
    # Parse the FunctionGemma function call format
    func_call = parse_function_call(new_output)
    
    if func_call and "parameters" in func_call:
        sql = func_call["parameters"].get("query", "")
    else:
        # Fallback: try to extract SQL directly
        sql = new_output.strip()
    
    # Clean up
    if sql:
        sql = sql.split(";")[0].strip()
        if sql and not sql.endswith(";"):
            sql += ";"
    
    return sql


# Test the generation
test_schema = """Database: concert_singer
Tables:
  singer: Singer_ID (PK), Name, Country, Age"""

test_question = "How many singers are there?"

print("Test generation:")
result = generate_sql(test_question, test_schema)
print(f"Extracted SQL: {result}")

Test generation:
Extracted SQL: SELECT singer_id, name, country, age FROM singer;


## Run Spider benchmark

We run the benchmark on the dev set to establish baseline performance.

In [None]:
# Initialise benchmarker
benchmarker = SpiderBenchmarker(SPIDER_DIR, split="dev")

# Run benchmark (use subset for initial testing)
NUM_SAMPLES = 100  # Set to None for full benchmark

print(f"Running benchmark on {NUM_SAMPLES or 'all'} samples...")
results = benchmarker.run_benchmark(generate_sql, num_samples=NUM_SAMPLES)

# Compute and display metrics
metrics = benchmarker.compute_metrics(results)
print("\n" + "="*50)
print("BENCHMARK RESULTS: FunctionGemma-270m-it (baseline)")
print("="*50)
print(f"Total samples:       {metrics['total']}")
print(f"Syntax valid:        {metrics['syntax_valid']} ({metrics['syntax_accuracy']*100:.1f}%)")
print(f"Execution match:     {metrics['execution_match']} ({metrics['execution_accuracy']*100:.1f}%)")
print("="*50)

Loaded 1034 examples from dev split
Loaded schemas for 166 databases
Running benchmark on 100 samples...


Benchmarking:   0%|          | 0/100 [00:00<?, ?it/s]

Raw output: <start_function_call>call:execute_sql{query:<escape>SELECT singer_id, singer_id FROM concert_singer WHERE concert_id = concert_id AND singer_id = singer_id AND concert_id = stadium_id<escape>}<end_function_call><start_function_response>
Raw output: <start_function_call>call:execute_sql{query:<escape>SELECT singer_id FROM singer_data WHERE concert_id = concert_id<escape>}<end_function_call><start_function_response>
Raw output: <start_function_call>call:execute_sql{query:<escape>SELECT singer_id.Singer_ID text ORDER BY age DESC LIMIT 100000;<escape>}<end_function_call><start_function_response>
Raw output: <start_function_call>call:execute_sql{query:<escape>SELECT singer_id.Singer_ID text ORDER BY singer.Age DESC LIMIT 10;<escape>}<end_function_call><start_function_response>
Raw output: <start_function_call>call:execute_sql{query:<escape>SELECT singer_id:PK, singer_id:PK, age: AVG(age) AS min_age, age: AVG(age) AS max_age FROM concert_singer WHERE country = 'FR' AND song_name 

In [27]:
# Show 3 detailed examples
import difflib

def show_example(result, idx):
    print(f"\n{'='*60}")
    print(f"EXAMPLE {idx + 1}")
    print(f"{'='*60}")
    
    print(f"\nDatabase: {result.db_id}")
    print(f"Question: {result.question}")
    
    print(f"\n--- Generated SQL ---")
    print(result.predicted_sql)
    
    print(f"\n--- Gold SQL ---")
    print(result.gold_sql)
    
    # Show diff
    print(f"\n--- Diff (generated vs gold) ---")
    diff = difflib.unified_diff(
        result.gold_sql.splitlines(keepends=True),
        result.predicted_sql.splitlines(keepends=True),
        fromfile='gold',
        tofile='generated',
        lineterm=''
    )
    diff_text = ''.join(diff)
    if diff_text:
        print(diff_text)
    else:
        print("(identical)")
    
    # Show execution results
    print(f"\n--- Execution Results ---")
    print(f"Gold result:      {result.gold_result[:3] if result.gold_result else 'None'}{'...' if result.gold_result and len(result.gold_result) > 3 else ''}")
    print(f"Generated result: {result.predicted_result[:3] if result.predicted_result else 'None'}{'...' if result.predicted_result and len(result.predicted_result) > 3 else ''}")
    
    print(f"\n--- Verdict ---")
    print(f"Syntax valid:     {result.syntax_valid}")
    print(f"Execution match:  {result.execution_match}")
    if result.error:
        print(f"Error:            {result.error}")


# Show 3 examples: 1 success (if any), 1 wrong result, 1 syntax error
examples_to_show = []

# Find one successful match
for r in results:
    if r.execution_match:
        examples_to_show.append(r)
        break

# Find one wrong result (valid syntax but wrong output)
for r in results:
    if r.syntax_valid and not r.execution_match:
        examples_to_show.append(r)
        break

# Find one syntax error
for r in results:
    if not r.syntax_valid:
        examples_to_show.append(r)
        break

# If we don't have 3, just take first 3
while len(examples_to_show) < 3 and len(results) > len(examples_to_show):
    for r in results:
        if r not in examples_to_show:
            examples_to_show.append(r)
            break

for idx, result in enumerate(examples_to_show):
    show_example(result, idx)


EXAMPLE 1

Database: concert_singer
Question: How many singers do we have?

--- Generated SQL ---
SELECT singer_id, singer_id FROM concert_singer WHERE concert_id = concert_id AND singer_id = singer_id AND concert_id = stadium_id;

--- Gold SQL ---
SELECT count(*) FROM singer

--- Diff (generated vs gold) ---
--- gold+++ generated@@ -1 +1 @@-SELECT count(*) FROM singer+SELECT singer_id, singer_id FROM concert_singer WHERE concert_id = concert_id AND singer_id = singer_id AND concert_id = stadium_id;

--- Execution Results ---
Gold result:      [(6,)]
Generated result: None

--- Verdict ---
Syntax valid:     False
Execution match:  False
Error:            Predicted SQL error: no such table: concert_singer

EXAMPLE 2

Database: concert_singer
Question: What is the total number of singers?

--- Generated SQL ---
SELECT singer_id FROM singer_data WHERE concert_id = concert_id;

--- Gold SQL ---
SELECT count(*) FROM singer

--- Diff (generated vs gold) ---
--- gold+++ generated@@ -1 +1 @@-

In [28]:
# Save results for analysis
results_df = pd.DataFrame([
    {
        "question": r.question,
        "db_id": r.db_id,
        "gold_sql": r.gold_sql,
        "predicted_sql": r.predicted_sql,
        "syntax_valid": r.syntax_valid,
        "execution_match": r.execution_match,
        "error": r.error,
    }
    for r in results
])

results_df.to_csv("baseline_benchmark_results.csv", index=False)
print(f"Results saved to baseline_benchmark_results.csv")

# Summary by database
print("\nExecution accuracy by database (top 10):")
db_accuracy = results_df.groupby("db_id")["execution_match"].agg(["sum", "count"])
db_accuracy["accuracy"] = db_accuracy["sum"] / db_accuracy["count"]
print(db_accuracy.sort_values("accuracy", ascending=False).head(10))

Results saved to baseline_benchmark_results.csv

Execution accuracy by database (top 10):
                sum  count  accuracy
db_id                               
car_1             0     13       0.0
concert_singer    0     45       0.0
pets_1            0     42       0.0


So, unsurprisingly, we find what we've known all along: out of the box, FunctionGemma-270m-it is not a very good text-to-SQL model. 

What can we do about that? Fine-tuning!

## Round 1: Supervised Fine-Tuning (SFT)