# ✈️ Natural-Language → SQL with Reinforcement-Fine-Tuning (RFT)
A full, reproducible demo that trains a 3-B parameter model to answer English questions by writing SQL, without touching real production data.

### 0. Preface – Why this matters
Large-language-model copilots are great, but even the best models don't know your private data, and can hallucinate a column that doesn’t exist.
Reinforcement-Fine-Tuning (RFT) fixes this by teaching your language model about your data and how to write accurate queries. In this tutorial you will:

| You will learn to 📚                               | By the end you’ll have 🎁                          |
| -------------------------------------------------- | -------------------------------------------------- |
| ✓ create a synthetic DB that *mimics* your schema  | `openflights.db` (<20 MB) wrapped by an MCP server |
| ✓ generate a MECE query set & ground-truth answers | `queries.json`, `gt_rows.json`                     |
| ✓ build NL ↔ SQL result training pairs             | `train_pairs.jsonl`                                |
| ✓ run an RFT job on Fireworks AI                   | a tuned **Qwen-2.5-Coder-3B-RFT** model            |
| ✓ benchmark baseline vs. RFT accuracy              | >2× exact-match gain                               |

<br>

> **A Note on Our Method: Demo vs. Real World 🌍** <br>
> Throughout this tutorial, we will be clear about what we're doing for the purpose of this self-contained demo versus what you would do in a real-world scenario.
> - **Real World 🌍**: Look for these notes to see the parallel step you would take in your own environment if you wanted to apply this workflow, typically by swapping in your own private business assets like schemas or query logs.

### 1. Development Environment Setup
**Complete these steps once in your terminal, *outside* this notebook.**

1.  **Get a Fireworks AI API Key**
    - Go to [fireworks.ai](https://fireworks.ai) and sign up.
    - Create an API key from your settings page.
    - Create a file named `.env` in your project directory and add your key:
      ```
      FIREWORKS_API_KEY="YOUR_API_KEY_HERE"
      ```

2.  **Install `uv`**
    - `uv` is a fast Python package manager from Astral. Follow the official installation instructions at [docs.astral.sh/uv/](https://docs.astral.sh/uv/).

3.  **Create a Virtual Environment and Install Packages**
    - Once `uv` is installed, create and activate a virtual environment.
    ```bash
    # Run this in your terminal
    uv venv .venv
    source .venv/bin/activate  # On Windows PowerShell: .venv\Scripts\Activate.ps1
    ```
    - Install all required packages using `uv add`.
    ```bash
    # Run this in your terminal
    uv add duckdb tabulate pandas pyarrow requests \
           faker python-dotenv \
           jsonlines fireworks-ai \
           mcp-sdk mcp-server-motherduck
    ```
After running these commands, your environment is ready. You can proceed with the cells inside this notebook.

### 2. Simulate the "Production" Database
First, we'll create a database that represents your real, populated production database. We'll download the public OpenFlights dataset and load it into a DuckDB file.

> **Real World 🌍**: You already have this! It's your live production database (or a replica). You would skip this entire step.

In [11]:
import urllib.request
import pathlib
import pandas as pd
import duckdb

# --- Download the raw data files ---
DATA_DIR = pathlib.Path("data")
DATA_DIR.mkdir(exist_ok=True)
BASE_URL = "https://raw.githubusercontent.com/jpatokal/openflights/master/data/"
FILES_TO_DOWNLOAD = {
    "airports": "airports.dat",
    "airlines": "airlines.dat",
    "routes": "routes.dat",
    "countries": "countries.dat",
    "planes": "planes.dat"
}
# Define column names as the files don't have headers
COLUMN_NAMES = {
    "airports": ["airport_id", "name", "city", "country", "iata", "icao", "latitude", "longitude", "altitude", "timezone", "dst", "tz_db", "type", "source"],
    "airlines": ["airline_id", "name", "alias", "iata", "icao", "callsign", "country", "active"],
    "routes": ["airline", "airline_id", "source_airport", "source_airport_id", "destination_airport", "destination_airport_id", "codeshare", "stops", "equipment"],
    "countries": ["name", "iso_code", "dafif_code"],
    "planes": ["name", "iata", "icao"]
}

PROD_DB_PATH = "data/prod_openflights.db"

# --- Load the real data into our "production" DuckDB ---
with duckdb.connect(PROD_DB_PATH) as con:
    for name, filename in FILES_TO_DOWNLOAD.items():
        url = f"{BASE_URL}{filename}"
        path = DATA_DIR / filename
        if not path.exists():
            urllib.request.urlretrieve(url, path)
            print(f"✅ Downloaded: {path}")

        # Load data using pandas to handle missing headers and null values
        df = pd.read_csv(path, header=None, names=COLUMN_NAMES[name], na_values=["\\N"])
        con.execute(f"CREATE OR REPLACE TABLE {name} AS SELECT * FROM df")

    print(f"\n✅ 'Production' database simulated at: {PROD_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())


✅ 'Production' database simulated at: data/prod_openflights.db
Tables created: [('airlines',), ('airports',), ('countries',), ('planes',), ('routes',)]


### 3. Acquire the Schema (No Data!)
This is a critical step. We connect to our "production" database and extract **only its schema** (the table structure, column names, and data types). We do not touch or read any of the data rows. This schema is the only artifact we need from the production environment.

In [12]:
import duckdb

# Connect to the "production" database we just created
with duckdb.connect(PROD_DB_PATH, read_only=True) as con:
    # The DESCRIBE command gives us the schema information for all tables
    schema_df = con.sql("DESCRIBE;").df()

print("✅ Schema successfully extracted from 'production' database:")
print(schema_df.to_markdown(index=False))

# We can also store this for later use in prompts
schema_for_prompt = schema_df.to_markdown(index=False)

✅ Schema successfully extracted from 'production' database:
| database         | schema   | name      | column_names                                                               | column_types                                                          | temporary   |
|:-----------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|
| prod_openflights | main     | airlines  | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active']  | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False       |
|                  |          |           |                                                                            |  'VARCHAR']                                                           |             |
| prod_openflights | main     | airports  | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'long

### 4. Create the Synthetic Training Sandbox with an LLM
Now that we have the schema, we will use a large language model to generate a complete, contextually-aware synthetic dataset.

To ensure the LLM's output is structured and parseable, we will **dynamically generate a Pydantic schema** based on the `DESCRIBE` output from the previous step. This is a powerful, generic technique that adapts to any database schema.

To fine-tune our model with RFT, **we will only interact with this synthetic database.**

In [3]:
import pandas as pd
import os
from pydantic import create_model, BaseModel
from fireworks import LLM
import duckdb
import json
from dotenv import load_dotenv
from typing import List, Optional, Any, Dict, Type
import datetime
import decimal
import uuid
import math
import time


TARGET_ROW_COUNT = 100  # The number of rows to generate for each table.

# --- 1. Dynamically Create Pydantic Models from the SQL Schema ---
def map_sql_type_to_python(sql_type: str) -> Type:
    """Maps SQL data types to Python types for Pydantic models."""
    sql_type_upper = str(sql_type).upper()
    if 'DECIMAL' in sql_type_upper: return decimal.Decimal
    if 'DOUBLE' in sql_type_upper or 'FLOAT' in sql_type_upper or 'REAL' in sql_type_upper: return float
    if 'BIGINT' in sql_type_upper or 'INT' in sql_type_upper: return int
    if 'VARCHAR' in sql_type_upper or 'TEXT' in sql_type_upper or 'STRING' in sql_type_upper: return str
    if 'TIMESTAMP' in sql_type_upper: return datetime.datetime
    if 'DATE' in sql_type_upper: return datetime.date
    if 'TIME' in sql_type_upper: return datetime.time
    if 'BOOLEAN' in sql_type_upper: return bool
    if 'BLOB' in sql_type_upper or 'BYTEA' in sql_type_upper: return bytes
    if 'UUID' in sql_type_upper: return uuid.UUID
    return object

pydantic_models: Dict[str, Type[BaseModel]] = {}
table_names = schema_df['name'].unique()

for table_name in table_names:
    table_schema = schema_df[schema_df['name'] == table_name].iloc[0]
    fields: Dict[str, Any] = {}
    col_names = table_schema['column_names']
    col_types = table_schema['column_types']
    for i, col_name in enumerate(col_names):
        python_type = map_sql_type_to_python(col_types[i])
        fields[col_name] = (Optional[python_type], None)
    model_name = table_name.capitalize() + "Model"
    pydantic_models[table_name] = create_model(model_name, **fields)

dataset_fields: Dict[str, Any] = {
    table_name: (List[model], ...) for table_name, model in pydantic_models.items()
}
SyntheticDataset = create_model('SyntheticDataset', **dataset_fields)
print("✅ Dynamically created Pydantic models for all tables.")


# --- 2. Define Total Row Counts and Chunking Strategy ---
TOTAL_ROW_COUNTS = {name: TARGET_ROW_COUNT for name in table_names}
ROWS_PER_API_CALL = 2 # Ask for data in small, safe chunks
print("\n✅ Data Generation Plan:")
print(f" - Target rows per table: {list(TOTAL_ROW_COUNTS.values())[0]}")
print(f" - Will make API calls asking for {ROWS_PER_API_CALL} rows/call until target is met.")


# --- 3. Setup LLM and Loop to Generate Data in Chunks ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
load_dotenv()
llm = LLM(model="accounts/fireworks/models/deepseek-v3", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))

all_synthetic_data: Dict[str, List[Dict]] = {name: [] for name in table_names}
chunk_row_counts = {name: ROWS_PER_API_CALL for name in table_names}

base_generation_prompt = f"""
You are a highly intelligent AI data generator. Your task is to create a realistic, synthetic dataset based on the provided database schema.
The data you generate must be internally consistent. For example, an `airline_id` in a `routes` table must correspond to an existing `airline_id` in an `airlines` table within this same generated chunk.
This applies to any schema you might be working with, not just airlines.
You must generate a single JSON object that strictly adheres to the provided JSON schema.

The database schema is as follows:
{schema_for_prompt}
"""

call_count = 0
# Loop until all tables have at least the desired number of rows
while not all(len(rows) >= TOTAL_ROW_COUNTS[name] for name, rows in all_synthetic_data.items()):
    call_count += 1
    print(f"\n📞 --- Generating data chunk #{call_count} ---")
    
    # --- Create a summary of existing data to guide the LLM ---
    existing_data_summary = ""
    if any(len(rows) > 0 for rows in all_synthetic_data.values()):
        summary_parts = ["\nYou have already generated the following data. Do NOT generate rows that are substantially similar to these examples. Create new, distinct data.\n"]
        for table_name, rows in all_synthetic_data.items():
            if rows:
                summary_parts.append(f"\n--- Existing data in '{table_name}' table ---")
                df = pd.DataFrame(rows)
                if len(df.columns) > 10:
                    df = df.iloc[:, :10]
                markdown_summary = df.to_markdown(index=False, tablefmt="grid")
                if markdown_summary:
                    summary_parts.append(markdown_summary)
        existing_data_summary = "\n".join(summary_parts)


    # --- Construct the final prompt for this iteration ---
    final_prompt = (
        base_generation_prompt +
        existing_data_summary +
        f"\n\nNow, generate a NEW JSON object with a key for each table. The number of new rows for each table should be:\n" +
        json.dumps(chunk_row_counts, indent=2)
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SyntheticDataset", "schema": SyntheticDataset.model_json_schema()}},
        temperature=0.7
    )

    choice = response.choices[0]
    response_content = choice.message.content

    if choice.finish_reason == "length":
        print(f"⚠️ WARNING: Chunk #{call_count} was truncated. Skipping.")
        continue
    if not response_content:
        print(f"⚠️ WARNING: Received empty content for chunk #{call_count}. Skipping.")
        continue

    try:
        chunk_data = json.loads(response_content)
        print(f"✅ Received and parsed chunk #{call_count}.")
        for table_name, rows in chunk_data.items():
            if table_name in all_synthetic_data and rows:
                all_synthetic_data[table_name].extend(rows)
        # Log progress
        for name, rows in all_synthetic_data.items():
             print(f"   - '{name}': {len(rows)} / {TOTAL_ROW_COUNTS[name]} rows")
    except json.JSONDecodeError as e:
        print(f"❌ ERROR: Failed to parse JSON for chunk #{call_count}. Reason: {e}. Skipping.")
    
    time.sleep(1)

# --- 4. Deduplicate and Write to DB ---
print("\n✨ Data generation complete. Aggregating, deduplicating, and saving to database...")

synthetic_data = all_synthetic_data
print("\n--- Deduplicating generated data ---")
for table_name, rows in synthetic_data.items():
    if not rows: continue
    initial_count = len(rows)
    df = pd.DataFrame(rows).drop_duplicates()
    final_count = len(df)
    synthetic_data[table_name] = df.to_dict('records')
    print(f" - Table '{table_name}': Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Final trim to ensure exact counts
for table_name, total_rows_needed in TOTAL_ROW_COUNTS.items():
    if table_name in synthetic_data:
        synthetic_data[table_name] = synthetic_data[table_name][:total_rows_needed]

with duckdb.connect(SYNTHETIC_DB_PATH) as con:
    for table_name, rows in synthetic_data.items():
        if rows:
            df = pd.DataFrame(rows)
            schema_cols = schema_df[schema_df['name'] == table_name].iloc[0]['column_names']
            for col in schema_cols:
                if col not in df.columns: df[col] = None
            df = df[schema_cols]
            con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
    
    print(f"\n✅ Synthetic training sandbox created at: {SYNTHETIC_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())

✅ Dynamically created Pydantic models for all tables.

✅ Data Generation Plan:
 - Target rows per table: 100
 - Will make API calls asking for 2 rows/call until target is met.

📞 --- Generating data chunk #1 ---
✅ Received and parsed chunk #1.
   - 'airlines': 2 / 100 rows
   - 'airports': 2 / 100 rows
   - 'countries': 2 / 100 rows
   - 'planes': 2 / 100 rows
   - 'routes': 2 / 100 rows

📞 --- Generating data chunk #2 ---
✅ Received and parsed chunk #2.
   - 'airlines': 4 / 100 rows
   - 'airports': 4 / 100 rows
   - 'countries': 4 / 100 rows
   - 'planes': 4 / 100 rows
   - 'routes': 4 / 100 rows

📞 --- Generating data chunk #3 ---
✅ Received and parsed chunk #3.
   - 'airlines': 6 / 100 rows
   - 'airports': 6 / 100 rows
   - 'countries': 6 / 100 rows
   - 'planes': 6 / 100 rows
   - 'routes': 6 / 100 rows

📞 --- Generating data chunk #4 ---
✅ Received and parsed chunk #4.
   - 'airlines': 8 / 100 rows
   - 'airports': 8 / 100 rows
   - 'countries': 8 / 100 rows
   - 'planes': 8 / 1

### 5. Validate the Synthetic Sandbox
Let's run a few queries against our new synthetic database to ensure the LLM did a good job generating plausible, interconnected data. We expect to see non-empty results from these joins, which confirms that the referential integrity is holding up.

In [4]:
import duckdb
from tabulate import tabulate

# Connect to the synthetic database
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    
    # Get the list of all tables created
    all_tables = [table[0] for table in con.sql("SHOW TABLES;").fetchall()]
    
    # Select the first 5 tables to display (or all if fewer than 5)
    tables_to_validate = all_tables[:5]

    print("--- Validating the first few tables in the synthetic sandbox ---\n")

    # Execute and print results for the selected tables
    for table_name in tables_to_validate:
        print(f"--- SELECT * FROM {table_name} LIMIT 10; ---")
        try:
            result_df = con.sql(f"SELECT * FROM {table_name} LIMIT 10;").df()
            if not result_df.empty:
                print(tabulate(result_df, headers='keys', tablefmt='psql'))
            else:
                print(f"(Table '{table_name}' is empty)")
        except Exception as e:
            print(f"Query failed for table '{table_name}': {e}")
        print("\n")

--- Validating the first few tables in the synthetic sandbox ---

--- SELECT * FROM airlines LIMIT 10; ---
+----+--------------+------------------+---------+--------+--------+----------------+----------------+----------+
|    |   airline_id | name             | alias   | iata   | icao   | callsign       | country        | active   |
|----+--------------+------------------+---------+--------+--------+----------------+----------------+----------|
|  0 |            1 | SkyHigh Airlines | SHA     | SH     | SKY    | SKYHIGH        | United States  | Y        |
|  1 |            2 | Oceanic Airways  | OA      | OC     | OCN    | OCEANIC        | United Kingdom | Y        |
|  2 |            3 | Global Wings     | GW      | GL     | GLB    | GLOBALWINGS    | Canada         | Y        |
|  3 |            4 | Pacific Horizon  | PH      | PA     | PHZ    | PACIFICHORIZON | Australia      | Y        |
|  4 |            5 | Arctic Air       | AA      | AR     | ARC    | ARCTIC         | Canada   

### 5. Generate Example SQL Queries
With our synthetic database in place, the next step is to create a set of synthetic SQL queries. These SQL queries will be executed against our database of synthetic data to get the ground truth labels for RFT. Furthermore, these same SQL queries will be used as input to an LLM to generate queries in natural language. This will enable us to form our final RFT dataset, which pairs natural language queries with ground truth results from the database.

> **Real World 🌍**: You would use a historical log of real SQL queries that have been run against your production database. These logs are the most valuable source of training data because they represent the *actual* way your users query your data.

In [None]:
import pandas as pd
import json
import time
from pydantic import BaseModel, Field
from typing import List
from fireworks import LLM

# --- 1. Define Generation Parameters and Pydantic Model ---
llm = LLM(model="accounts/fireworks/models/qwen3-coder-480b-a35b-instruct", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))  # Use Qwen3-coder for SQL queries
TOTAL_QUERIES_TO_GENERATE = 400
QUERIES_PER_API_CALL = 10

class SqlQueryBatch(BaseModel):
    queries: List[str] = Field(description=f"A list of exactly {QUERIES_PER_API_CALL} unique and diverse SQL queries.")

print(f"🎯 Goal: Generate {TOTAL_QUERIES_TO_GENERATE} unique queries in batches of {QUERIES_PER_API_CALL}.")

# --- 2. Get Clean Schema From Synthetic DB ---
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    schema_df = con.sql("DESCRIBE;").df()
    schema_for_prompt = schema_df.to_markdown(index=False)

# --- 3. Setup Base Prompt and Generation Loop ---
base_query_generation_prompt = f"""
You are an expert SQL data analyst. Your task is to generate unique and diverse SQL queries based on the database schema provided.
The queries should be realistic and cover a range of complexities and SQL features (JOINS, GROUP BY, aggregates, etc.).
Ensure you break ties with ORDER BY clauses so that the same queries produce the same results when executed against the database.
Write on the SQL query and nothing else.
Ensure the generated SQL is valid for DuckDB.

**Database Schema:**
{schema_for_prompt}
"""

all_generated_queries = []
# Loop until we have enough queries
while len(all_generated_queries) < TOTAL_QUERIES_TO_GENERATE:
    print(f"\n📞 --- Generating batch #{len(all_generated_queries) // QUERIES_PER_API_CALL + 1} ---")

    # Create a summary of queries generated so far to prevent duplicates
    existing_queries_summary = ""
    if all_generated_queries:
        summary_parts = ["\nYou have already generated the following queries. Generate NEW, DISTINCT queries that are not on this list and cover different analytic scenarios.\n"]
        for i, q in enumerate(all_generated_queries):
            summary_parts.append(f"{i+1}. {q}")
        existing_queries_summary = "\n".join(summary_parts)

    # Construct the final prompt for this iteration
    final_prompt = (
        base_query_generation_prompt +
        existing_queries_summary +
        f"\n\nNow, generate {QUERIES_PER_API_CALL} new and unique SQL queries. Return your response as a single JSON object adhering to the specified schema."
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SqlQueryBatch", "schema": SqlQueryBatch.model_json_schema()}},
        temperature=0.8
    )

    response_content = response.choices[0].message.content
    if response_content:
        try:
            new_queries = json.loads(response_content).get("queries", [])
            all_generated_queries.extend(new_queries)
            print(f"   - Received {len(new_queries)} new queries. Total now: {len(all_generated_queries)} / {TOTAL_QUERIES_TO_GENERATE}")
        except json.JSONDecodeError as e:
            print(f"❌ ERROR: Failed to parse generated queries in this batch: {e}")
    
    time.sleep(1) # Be nice to the API

# --- 4. Deduplicate, Trim, and Save --- 
print("\n✨ Generation complete. Deduplicating and saving...")
initial_count = len(all_generated_queries)
# Simple, fast deduplication preserving order
unique_queries = list(dict.fromkeys(all_generated_queries))
final_count = len(unique_queries)
print(f" - Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Trim to the exact number we need
final_queries = unique_queries[:TOTAL_QUERIES_TO_GENERATE]

# Save the final list to a file
QUERIES_FILE_PATH = "data/generated_queries.json"
with open(QUERIES_FILE_PATH, 'w') as f:
    json.dump({"queries": final_queries}, f, indent=2)

print(f"\n✅ Successfully saved {len(final_queries)} unique queries to `{QUERIES_FILE_PATH}`.")
print("\n--- Here are a few examples: ---")
for query in final_queries[:5]:
    print(f"- {query}")

🎯 Goal: Generate 400 unique queries in batches of 10.

📞 --- Generating batch #1 ---
   - Received 10 new queries. Total now: 10 / 400

📞 --- Generating batch #2 ---
   - Received 10 new queries. Total now: 20 / 400

📞 --- Generating batch #3 ---
   - Received 10 new queries. Total now: 30 / 400

📞 --- Generating batch #4 ---
   - Received 10 new queries. Total now: 40 / 400

📞 --- Generating batch #5 ---
   - Received 10 new queries. Total now: 50 / 400

📞 --- Generating batch #6 ---
   - Received 10 new queries. Total now: 60 / 400

📞 --- Generating batch #7 ---
   - Received 10 new queries. Total now: 70 / 400

📞 --- Generating batch #8 ---
   - Received 10 new queries. Total now: 80 / 400

📞 --- Generating batch #9 ---
   - Received 10 new queries. Total now: 90 / 400

📞 --- Generating batch #10 ---
   - Received 10 new queries. Total now: 100 / 400

📞 --- Generating batch #11 ---
   - Received 10 new queries. Total now: 110 / 400

📞 --- Generating batch #12 ---
   - Received 10 ne

### 6. Execute Queries to Get Ground-Truth Answers
Now we will act as the "system" and run the queries we just generated against our synthetic sandbox. The output of each query is the **ground-truth result**. During Reinforcement Fine-Tuning, our model will be rewarded if the SQL it writes produces this exact same result.

> **Real World 🌍**: You would run your real historical queries against the synthetic database we previously created. The correctness of the data is not a concern here, as our aim is to see what a correct query would have generated, so we can compare it to our LLM's generations during the RFT process.

In [14]:
import duckdb
import json
import pandas as pd

# --- 1. Define File Paths ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
QUERIES_FILE_PATH = "data/generated_queries.json"
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"

# --- 2. Load Generated Queries ---
with open(QUERIES_FILE_PATH, 'r') as f:
    queries_data = json.load(f)
    queries_to_execute = queries_data.get("queries", [])

print(f"Loaded {len(queries_to_execute)} queries to execute.")

# --- 3. Execute Queries and Store Results ---
ground_truth_results = []
successful_executions = 0
failed_executions = 0

print("Executing queries against the synthetic database...")
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    for query in queries_to_execute:
        try:
            # Execute the query and convert the result to a pandas DataFrame
            result_df = con.sql(query).df()

            # Replace any NaN/NaT values with None, which serializes to JSON `null`
            result_df = result_df.astype(object).where(pd.notna(result_df), None)
            
            result_records = result_df.to_dict('records')
            
            # Pair the query with its result
            ground_truth_results.append({
                "query": query,
                "result": result_records
            })
            successful_executions += 1
        except Exception as e:
            # The LLM might have occasionally generated a slightly invalid query
            print(f"⚠️  Skipping query due to execution error: {query}\n   Error: {e}\n")
            failed_executions += 1

print(f"\nExecution complete. Success: {successful_executions}, Failed: {failed_executions}.")

# --- 4. Save the Ground-Truth Data ---
with open(GROUND_TRUTH_FILE_PATH, 'w') as f:
    for entry in ground_truth_results:
        f.write(json.dumps(entry) + '\n')

print(f"\n✅ Successfully saved {len(ground_truth_results)} ground-truth results to `{GROUND_TRUTH_FILE_PATH}`.")

# --- 5. Print an Example ---
if ground_truth_results:
    print("\n--- Example ground_truth_results dataset entry ---")
    print(json.dumps(ground_truth_results[0], indent=2))

Loaded 263 queries to execute.
Executing queries against the synthetic database...

Execution complete. Success: 263, Failed: 0.

✅ Successfully saved 263 ground-truth results to `data/ground_truth_results.jsonl`.

--- Example ground_truth_results dataset entry ---
{
  "query": "SELECT country, COUNT(*) AS airline_count FROM airlines GROUP BY country ORDER BY airline_count DESC, country ASC",
  "result": [
    {
      "country": "Canada",
      "airline_count": 10
    },
    {
      "country": "Sweden",
      "airline_count": 10
    },
    {
      "country": "Kenya",
      "airline_count": 9
    },
    {
      "country": "United States",
      "airline_count": 9
    },
    {
      "country": "Australia",
      "airline_count": 8
    },
    {
      "country": "Spain",
      "airline_count": 6
    },
    {
      "country": "Italy",
      "airline_count": 4
    },
    {
      "country": "Switzerland",
      "airline_count": 4
    },
    {
      "country": "Finland",
      "airline_count":

### 7. Generate Natural Language Questions for Final RFT Training Data
We now have pairs of `(SQL Query, Ground-Truth Result)`. The final piece missing from our training data is the user's input: a question in natural language. This is because our final goal is to use RFT to tune an LLM to map from a natural language question to a SQL query, having the reward signal be the actual result of the query, rather than just the query itself. This is important because there are many ways to write the same SQL query that yield the same, correct result.

Thus, we will use an LLM once again to translate our "historical" SQL queries into plausible questions a business user might ask, corresponding to that query. This will yield our final training dataset in the format: `(Natural Language Question, SQL Query, Ground-Truth Result)`. Note that the SQL queries themselves will not be used as part of the RFT job itself, but are useful for debugging our evaluation function (more details in a later section).

> **Real World 🌍**: You might not need this step! If you have logs that already link user questions to the queries they ran (e.g., from a BI tool's search bar), you can use those directly. If not, this LLM-based translation is a powerful technique to bootstrap your training data.

In [15]:
import json
import time
import jsonlines
from typing import List
import random
from fireworks import LLM

# --- 1. Define File Paths and Parameters ---
llm = LLM(model="accounts/fireworks/models/qwen3-coder-480b-a35b-instruct", deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"
FINAL_TRAINING_DATA_PATH = "data/final_rft_sql_train_data.jsonl"
FINAL_TEST_DATA_PATH = "data/final_rft_sql_test_data.jsonl"

# --- 2. Load Ground-Truth Data ---
query_result_pairs = []
with jsonlines.open(GROUND_TRUTH_FILE_PATH) as reader:
    for obj in reader:
        query_result_pairs.append(obj)

print(f"Loaded {len(query_result_pairs)} query-result pairs.")

# --- 3. Use LLM to Generate Natural Language Questions ---
nl_generation_prompt_template = f"""
You are an expert data analyst who is great at translating SQL queries into plain English.
Based on the database schema and the provided SQL query, what is a natural language question a business user would ask to get this information?
Ensure that the question is precise enough to accurately map to the corresponding SQL query.

**Database Schema:**
{schema_for_prompt}

**SQL Query:**
```sql
{{query}}
```

Provide only the user's question, without any preamble or explanation.
"""

# The system prompt that will be included in the final training data for the RFT job.
# It gives the model its instructions at inference time.
rft_system_prompt = f"""
You are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide any explanation or text other than the SQL query itself.

**Database Schema:**
{schema_for_prompt}
"""

final_generated_data = []
print(f"Generating natural language questions and formatting for RFT for {len(query_result_pairs)} queries...")

for i, pair in enumerate(query_result_pairs):
    print(f" - Processing query {i+1}/{len(query_result_pairs)}...")
    query = pair['query']
    ground_truth = pair['result']
    nl_generation_prompt = nl_generation_prompt_template.format(query=query)
    
    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": nl_generation_prompt}],
        temperature=0.5
    )
    
    nl_question = response.choices[0].message.content
    if nl_question:  # Only include the entry if the LLM generated a question
        # Assemble the final data structure
        rft_entry = {
            "messages": [
                {"role": "system", "content": rft_system_prompt},
                {"role": "user", "content": nl_question.strip()},
                {"role": "assistant", "content": query}
            ],
            "ground_truth": ground_truth  # The ground-truth result for the evaluator
        }
        final_generated_data.append(rft_entry)
    
    time.sleep(1) # Be nice to the API

# --- 4. Shuffle and Split the Dataset ---
print(f"\nGenerated {len(final_generated_data)} total examples. Now splitting into train and test sets.")
random.seed(42)
random.shuffle(final_generated_data)

split_index = int(len(final_generated_data) * 0.8)
train_data = final_generated_data[:split_index]
test_data = final_generated_data[split_index:]

print(f"Train set size: {len(train_data)}")
print(f"Test set size: {len(test_data)}")

# --- 5. Save the Final RFT-Ready Datasets ---
with jsonlines.open(FINAL_TRAINING_DATA_PATH, mode='w') as writer:
    writer.write_all(train_data)
print(f"\n✅ Successfully saved training dataset to `{FINAL_TRAINING_DATA_PATH}`.")

with jsonlines.open(FINAL_TEST_DATA_PATH, mode='w') as writer:
    writer.write_all(test_data)
print(f"✅ Successfully saved test dataset to `{FINAL_TEST_DATA_PATH}`.")

# --- 6. Print an Example ---
if train_data:
    print("\n--- Example RFT training entry ---")
    print(json.dumps(train_data[0], indent=2))


Loaded 263 query-result pairs.
Generating natural language questions and formatting for RFT for 263 queries...
 - Processing query 1/263...
 - Processing query 2/263...
 - Processing query 3/263...
 - Processing query 4/263...
 - Processing query 5/263...
 - Processing query 6/263...
 - Processing query 7/263...
 - Processing query 8/263...
 - Processing query 9/263...
 - Processing query 10/263...
 - Processing query 11/263...
 - Processing query 12/263...
 - Processing query 13/263...
 - Processing query 14/263...
 - Processing query 15/263...
 - Processing query 16/263...
 - Processing query 17/263...
 - Processing query 18/263...
 - Processing query 19/263...
 - Processing query 20/263...
 - Processing query 21/263...
 - Processing query 22/263...
 - Processing query 23/263...
 - Processing query 24/263...
 - Processing query 25/263...
 - Processing query 26/263...
 - Processing query 27/263...
 - Processing query 28/263...
 - Processing query 29/263...
 - Processing query 30/263..

### 8. 🛰️ Deploy an MCP Server for the Synthetic Data
Now, we'll start a remote server that speaks the Model Context Protocol (MCP). This server will wrap our synthetic DuckDB database, providing a standardized way for any external tool—in our case, the Fireworks RFT evaluator—to interact with it.
> Real World 🌍: This pattern is directly applicable. You would run a similar MCP server to provide a secure, read-only interface to a production database replica or a data warehouse, allowing the fine-tuning process to happen without granting direct database credentials to the training environment.

10. a) Create a server script in this project's root directory (`run_mcp_server.py`). This Python script starts our database server. It is configured to be read-only.

```python
    import os, contextlib, uvicorn
    from starlette.applications import Starlette
    from starlette.routing import Mount
    from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
    from mcp_server_motherduck import build_application

    DB = "data/synthetic_openflights.db"          # ← path from previous steps
    PORT = int(os.environ.get("PORT", 8080))        # Cloud Run injects $PORT

    # 1️⃣ Build the core SQL-aware MCP server (read-only for safety).
    server, _ = build_application(db_path=DB, read_only=True)

    # 2️⃣ Wrap it so HTTP clients can talk to it (ASGI handler).
    sess = StreamableHTTPSessionManager(app=server, event_store=None, stateless=True)

    async def handler(scope, receive, send):
        await sess.handle_request(scope, receive, send)

    @contextlib.asynccontextmanager
    async def lifespan(app):
        async with sess.run():
            yield                                        # keep sessions alive

    # 3️⃣ Starlette turns that handler into a full ASGI app Uvicorn can serve.
    app = Starlette(routes=[Mount("/mcp", app=handler)], lifespan=lifespan)

    if __name__ == "__main__":
        print(f"🔥 MCP endpoint → http://0.0.0.0:{PORT}/mcp")
        uvicorn.run(app, host="0.0.0.0", port=PORT)
```


### 9. ☁️ Set Up Google Cloud CLI & .gcloudignore
We'll first set up the Google Cloud CLI and authenticate.

> **Real World 🌍**  
> You would follow along here in the same way

9. a) **Install** the SDK (macOS/Linux):

      ```bash
      curl -sSL https://sdk.cloud.google.com | bash
      exec -l $SHELL  # reload shell so 'gcloud' is available
      ```

<br>

9. b) **Log in** (creates local access token):
      ```bash
      gcloud auth login
      ```

<br>

9. c) **Set your active project desired gcloud project**:
      ```bash
      gcloud config set project < YOUR_PROJECT_ID >  # set up project in gcloud console before running this if not already done
      ```

### 10. 📦 Containerize & Deploy the MCP Server  
We’ll build a Docker image and push it straight to Cloud Run.  
Remember to replace **`YOUR_PROJECT_ID`** with the project you actually want to bill.

> **Real World 🌍**  
> You would follow along in the same way here.

10. a) Create `mcp_requirements.txt` containing the following:

<br>

```bash
mcp
mcp-server-motherduck
duckdb
uvicorn
starlette
```

<br>

10. b) Create a `Dockerfile` (no extension) containing the following
```bash
base
FROM python:3.11-slim
WORKDIR /app

COPY mcp_requirements.txt .
RUN pip install --no-cache-dir -r mcp_requirements.txt

COPY run_mcp_server.py .
COPY data/synthetic_openflights.db ./data/

EXPOSE 8080

CMD ["python", "run_mcp_server.py"]
```

<br>

10. c) Create a .gcloudignore file in your root dir (to only deploy files needed for MCP server) containing:
```bash
# .gcloudignore

# 1. Ignore EVERYTHING in the directory by default.
*

# 2. Now, create exceptions for ONLY the files needed by the Dockerfile.
# The "!" character means "do not ignore this file".

# The Dockerfile itself is needed for the build process.
!Dockerfile

# The files explicitly copied by your Dockerfile:
!mcp_requirements.txt
!run_mcp_server.py

# 3. To include a specific file in a subdirectory, use this
#    three-line pattern to un-ignore the directory, re-ignore its
#    contents, and then un-ignore the specific file.
!data/
data/*
!data/synthetic_openflights.db
```

<br>

10. d) Deploy your MCP server as a Cloud Run app by running (from your project root):
```bash
FIREWORKS_API_KEY=$(grep FIREWORKS_API_KEY .env | cut -d '=' -f2) reward-kit deploy-mcp \
--id mcp-sql-rft-server \
--dockerfile Dockerfile \
--port 8080 \
--gcp-project < YOUR_GCP_PROJECT_ID > \
--gcp-region < YOUR_GCP_REGION >
```

<br>

10. e) Test that your MCP server is working as expected by running the following from your terminal:
10. e) i. To get your MCP server's URL:
```bash
gcloud run services describe mcp-sql-rft-server \
--project < YOUR_GCP_PROJECT_ID > \
--region < YOUR_GCP_REGION > \
--format="value(status.url)"
```

10. e) ii. (optional) To check the names of the MCP server's available tools:
```bash
curl -X POST "< YOUR_MCP_SERVER_URL_FROM_STEP_i >/mcp/" \
-H "Content-Type: application/json" \
-H "Accept: application/json, text/event-stream" \
-d '{
    "id": "list-tools-1",
    "jsonrpc": "2.0",
    "method": "tools/list",
    "params": {
        "session": {"id": "test-from-my-laptop"}
    }
}'
```
>Note that the above is a generally useful way to check an MCP server's tools.
>In this case, the tool of interest is the "query" tool.

10. e) iii. To send a test request to the MCP server:
```bash
curl -X POST "< YOUR_MCP_SERVER_URL_FROM_STEP_i >/mcp/" \
-H "Content-Type: application/json" \
-H "Accept: application/json, text/event-stream" \
-d '{
    "id": "query-1",
    "jsonrpc": "2.0",
    "method": "tools/call",
    "params": {
        "session": {"id": "test-from-my-laptop"},
        "name": "query",
        "arguments": {
            "query": "SELECT COUNT(*) FROM airlines;"
        }
    }
}'
```


### 11. 📦 Define an evaluation function for RFT
Here, we define an `evaluate` function for RFT, which will interface with our MCP server. Note that you will not directly execute the function here, but will use it as part of the Fireworks Evaluations UI.
Ensure that you set MCP_SERVER_URL to be your actual MCP server URL from step 10. e) i.

> **Real World 🌍**  
> You would follow along in the same way here.

In [None]:
import requests
import json
import math

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end

def evaluate(messages: list[dict], ground_truth: list[dict], **kwargs) -> dict:
    """
    Evaluates the model's generated SQL query by executing it against a live
    MCP server and comparing the result with the ground_truth.
    """
    
    def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
        """
        Parses a DuckDB-style ASCII table string into a list of dictionaries.
        This version robustly handles 'NULL' values and empty strings.
        """
        lines = table_string.strip().split('\n')
        content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
        if len(content_lines) < 2:
            return []
        
        header_raw = [h.strip() for h in content_lines[0].split('|')[1:-1]]
        data_lines = content_lines[1:]
        
        if len(data_lines) > 0:
            try:
                first_data_values = [v.strip() for v in data_lines[0].split('|')[1:-1]]
                if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
                    data_lines = data_lines[1:]
            except IndexError:
                pass

        rows = []
        for line in data_lines:
            try:
                values_raw = [v.strip() for v in line.split('|')[1:-1]]
                if len(values_raw) == len(header_raw):
                    row_dict = {}
                    for i, header in enumerate(header_raw):
                        value_str = values_raw[i]
                        if value_str.upper() == 'NULL' or value_str == '':
                            row_dict[header] = None
                            continue
                        
                        try:
                            if '.' in value_str:
                                row_dict[header] = float(value_str)
                            else:
                                row_dict[header] = int(value_str)
                        except (ValueError, TypeError):
                            row_dict[header] = value_str
                    rows.append(row_dict)
            except IndexError:
                continue
        return rows

    # --- 1. Get MCP Server URL from Environment Variables ---
    mcp_server_url = MCP_SERVER_URL
    if not mcp_server_url:
        return {"score": 0, "is_score_valid": False, "reason": "FATAL: MCP_SERVER_URL environment variable is not set."}

    # --- 2. Get the SQL query from the model's response ---
    sql_query = messages[-1]['content'].strip()
    if not sql_query:
        return {"score": 0, "reason": "Model returned an empty response."}

    # --- 3. Execute the Query against the MCP Server ---
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json, text/event-stream"
    }
    payload = {
        "id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call",
        "params": {"session": {"id": "stateless-eval-session"}, "name": "query", "arguments": {"query": sql_query}}
    }
    try:
        with requests.post(f"{mcp_server_url}/mcp/", headers=headers, json=payload, timeout=15, stream=True) as response:
            response.raise_for_status()
            response_data = None
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode('utf-8')
                    if decoded_line.startswith('data:'):
                        json_part = decoded_line[len('data:'):].strip()
                        if json_part:
                            response_data = json.loads(json_part)
                            break
            if response_data is None:
                return {"score": 0, "reason": "Could not find JSON data in event stream response from MCP server."}

        if "error" in response_data:
            return {"score": 0, "reason": f"SQL execution failed. Error: {response_data['error'].get('message', 'Unknown')}"}

        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)

    except requests.exceptions.RequestException as e:
        return {"score": 0, "reason": f"Network error calling MCP server: {e}"}
    except json.JSONDecodeError as e:
        return {"score": 0, "reason": f"JSON decode error from server response: {e}"}
    except (KeyError, IndexError):
        return {"score": 0, "reason": f"Failed to parse predicted result from MCP server response structure. Data found: {json.dumps(response_data)}"}
    except Exception as e:
        return {"score": 0, "reason": f"An unexpected error occurred during query execution: {e}"}

    # --- 4. Process Ground Truth ---
    if not isinstance(ground_truth, list):
        return {"score": 0, "is_score_valid": False, "reason": f"FATAL: ground_truth is not a list as expected. Got type: {type(ground_truth)}"}
    ground_truth_rows = ground_truth


    # --- 5. Comparison Logic ---
    def normalize_and_stringify(v):
        """
        Normalizes numbers and None before string conversion.
        """
        if v is None:
            return str(v)
        
        if isinstance(v, float) and not math.isinf(v) and not math.isnan(v) and v == int(v):
            v = int(v)
        return str(v)

    try:
        gt_values = sorted([sorted(map(normalize_and_stringify, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(normalize_and_stringify, row.values())) for row in predicted_rows])

        if gt_values == predicted_values:
            score = 1
            reason = "Success: The SQL query produced the exact expected result."
        else:
            score = 0
            gt_json = json.dumps(ground_truth_rows)
            pred_json = json.dumps(predicted_rows)
            reason = f"Incorrect result. Expected (from ground_truth): {gt_json}. Got (from query): {pred_json}."
    
    except Exception as e:
        return {"score": 0, "reason": f"Error during result comparison: {e}"}

    return {"score": score, "reason": reason}

### 12. 🧪 Test English -> SQL of a base model without fine-tuning
Here, we test a base model's natural language to SQL capability without fine-tuning for a single example.
Ensure that you set MCP_SERVER_URL to be your actual MCP server URL from step 10. e) i.

> **Real World 🌍**  
> You would follow along in the same way here.

In [None]:
import requests
import json
import os
from fireworks import LLM

# --- 1. SETUP: Define API keys, server URLs, and the model to use ---

# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
# You can get one from https://fireworks.ai
if "FIREWORKS_API_KEY" not in os.environ:
    print("FATAL: FIREWORKS_API_KEY environment variable not set.")
    # If not set, you can hardcode it here for testing, but this is not recommended:
    # os.environ["FIREWORKS_API_KEY"] = "YOUR_API_KEY_HERE"

# The model we'll use to generate the SQL. This acts as our "base" model.
LLM_MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
llm = LLM(model=LLM_MODEL, deployment_type="serverless", api_key=os.getenv("FIREWORKS_API_KEY"))

# The URL for your running MCP server.
MCP_SERVER_URL = None  # PUT MCP SERVER URL HERE without the /mcp/ suffix at the end


# --- 2. LOAD THE EXAMPLE DATA ---

# This is the example data you provided.
DATASET_FILE_PATH = "data/final_rft_sql_train_data_v3.jsonl"
ROW_INDEX_TO_TEST = 0  # 0 is the first row, 1 is the second row, etc.

EXAMPLE_DATA = None
try:
    with open(DATASET_FILE_PATH, 'r') as f:
        for i, line in enumerate(f):
            if i == ROW_INDEX_TO_TEST:
                EXAMPLE_DATA = json.loads(line)
                break
    
    if EXAMPLE_DATA is None:
        with open(DATASET_FILE_PATH, 'r') as f:
            line_count = sum(1 for line in f)
        raise IndexError(f"row index {ROW_INDEX_TO_TEST} is out of bounds for file with {line_count} rows.")

    print(f"Successfully loaded row {ROW_INDEX_TO_TEST} from '{DATASET_FILE_PATH}'.\n")

except Exception as e:
    print(f"Warning: Could not load from file. Reason: {e}")

# If loading from file failed for any reason, use the hardcoded fallback data.
if EXAMPLE_DATA is None:
    print("Using hardcoded fallback EXAMPLE_DATA.\n")
    EXAMPLE_DATA = {
        "messages": [
            {"role": "system", "content": "\nYou are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide any explanation or text other than the SQL query itself.\n\n**Database Schema:**\n| database              | schema   | name      | column_names                                                               | column_types                                                          | temporary   |\n|:----------------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|\n| synthetic_openflights | main     | airlines  | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active']  | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False       |\n|                       |          |           |                                                                            |  'VARCHAR']                                                           |             |\n| synthetic_openflights | main     | airports  | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'longitude' | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'DOUBLE'  | False       |\n|                       |          |           |  'altitude' 'timezone' 'dst' 'tz_db' 'type' 'source']                      |  'DOUBLE' 'BIGINT' 'DOUBLE' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR']  |             |\n| synthetic_openflights | main     | countries | ['name' 'iso_code' 'dafif_code']                                           | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | planes    | ['name' 'iata' 'icao']                                                     | ['VARCHAR' 'VARCHAR' 'VARCHAR']                                       | False       |\n| synthetic_openflights | main     | routes    | ['airline' 'airline_id' 'source_airport' 'source_airport_id'               | ['VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR' 'BIGINT' 'VARCHAR'   | False       |\n|                       |          |           |  'destination_airport' 'destination_airport_id' 'codeshare' 'stops'        |  'BIGINT' 'VARCHAR']                                                  |             |\n|                       |          |           |  'equipment']                                                              |                                                                       |             |\n"},
            {"role": "user", "content": "Which countries have the most airlines, and how many airlines are there in each country, listed in descending order by the number of airlines and then alphabetically by country name?"},
            {"role": "assistant", "content": "SELECT country, COUNT(*) AS airline_count FROM airlines GROUP BY country ORDER BY airline_count DESC, country ASC"}
        ],
        "ground_truth": [{"country": "Canada", "airline_count": 10}, {"country": "Sweden", "airline_count": 10}, {"country": "Kenya", "airline_count": 9}, {"country": "United States", "airline_count": 9}, {"country": "Australia", "airline_count": 8}, {"country": "Spain", "airline_count": 6}, {"country": "Italy", "airline_count": 4}, {"country": "Switzerland", "airline_count": 4}, {"country": "Finland", "airline_count": 3}, {"country": "France", "airline_count": 3}, {"country": "Mexico", "airline_count": 3}, {"country": "Costa Rica", "airline_count": 2}, {"country": "Germany", "airline_count": 2}, {"country": "Iceland", "airline_count": 2}, {"country": "Ireland", "airline_count": 2}, {"country": "Japan", "airline_count": 2}, {"country": "Norway", "airline_count": 2}, {"country": "Singapore", "airline_count": 2}, {"country": "United Kingdom", "airline_count": 2}, {"country": "Argentina", "airline_count": 1}, {"country": "Brazil", "airline_count": 1}, {"country": "China", "airline_count": 1}, {"country": "Egypt", "airline_count": 1}, {"country": "Fiji", "airline_count": 1}, {"country": "Greece", "airline_count": 1}, {"country": "India", "airline_count": 1}, {"country": "Jordan", "airline_count": 1}, {"country": "Netherlands", "airline_count": 1}, {"country": "New Zealand", "airline_count": 1}, {"country": "Portugal", "airline_count": 1}, {"country": "Saudi Arabia", "airline_count": 1}, {"country": "South Africa", "airline_count": 1}, {"country": "Thailand", "airline_count": 1}, {"country": "United Arab Emirates", "airline_count": 1}]
    }

# Extract the prompts and ground truth from the data
system_prompt = EXAMPLE_DATA["messages"][0]["content"]
user_prompt = EXAMPLE_DATA["messages"][1]["content"]
GROUND_TRUTH_ROWS = EXAMPLE_DATA["ground_truth"]

# --- 3. HELPER FUNCTION: To parse the server's ASCII table response ---

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This is a helper function to handle the specific string output from the MCP server.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 1:
        return []
    header_raw = [h.strip() for h in content_lines[0].split('|') if h.strip()]
    data_lines = content_lines[1:]
    if len(data_lines) > 0:
        first_data_values = [v.strip() for v in data_lines[0].split('|') if v.strip()]
        if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
            data_lines = data_lines[1:]
    rows = []
    for line in data_lines:
        values_raw = [v.strip() for v in line.split('|') if v.strip()]
        if len(values_raw) == len(header_raw):
            row_dict = {}
            for i, header in enumerate(header_raw):
                value_str = values_raw[i]
                try:
                    if '.' in value_str:
                        row_dict[header] = float(value_str)
                    else:
                        row_dict[header] = int(value_str)
                except (ValueError, TypeError):
                    row_dict[header] = value_str
            rows.append(row_dict)
    return rows

# --- 4. GENERATE SQL QUERY USING THE LLM ---

print("="*20)
print("LLM QUERY GENERATION")
print("="*20)

model_generated_sql = ""
try:
    print(f"Calling model '{LLM_MODEL}' to generate SQL query...")
    
    messages_for_llm = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    
    response = llm.chat.completions.create(
        model=LLM_MODEL,
        messages=messages_for_llm,
        temperature=0.0  # Set to 0 for deterministic output
    )
    
    model_generated_sql = response.choices[0].message.content.strip()
    print("\nModel Generated SQL Query:")
    print(model_generated_sql)
    
except Exception as e:
    print(f"\nAN ERROR OCCURRED during LLM call: {e}")


# --- 5. EXECUTE GENERATED QUERY ON MCP SERVER ---

predicted_rows = []
if model_generated_sql:
    try:
        print("\n" + "="*20)
        print("MCP SERVER EXECUTION")
        print("="*20)
        print(f"Sending query to MCP server...")
        
        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {
            "id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call",
            "params": {"session": {"id": "stateless-eval-session"}, "name": "query", "arguments": {"query": model_generated_sql}}
        }

        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=20, stream=True) as response:
            response.raise_for_status()
            response_data = None
            for line in response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break
            
            if response_data is None: raise RuntimeError("No JSON data in event stream.")
            if "error" in response_data: raise RuntimeError(f"SQL Error: {response_data['error'].get('message', 'Unknown')}")

            ascii_table = response_data['result']['content'][0]['text']
            predicted_rows = parse_duckdb_ascii_table(ascii_table)
            print("\nParsed Result from Server:")
            print(json.dumps(predicted_rows, indent=2))

    except Exception as e:
        print(f"\nAN ERROR OCCURRED during MCP call: {e}")

# --- 6. FINAL COMPARISON ---
print("\n" + "="*20)
print("COMPARISON")
print("="*20)

if not predicted_rows:
    print("Skipping comparison: no rows returned from query or an error occurred.")
else:
    gt_values = sorted([sorted(map(str, row.values())) for row in GROUND_TRUTH_ROWS])
    predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])

    if gt_values == predicted_values:
        print("\n✅ GOOD RESULT: The base model generated SQL that produced the correct data.\n")
    else:
        print("\n❌ BAD RESULT: The base model's SQL produced different data than expected.\n")
        print("This is often the intended outcome when testing a base model, as it highlights what fine-tuning needs to correct.")

Successfully loaded row 0 from 'data/final_rft_sql_train_data_v3.jsonl'.

LLM QUERY GENERATION
Calling model 'accounts/fireworks/models/llama-v3p1-8b-instruct' to generate SQL query...

Model Generated SQL Query:
SELECT country, COUNT(*) FROM airlines GROUP BY country ORDER BY COUNT(*) DESC, country ASC

MCP SERVER EXECUTION
Sending query to MCP server...

Parsed Result from Server:
[
  {
    "country": "Canada",
    "count_star()": 10
  },
  {
    "country": "Sweden",
    "count_star()": 10
  },
  {
    "country": "Kenya",
    "count_star()": 9
  },
  {
    "country": "United States",
    "count_star()": 9
  },
  {
    "country": "Australia",
    "count_star()": 8
  },
  {
    "country": "Spain",
    "count_star()": 6
  },
  {
    "country": "Italy",
    "count_star()": 4
  },
  {
    "country": "Switzerland",
    "count_star()": 4
  },
  {
    "country": "Finland",
    "count_star()": 3
  },
  {
    "country": "France",
    "count_star()": 3
  },
  {
    "country": "Mexico",
    "cou

### 13. 🚀 Launch the Fine-Tuning Job & Deploy via the UI
Now we'll use the Fireworks AI web interface to take our prepared dataset and fine-tune a model. This process uses your custom `evaluate` function to teach a base model how to generate SQL correctly.

> **Real World 🌍**  
> This is the core of the RFT process. You're teaching a general-purpose model a very specific and valuable new skill using a powerful, UI-driven workflow. You may follow along as described below

As described in the [Fireworks RFT documentation](https://fireworks.ai/docs/fine-tuning/reinforcement-fine-tuning-models), the process involves uploading your data, creating an evaluator, running the job, and deploying.

<br>

**13. a) Upload Your Dataset**

1.  Navigate to the **Datasets** tab in your [Fireworks AI dashboard](https://fireworks.ai/account/datasets).
2.  Click **"Upload Dataset"**.
3.  Upload your training file: `data/final_rft_sql_train_data.jsonl`.
4.  Give it a memorable name, like `rft-sql-train-data-v1`, and save it.

<br>

**13. b) Create the Evaluator**

1.  Navigate to the **Evaluations** tab in the dashboard.
2.  Click **"Create Evaluator"**. This will open the web IDE.
3.  In the editor on the left, replace the template code with your full `evaluate` function from step 11 above. This function already contains the logic to connect to your MCP server and compare the results. You just need to add your MCP server URL to the MCP_SERVER_URL line.
4.  Save the evaluator with a name like `rft-sql-mcp-evaluator`.

<br>

**13. c) Launch the Fine-Tuning Job**

1.  Navigate to the **Fine-Tuning** tab.
2.  Click **"Fine-Tune a Model"** and select **Reinforcement**.
3.  Configure the job:
    *   **Model Selection:** Select a model, for example `accounts/fireworks/models/llama-v3p1-8b-instruct`.
    *   **Dataset:** Select the `rft-sql-train-data-v1` you uploaded.
    *   **Evaluator:** Select the `rft-sql-mcp-evaluator` you just created.
    *   **Rollout:** You can leave these as the default values.
    *   **Optional Settings:** You can leave the Model Output Name blank and get the default name, or enter a name of your choosing.
4.  You can leave the other hyperparameters as their defaults for the first run.
5.  Click **"Create Job"**.

<br>

**13. d) Monitor and Deploy**

1.  You can monitor the progress of your job in the **Fine-tuning** tab.
2.  Once the job status is `Completed`, you can deploy your model. To deploy, click "Deploy" on the top right of your fine-tuning job's page. Please note:
    -  The Model under "Select base model*" should be the one from your Reinforcement Fine-Tuning job (this should be populated automatically)
    -  Speculative decoding is an advanced technique that can improve latency, but is not needed for this use-case
    -  Feel free to make the other selections (Performance, Scaling, and Metadata) as needed; enabling autoscaling is recommended to reduce costs
3.  Find this new model and click the **Deploy** button to create an API endpoint.

<br>

**13. e) Test Your New Model!**
Once deployed, copy your new model's ID and paste it into the `LLM_MODEL` variable in the testing cell (step #12) to make sure it works as expected.

### 14. ⚖️ Evaluate Model Performance
Now for the moment of truth. We will systematically compare the performance of the original base model against our newly fine-tuned model to quantify the improvement.

We'll run both models against every entry in our training dataset (final_rft_sql_train_data_v3.jsonl). For each entry, we will:
1. Provide the same system and user prompt to both the base model and the fine-tuned model.
2. Capture the SQL query generated by each.
3. Execute each query against our live MCP server.
4. Compare the query result to the ground_truth from our dataset.
5. Keep a running score for each model.

This process will give us a clear, data-driven view of how much more accurate our model became after reinforcement fine-tuning.
> Real World 🌍
> This is a critical step in any MLOps loop. Evaluating a model on a consistent "test set" (in this case, we're re-using our training set for simplicity) is the only way to prove that your efforts have resulted in a tangible improvement.

In [None]:
import requests
import json
import os
import time
from fireworks import LLM
from tqdm.auto import tqdm
from dotenv import load_dotenv

load_dotenv()

# --- 1. SETUP: Define the models to compare, server URL, and dataset path ---

# IMPORTANT: Make sure your FIREWORKS_API_KEY is set as an environment variable.
if "FIREWORKS_API_KEY" not in os.environ:
    print("FATAL: FIREWORKS_API_KEY environment variable not set.")

# The base model you used for the fine-tuning job.
BASE_MODEL_ID = "accounts/fireworks/models/qwen2p5-7b"  # <--- Replace if you used a different base model
LARGE_BASE_MODEL_ID = "accounts/fireworks/models/qwen3-coder-480b-a35b-instruct"

# IMPORTANT: Replace this with the model ID of your new fine-tuned model.
FINE_TUNED_MODEL_ID = "accounts/<your-account-id>/models/ft-mdpe6xlm-z9vp7"  # <--- Replace with your fine-tuned model ID

MCP_SERVER_URL = None  # <--- PUT MCP SERVER URL HERE without the /mcp/ suffix at the end
DATASET_FILE_PATH = "data/final_rft_sql_test_data.jsonl"

# --- 2. Create LLM Objects ---
base_model_llm = None
large_base_model_llm = None
fine_tuned_model_llm = None
try:
    base_model_llm = LLM(model=BASE_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    large_base_model_llm = LLM(model=LARGE_BASE_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    fine_tuned_model_llm = LLM(model=FINE_TUNED_MODEL_ID, deployment_type="auto", api_key=os.getenv("FIREWORKS_API_KEY"))
    print("LLM objects for all three models created successfully.")
except Exception as e:
    print(f"FATAL: Could not create LLM objects. Error: {e}")

# --- 3. Load Dataset ---
dataset = []
if all([base_model_llm, large_base_model_llm, fine_tuned_model_llm]):
    try:
        with open(DATASET_FILE_PATH, 'r') as f:
            dataset = [json.loads(line) for line in f]
        print(f"Loaded {len(dataset)} evaluation examples from '{DATASET_FILE_PATH}'.")
    except Exception as e:
        print(f"FATAL: Could not load dataset. Error: {e}")
        dataset = []

# --- 4. HELPER AND EVALUATION FUNCTIONS ---

def parse_duckdb_ascii_table(table_string: str) -> list[dict]:
    """
    Parses a DuckDB-style ASCII table string into a list of dictionaries.
    This is a helper function to handle the specific string output from the MCP server.
    """
    lines = table_string.strip().split('\n')
    content_lines = [line for line in lines if line.strip() and not line.startswith('+')]
    if len(content_lines) < 1:
        return []
    header_raw = [h.strip() for h in content_lines[0].split('|') if h.strip()]
    data_lines = content_lines[1:]
    if len(data_lines) > 0:
        first_data_values = [v.strip() for v in data_lines[0].split('|') if v.strip()]
        if len(first_data_values) == len(header_raw) and all(v.isupper() for v in first_data_values):
            data_lines = data_lines[1:]
    rows = []
    for line in data_lines:
        values_raw = [v.strip() for v in line.split('|') if v.strip()]
        if len(values_raw) == len(header_raw):
            row_dict = {}
            for i, header in enumerate(header_raw):
                value_str = values_raw[i]
                try:
                    if '.' in value_str:
                        row_dict[header] = float(value_str)
                    else:
                        row_dict[header] = int(value_str)
                except (ValueError, TypeError):
                    row_dict[header] = value_str
            rows.append(row_dict)
    return rows

def are_results_equal(predicted_rows: list[dict], ground_truth_rows: list[dict]) -> bool:
    """
    Compares datasets by converting all values to strings and sorting them,
    which ignores row order, column order, and data types (e.g., int vs float).
    """
    try:
        gt_values = sorted([sorted(map(str, row.values())) for row in ground_truth_rows])
        predicted_values = sorted([sorted(map(str, row.values())) for row in predicted_rows])
        return gt_values == predicted_values
    except Exception:
        return False

def get_sql_and_evaluate(llm_obj, system_prompt: str, user_prompt: str, ground_truth: list[dict]) -> int:
    """
    Calls a pre-configured LLM object to get a SQL query, executes it, and compares to ground truth.
    Returns 1 for a correct result, 0 for an incorrect one.
    """
    try:
        # Step 1: Get SQL from the model
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
        response = llm_obj.chat.completions.create(messages=messages, temperature=0.0)
        sql_query = response.choices[0].message.content.strip()

        # Step 2: Execute SQL on MCP server
        headers = {"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}
        payload = {"id": "eval-query-1", "jsonrpc": "2.0", "method": "tools/call", "params": {"session": {"id": "full-eval-session"}, "name": "query", "arguments": {"query": sql_query}}}

        response_data = None
        with requests.post(f"{MCP_SERVER_URL}/mcp/", headers=headers, json=payload, timeout=30, stream=True) as mcp_response:
            mcp_response.raise_for_status()
            for line in mcp_response.iter_lines():
                if line and line.decode('utf-8').startswith('data:'):
                    json_part = line.decode('utf-8')[len('data:'):].strip()
                    if json_part:
                        response_data = json.loads(json_part)
                        break

        if response_data is None or "error" in response_data:
            return 0

        # Step 3: Parse and compare results
        ascii_table = response_data['result']['content'][0]['text']
        predicted_rows = parse_duckdb_ascii_table(ascii_table)

        return 1 if are_results_equal(predicted_rows, ground_truth) else 0
    except Exception as e:
        print(f"--> Error during evaluation for model {llm_obj.model}: {e}")
        return 0

# --- 5. RUN THE FULL EVALUATION ---

base_model_score = 0
large_base_model_score = 0
fine_tuned_model_score = 0

if dataset:
    print("\nStarting evaluation...")
    for item in tqdm(dataset, desc="Evaluating models"):
        system_prompt = item["messages"][0]["content"]
        user_prompt = item["messages"][1]["content"]
        ground_truth = item["ground_truth"]

        # Evaluate base model
        base_model_score += get_sql_and_evaluate(base_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)  # Be nice to the API

        # Evaluate large base model
        large_base_model_score += get_sql_and_evaluate(large_base_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

        # Evaluate fine-tuned model
        fine_tuned_model_score += get_sql_and_evaluate(fine_tuned_model_llm, system_prompt, user_prompt, ground_truth)
        time.sleep(1)

# --- 6. REPORT RESULTS ---

if dataset:
    total = len(dataset)
    base_accuracy = (base_model_score / total) * 100
    large_base_accuracy = (large_base_model_score / total) * 100
    tuned_accuracy = (fine_tuned_model_score / total) * 100

    print("\n" + "="*25)
    print("  EVALUATION COMPLETE")
    print("="*25)
    print(f"Total Examples: {total}\n")
    print("--- BASE MODEL ---")
    print(f"Model ID: {BASE_MODEL_ID}")
    print(f"Correct: {base_model_score}/{total}")
    print(f"Accuracy: {base_accuracy:.2f}%\n")

    print("--- LARGE BASE MODEL ---")
    print(f"Model ID: {LARGE_BASE_MODEL_ID}")
    print(f"Correct: {large_base_model_score}/{total}")
    print(f"Accuracy: {large_base_accuracy:.2f}%\n")

    print("--- FINE-TUNED MODEL ---")
    print(f"Model ID: {FINE_TUNED_MODEL_ID}")
    print(f"Correct: {fine_tuned_model_score}/{total}")
    print(f"Accuracy: {tuned_accuracy:.2f}%\n")
    
    print("="*25)
    print("  PERFORMANCE LIFT")
    print("="*25)
    print(f"Fine-Tuned vs. Base: {tuned_accuracy - base_accuracy:+.2f}%")
    print(f"Fine-Tuned vs. Large Base: {tuned_accuracy - large_base_accuracy:+.2f}%")

else:
    print("Evaluation skipped because the dataset or LLM objects could not be loaded.")

LLM objects for all three models created successfully.
Loaded 53 evaluation examples from 'data/final_rft_sql_test_data.jsonl'.

Starting evaluation...


Evaluating models: 100%|██████████| 53/53 [11:05<00:00, 12.55s/it]  


  EVALUATION COMPLETE
Total Examples: 53

--- BASE MODEL ---
Model ID: accounts/fireworks/models/qwen2p5-7b
Correct: 21/53
Accuracy: 39.62%

--- LARGE BASE MODEL ---
Model ID: accounts/fireworks/models/qwen3-235b-a22b-instruct-2507
Correct: 20/53
Accuracy: 37.74%

--- FINE-TUNED MODEL ---
Model ID: accounts/pyroworks/models/ft-mdpe6xlm-z9vp7
Correct: 28/53
Accuracy: 52.83%

  PERFORMANCE LIFT
Fine-Tuned vs. Base: +13.21%
Fine-Tuned vs. Large Base: +15.09%





### 15. ✨ Cleanup & Conclusion
Congratulations! You've successfully completed the entire Reinforcement Fine-Tuning loop. You started with just a database schema and ended with a highly specialized, performant, and data-aware AI model.

#### Cleanup
Model deployments can incur costs, so it's good practice to clean up any resources you no longer need.

*   **Check your Deployments:** Navigate to the [Deployments tab](https://app.fireworks.ai/dashboard/deployments) in your Fireworks AI dashboard. Here you can monitor and manage all your deployed models.
*   **Delete Unneeded Models:** Feel free to delete any deployments you no longer need. For example, you might have deployed the base or large-base models during the evaluation step to compare against your fine-tuned model. These can now be safely removed to save costs.

You can, of course, continue using your new fine-tuned SQL generation model for any application you see fit!

#### Conclusions
The evaluation results from the previous step highlight the power of this approach.

*   **Performance on par with massive models:** Our fine-tuned 7B parameter model performs on par with a much larger model like `qwen3-coder-480b-a35b-instruct` on this specific dataset. This is because it has been fine-tuned to understand the data schema via real query generation and execution.
*   **Efficiency Gains:** This specialized 7B model is significantly faster and cheaper to run than its 480B counterpart, offering production-grade performance at a fraction of the cost and latency.
*   **High-Level Capability on Complex Tasks:** The queries in this dataset are relatively complex, which is reflected in the final accuracy score of around 60%. This is a strong result, demonstrating that for a specialized domain, a smaller model can be tuned to achieve a level of performance comparable to elite, state-of-the-art coding models like Claude 4 Sonnet.

---

Throughout this tutorial, we demonstrated a complete, end-to-end workflow for creating a fine-tuned text-to-SQL model. We began with the absolute minimum requirement, a database schema, and used a series of LLM-driven steps to generate a safe, synthetic data sandbox. From there, we generated a rich dataset of queries and answers, which we used to fine-tune a model using the Fireworks RFT platform. The final result is a small, efficient model that can accurately query data it has never seen, a task that was previously only possible with vastly larger and more expensive models.

This pattern of **schema → synthetic data → RFT** is a secure, effective, and repeatable methodology for teaching language models to become expert users of your private data and custom APIs, without ever exposing the underlying sensitive information.