# ✈️ Natural-Language → SQL with Reinforcement-Fine-Tuning (RFT)
A full, reproducible demo that trains a 3-B parameter model to answer English questions by writing SQL, without touching real production data.

### 0. Preface – Why this matters
Large-language-model copilots are great, but even the best models don't know your private data, and can hallucinate a column that doesn’t exist.
Reinforcement-Fine-Tuning (RFT) fixes this by teaching your language model about your data and how to write accurate queries. In this tutorial you will:

| You will learn to 📚                               | By the end you’ll have 🎁                          |
| -------------------------------------------------- | -------------------------------------------------- |
| ✓ create a synthetic DB that *mimics* your schema  | `openflights.db` (<20 MB) wrapped by an MCP server |
| ✓ generate a MECE query set & ground-truth answers | `queries.json`, `gt_rows.json`                     |
| ✓ build NL ↔ SQL result training pairs             | `train_pairs.jsonl`                                |
| ✓ run an RFT job on Fireworks AI                   | a tuned **Qwen-2.5-Coder-3B-RFT** model            |
| ✓ benchmark baseline vs. RFT accuracy              | >2× exact-match gain                               |

<br>

> **A Note on Our Method: Demo vs. Real World 🌍** <br>
> Throughout this tutorial, we will be clear about what we're doing for the purpose of this self-contained demo versus what you would do in a real-world scenario.
> - **Real World 🌍**: Look for these notes to see the parallel step you would take in your own environment if you wanted to apply this workflow, typically by swapping in your own private business assets like schemas or query logs.

### 1. Development Environment Setup
**Complete these steps once in your terminal, *outside* this notebook.**

1.  **Get a Fireworks AI API Key**
    - Go to [fireworks.ai](https://fireworks.ai) and sign up.
    - Create an API key from your settings page.
    - Create a file named `.env` in your project directory and add your key:
      ```
      FIREWORKS_API_KEY="YOUR_API_KEY_HERE"
      ```

2.  **Install `uv`**
    - `uv` is a fast Python package manager from Astral. Follow the official installation instructions at [docs.astral.sh/uv/](https://docs.astral.sh/uv/).

3.  **Create a Virtual Environment and Install Packages**
    - Once `uv` is installed, create and activate a virtual environment.
    ```bash
    # Run this in your terminal
    uv venv .venv
    source .venv/bin/activate  # On Windows PowerShell: .venv\Scripts\Activate.ps1
    ```
    - Install all required packages using `uv add`.
    ```bash
    # Run this in your terminal
    uv add duckdb tabulate pandas pyarrow requests \
           faker python-dotenv \
           jsonlines fireworks-ai \
           mcp-sdk mcp-server-motherduck
    ```
After running these commands, your environment is ready. You can proceed with the cells inside this notebook.

### 2. Simulate the "Production" Database
First, we'll create a database that represents your real, populated production database. We'll download the public OpenFlights dataset and load it into a DuckDB file.

> **Real World 🌍**: You already have this! It's your live production database (or a replica). You would skip this entire step.

In [2]:
import urllib.request
import pathlib
import pandas as pd
import duckdb

# --- Download the raw data files ---
DATA_DIR = pathlib.Path("data")
DATA_DIR.mkdir(exist_ok=True)
BASE_URL = "https://raw.githubusercontent.com/jpatokal/openflights/master/data/"
FILES_TO_DOWNLOAD = {
    "airports": "airports.dat",
    "airlines": "airlines.dat",
    "routes": "routes.dat",
    "countries": "countries.dat",
    "planes": "planes.dat"
}
# Define column names as the files don't have headers
COLUMN_NAMES = {
    "airports": ["airport_id", "name", "city", "country", "iata", "icao", "latitude", "longitude", "altitude", "timezone", "dst", "tz_db", "type", "source"],
    "airlines": ["airline_id", "name", "alias", "iata", "icao", "callsign", "country", "active"],
    "routes": ["airline", "airline_id", "source_airport", "source_airport_id", "destination_airport", "destination_airport_id", "codeshare", "stops", "equipment"],
    "countries": ["name", "iso_code", "dafif_code"],
    "planes": ["name", "iata", "icao"]
}

PROD_DB_PATH = "data/prod_openflights.db"

# --- Load the real data into our "production" DuckDB ---
with duckdb.connect(PROD_DB_PATH) as con:
    for name, filename in FILES_TO_DOWNLOAD.items():
        url = f"{BASE_URL}{filename}"
        path = DATA_DIR / filename
        if not path.exists():
            urllib.request.urlretrieve(url, path)
            print(f"✅ Downloaded: {path}")

        # Load data using pandas to handle missing headers and null values
        df = pd.read_csv(path, header=None, names=COLUMN_NAMES[name], na_values=["\\N"])
        con.execute(f"CREATE OR REPLACE TABLE {name} AS SELECT * FROM df")

    print(f"\n✅ 'Production' database simulated at: {PROD_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())


✅ 'Production' database simulated at: data/prod_openflights.db
Tables created: [('airlines',), ('airports',), ('countries',), ('planes',), ('routes',)]


### 3. Acquire the Schema (No Data!)
This is a critical step. We connect to our "production" database and extract **only its schema** (the table structure, column names, and data types). We do not touch or read any of the data rows. This schema is the only artifact we need from the production environment.

In [3]:
import duckdb

# Connect to the "production" database we just created
with duckdb.connect(PROD_DB_PATH, read_only=True) as con:
    # The DESCRIBE command gives us the schema information for all tables
    schema_df = con.sql("DESCRIBE;").df()

print("✅ Schema successfully extracted from 'production' database:")
print(schema_df.to_markdown(index=False))

# We can also store this for later use in prompts
schema_for_prompt = schema_df.to_markdown(index=False)

✅ Schema successfully extracted from 'production' database:
| database         | schema   | name      | column_names                                                               | column_types                                                          | temporary   |
|:-----------------|:---------|:----------|:---------------------------------------------------------------------------|:----------------------------------------------------------------------|:------------|
| prod_openflights | main     | airlines  | ['airline_id' 'name' 'alias' 'iata' 'icao' 'callsign' 'country' 'active']  | ['BIGINT' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' 'VARCHAR' | False       |
|                  |          |           |                                                                            |  'VARCHAR']                                                           |             |
| prod_openflights | main     | airports  | ['airport_id' 'name' 'city' 'country' 'iata' 'icao' 'latitude' 'long

### 4. Create the Synthetic Training Sandbox with an LLM
Now that we have the schema, we will use a large language model to generate a complete, contextually-aware synthetic dataset.

To ensure the LLM's output is structured and parseable, we will **dynamically generate a Pydantic schema** based on the `DESCRIBE` output from the previous step. This is a powerful, generic technique that adapts to any database schema.

To fine-tune our model with RFT, **we will only interact with this synthetic database.**

In [4]:
import pandas as pd
import os
from pydantic import create_model, BaseModel
from fireworks import LLM
import duckdb
import json
from dotenv import load_dotenv
from typing import List, Optional, Any, Dict, Type
import datetime
import decimal
import uuid
import math
import time

# --- 1. Dynamically Create Pydantic Models from the SQL Schema ---
def map_sql_type_to_python(sql_type: str) -> Type:
    """Maps SQL data types to Python types for Pydantic models."""
    sql_type_upper = str(sql_type).upper()
    if 'DECIMAL' in sql_type_upper: return decimal.Decimal
    if 'DOUBLE' in sql_type_upper or 'FLOAT' in sql_type_upper or 'REAL' in sql_type_upper: return float
    if 'BIGINT' in sql_type_upper or 'INT' in sql_type_upper: return int
    if 'VARCHAR' in sql_type_upper or 'TEXT' in sql_type_upper or 'STRING' in sql_type_upper: return str
    if 'TIMESTAMP' in sql_type_upper: return datetime.datetime
    if 'DATE' in sql_type_upper: return datetime.date
    if 'TIME' in sql_type_upper: return datetime.time
    if 'BOOLEAN' in sql_type_upper: return bool
    if 'BLOB' in sql_type_upper or 'BYTEA' in sql_type_upper: return bytes
    if 'UUID' in sql_type_upper: return uuid.UUID
    return object

pydantic_models: Dict[str, Type[BaseModel]] = {}
table_names = schema_df['name'].unique()

for table_name in table_names:
    table_schema = schema_df[schema_df['name'] == table_name].iloc[0]
    fields: Dict[str, Any] = {}
    col_names = table_schema['column_names']
    col_types = table_schema['column_types']
    for i, col_name in enumerate(col_names):
        python_type = map_sql_type_to_python(col_types[i])
        fields[col_name] = (Optional[python_type], None)
    model_name = table_name.capitalize() + "Model"
    pydantic_models[table_name] = create_model(model_name, **fields)

dataset_fields: Dict[str, Any] = {
    table_name: (List[model], ...) for table_name, model in pydantic_models.items()
}
SyntheticDataset = create_model('SyntheticDataset', **dataset_fields)
print("✅ Dynamically created Pydantic models for all tables.")


# --- 2. Define Total Row Counts and Chunking Strategy ---
TOTAL_ROW_COUNTS = {name: 50 for name in table_names}
ROWS_PER_API_CALL = 2 # Ask for data in small, safe chunks
print("\n✅ Data Generation Plan:")
print(f" - Target rows per table: {list(TOTAL_ROW_COUNTS.values())[0]}")
print(f" - Will make API calls asking for {ROWS_PER_API_CALL} rows/call until target is met.")


# --- 3. Setup LLM and Loop to Generate Data in Chunks ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
load_dotenv()
llm = LLM(model="accounts/fireworks/models/deepseek-v3", deployment_type="serverless", api_key=os.getenv("FW_API_KEY"))

all_synthetic_data: Dict[str, List[Dict]] = {name: [] for name in table_names}
chunk_row_counts = {name: ROWS_PER_API_CALL for name in table_names}

base_generation_prompt = f"""
You are a highly intelligent AI data generator. Your task is to create a realistic, synthetic dataset based on the provided database schema.
The data you generate must be internally consistent. For example, an `airline_id` in a `routes` table must correspond to an existing `airline_id` in an `airlines` table within this same generated chunk.
This applies to any schema you might be working with, not just airlines.
You must generate a single JSON object that strictly adheres to the provided JSON schema.

The database schema is as follows:
{schema_for_prompt}
"""

call_count = 0
# Loop until all tables have at least the desired number of rows
while not all(len(rows) >= TOTAL_ROW_COUNTS[name] for name, rows in all_synthetic_data.items()):
    call_count += 1
    print(f"\n📞 --- Generating data chunk #{call_count} ---")
    
    # --- Create a summary of existing data to guide the LLM ---
    existing_data_summary = ""
    if any(len(rows) > 0 for rows in all_synthetic_data.values()):
        summary_parts = ["\nYou have already generated the following data. Do NOT generate rows that are substantially similar to these examples. Create new, distinct data.\n"]
        for table_name, rows in all_synthetic_data.items():
            if rows:
                summary_parts.append(f"\n--- Existing data in '{table_name}' table ---")
                df = pd.DataFrame(rows)
                if len(df.columns) > 10:
                    df = df.iloc[:, :10]
                markdown_summary = df.to_markdown(index=False, tablefmt="grid")
                if markdown_summary:
                    summary_parts.append(markdown_summary)
        existing_data_summary = "\n".join(summary_parts)


    # --- Construct the final prompt for this iteration ---
    final_prompt = (
        base_generation_prompt +
        existing_data_summary +
        f"\n\nNow, generate a NEW JSON object with a key for each table. The number of new rows for each table should be:\n" +
        json.dumps(chunk_row_counts, indent=2)
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SyntheticDataset", "schema": SyntheticDataset.model_json_schema()}},
        temperature=0.7
    )

    choice = response.choices[0]
    response_content = choice.message.content

    if choice.finish_reason == "length":
        print(f"⚠️ WARNING: Chunk #{call_count} was truncated. Skipping.")
        continue
    if not response_content:
        print(f"⚠️ WARNING: Received empty content for chunk #{call_count}. Skipping.")
        continue

    try:
        chunk_data = json.loads(response_content)
        print(f"✅ Received and parsed chunk #{call_count}.")
        for table_name, rows in chunk_data.items():
            if table_name in all_synthetic_data and rows:
                all_synthetic_data[table_name].extend(rows)
        # Log progress
        for name, rows in all_synthetic_data.items():
             print(f"   - '{name}': {len(rows)} / {TOTAL_ROW_COUNTS[name]} rows")
    except json.JSONDecodeError as e:
        print(f"❌ ERROR: Failed to parse JSON for chunk #{call_count}. Reason: {e}. Skipping.")
    
    time.sleep(1)

# --- 4. Deduplicate and Write to DB ---
print("\n✨ Data generation complete. Aggregating, deduplicating, and saving to database...")

synthetic_data = all_synthetic_data
print("\n--- Deduplicating generated data ---")
for table_name, rows in synthetic_data.items():
    if not rows: continue
    initial_count = len(rows)
    df = pd.DataFrame(rows).drop_duplicates()
    final_count = len(df)
    synthetic_data[table_name] = df.to_dict('records')
    print(f" - Table '{table_name}': Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Final trim to ensure exact counts
for table_name, total_rows_needed in TOTAL_ROW_COUNTS.items():
    if table_name in synthetic_data:
        synthetic_data[table_name] = synthetic_data[table_name][:total_rows_needed]

with duckdb.connect(SYNTHETIC_DB_PATH) as con:
    for table_name, rows in synthetic_data.items():
        if rows:
            df = pd.DataFrame(rows)
            schema_cols = schema_df[schema_df['name'] == table_name].iloc[0]['column_names']
            for col in schema_cols:
                if col not in df.columns: df[col] = None
            df = df[schema_cols]
            con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
    
    print(f"\n✅ Synthetic training sandbox created at: {SYNTHETIC_DB_PATH}")
    print("Tables created:", con.sql("SHOW TABLES;").fetchall())

✅ Dynamically created Pydantic models for all tables.

✅ Data Generation Plan:
 - Target rows per table: 50
 - Will make API calls asking for 2 rows/call until target is met.

📞 --- Generating data chunk #1 ---
✅ Received and parsed chunk #1.
   - 'airlines': 2 / 50 rows
   - 'airports': 2 / 50 rows
   - 'countries': 2 / 50 rows
   - 'planes': 2 / 50 rows
   - 'routes': 2 / 50 rows

📞 --- Generating data chunk #2 ---
✅ Received and parsed chunk #2.
   - 'airlines': 4 / 50 rows
   - 'airports': 4 / 50 rows
   - 'countries': 4 / 50 rows
   - 'planes': 4 / 50 rows
   - 'routes': 4 / 50 rows

📞 --- Generating data chunk #3 ---
✅ Received and parsed chunk #3.
   - 'airlines': 6 / 50 rows
   - 'airports': 6 / 50 rows
   - 'countries': 6 / 50 rows
   - 'planes': 6 / 50 rows
   - 'routes': 6 / 50 rows

📞 --- Generating data chunk #4 ---
✅ Received and parsed chunk #4.
   - 'airlines': 8 / 50 rows
   - 'airports': 8 / 50 rows
   - 'countries': 8 / 50 rows
   - 'planes': 8 / 50 rows
   - 'routes

### 5. Validate the Synthetic Sandbox
Let's run a few queries against our new synthetic database to ensure the LLM did a good job generating plausible, interconnected data. We expect to see non-empty results from these joins, which confirms that the referential integrity is holding up.

In [5]:
import duckdb
from tabulate import tabulate

# Connect to the synthetic database
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    
    # Get the list of all tables created
    all_tables = [table[0] for table in con.sql("SHOW TABLES;").fetchall()]
    
    # Select the first 5 tables to display (or all if fewer than 5)
    tables_to_validate = all_tables[:5]

    print("--- Validating the first few tables in the synthetic sandbox ---\n")

    # Execute and print results for the selected tables
    for table_name in tables_to_validate:
        print(f"--- SELECT * FROM {table_name} LIMIT 10; ---")
        try:
            result_df = con.sql(f"SELECT * FROM {table_name} LIMIT 10;").df()
            if not result_df.empty:
                print(tabulate(result_df, headers='keys', tablefmt='psql'))
            else:
                print(f"(Table '{table_name}' is empty)")
        except Exception as e:
            print(f"Query failed for table '{table_name}': {e}")
        print("\n")

--- Validating the first few tables in the synthetic sandbox ---

--- SELECT * FROM airlines LIMIT 10; ---
+----+--------------+-------------------+---------+--------+--------+------------+----------------+----------+
|    |   airline_id | name              | alias   | iata   | icao   | callsign   | country        | active   |
|----+--------------+-------------------+---------+--------+--------+------------+----------------+----------|
|  0 |            1 | Sky High Airlines | SHA     | SK     | SKY    | SKYHIGH    | United States  | Y        |
|  1 |            2 | Oceanic Airways   | OA      | OC     | OCN    | OCEANIC    | United Kingdom | Y        |
|  2 |            3 | Pacific Wings     | PW      | PW     | PWI    | PACWINGS   | Australia      | Y        |
|  3 |            4 | Arctic Airlines   | AA      | AR     | ARC    | ARCTIC     | Canada         | Y        |
|  4 |            5 | Global Airways    | GA      | GL     | GLO    | GLOBAL     | Germany        | Y        |
|  5 

### 5. Generate Example SQL Queries
With our synthetic database in place, the next step is to create a set of synthetic SQL queries. These SQL queries will be executed against our database of synthetic data to get the ground truth labels for RFT. Furthermore, these same SQL queries will be used as input to an LLM to generate queries in natural language. This will enable us to form our final RFT dataset, which pairs natural language queries with ground truth results from the database.

> **Real World 🌍**: You would use a historical log of real SQL queries that have been run against your production database. These logs are the most valuable source of training data because they represent the *actual* way your users query your data.

In [6]:
import pandas as pd
import json
import time
from pydantic import BaseModel, Field
from typing import List

# --- 1. Define Generation Parameters and Pydantic Model ---
TOTAL_QUERIES_TO_GENERATE = 20
QUERIES_PER_API_CALL = 10

class SqlQueryBatch(BaseModel):
    queries: List[str] = Field(description=f"A list of exactly {QUERIES_PER_API_CALL} unique and diverse SQL queries.")

print(f"🎯 Goal: Generate {TOTAL_QUERIES_TO_GENERATE} unique queries in batches of {QUERIES_PER_API_CALL}.")

# --- 2. Setup Base Prompt and Generation Loop ---
base_query_generation_prompt = f"""
You are an expert SQL data analyst. Your task is to generate unique and diverse SQL queries based on the database schema provided.
The queries should be realistic and cover a range of complexities and SQL features (JOINS, GROUP BY, aggregates, etc.).
Ensure the generated SQL is valid for DuckDB.

**Database Schema:**
{schema_for_prompt}
"""

all_generated_queries = []
# Loop until we have enough queries
while len(all_generated_queries) < TOTAL_QUERIES_TO_GENERATE:
    print(f"\n📞 --- Generating batch #{len(all_generated_queries) // QUERIES_PER_API_CALL + 1} ---")

    # Create a summary of queries generated so far to prevent duplicates
    existing_queries_summary = ""
    if all_generated_queries:
        summary_parts = ["\nYou have already generated the following queries. Generate NEW, DISTINCT queries that are not on this list and cover different analytic scenarios.\n"]
        for i, q in enumerate(all_generated_queries):
            summary_parts.append(f"{i+1}. {q}")
        existing_queries_summary = "\n".join(summary_parts)

    # Construct the final prompt for this iteration
    final_prompt = (
        base_query_generation_prompt +
        existing_queries_summary +
        f"\n\nNow, generate {QUERIES_PER_API_CALL} new and unique SQL queries. Return your response as a single JSON object adhering to the specified schema."
    )

    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": final_prompt}],
        response_format={"type": "json_schema", "json_schema": {"name": "SqlQueryBatch", "schema": SqlQueryBatch.model_json_schema()}},
        temperature=0.8
    )

    response_content = response.choices[0].message.content
    if response_content:
        try:
            new_queries = json.loads(response_content).get("queries", [])
            all_generated_queries.extend(new_queries)
            print(f"   - Received {len(new_queries)} new queries. Total now: {len(all_generated_queries)} / {TOTAL_QUERIES_TO_GENERATE}")
        except json.JSONDecodeError as e:
            print(f"❌ ERROR: Failed to parse generated queries in this batch: {e}")
    
    time.sleep(1) # Be nice to the API

# --- 3. Deduplicate, Trim, and Save --- 
print("\n✨ Generation complete. Deduplicating and saving...")
initial_count = len(all_generated_queries)
# Simple, fast deduplication preserving order
unique_queries = list(dict.fromkeys(all_generated_queries))
final_count = len(unique_queries)
print(f" - Removed {initial_count - final_count} duplicates ({initial_count} -> {final_count}).")

# Trim to the exact number we need
final_queries = unique_queries[:TOTAL_QUERIES_TO_GENERATE]

# Save the final list to a file
QUERIES_FILE_PATH = "data/generated_queries.json"
with open(QUERIES_FILE_PATH, 'w') as f:
    json.dump({"queries": final_queries}, f, indent=2)

print(f"\n✅ Successfully saved {len(final_queries)} unique queries to `{QUERIES_FILE_PATH}`.")
print("\n--- Here are a few examples: ---")
for query in final_queries[:5]:
    print(f"- {query}")

🎯 Goal: Generate 20 unique queries in batches of 10.

📞 --- Generating batch #1 ---
   - Received 10 new queries. Total now: 10 / 20

📞 --- Generating batch #2 ---
   - Received 10 new queries. Total now: 20 / 20

✨ Generation complete. Deduplicating and saving...
 - Removed 0 duplicates (20 -> 20).

✅ Successfully saved 20 unique queries to `data/generated_queries.json`.

--- Here are a few examples: ---
- SELECT a.name AS airline_name, COUNT(r.route_id) AS num_routes FROM airlines a LEFT JOIN routes r ON a.airline_id = r.airline_id GROUP BY a.name ORDER BY num_routes DESC LIMIT 10;
- SELECT c.name AS country, COUNT(a.airport_id) AS num_airports FROM countries c LEFT JOIN airports a ON c.iso_code = a.country GROUP BY c.name ORDER BY num_airports DESC;
- SELECT p.name AS plane, COUNT(r.route_id) AS num_routes FROM planes p LEFT JOIN routes r ON p.iata = r.equipment GROUP BY p.name ORDER BY num_routes DESC;
- SELECT a.city, COUNT(*) AS num_airports FROM airports a GROUP BY a.city HAVING

### 6. Execute Queries to Get Ground-Truth Answers
Now we will act as the "system" and run the queries we just generated against our synthetic sandbox. The output of each query is the **ground-truth result**. During Reinforcement Fine-Tuning, our model will be rewarded if the SQL it writes produces this exact same result.

> **Real World 🌍**: You would run your real historical queries against the synthetic database we previously created. The correctness of the data is not a concern here, as our aim is to see what a correct query would have generated, so we can compare it to our LLM's generations during the RFT process.

In [7]:
import duckdb
import json
import pandas as pd

# --- 1. Define File Paths ---
SYNTHETIC_DB_PATH = "data/synthetic_openflights.db"
QUERIES_FILE_PATH = "data/generated_queries.json"
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"

# --- 2. Load Generated Queries ---
with open(QUERIES_FILE_PATH, 'r') as f:
    queries_data = json.load(f)
    queries_to_execute = queries_data.get("queries", [])

print(f"Loaded {len(queries_to_execute)} queries to execute.")

# --- 3. Execute Queries and Store Results ---
ground_truth_results = []
successful_executions = 0
failed_executions = 0

print("Executing queries against the synthetic database...")
with duckdb.connect(SYNTHETIC_DB_PATH, read_only=True) as con:
    for query in queries_to_execute:
        try:
            # Execute the query and convert the result to a list of dictionaries
            result_df = con.sql(query).df()
            result_records = result_df.to_dict('records')
            
            # Pair the query with its result
            ground_truth_results.append({
                "query": query,
                "result": result_records
            })
            successful_executions += 1
        except Exception as e:
            # The LLM might have occasionally generated a slightly invalid query (this should not apply in the case of a real-world sample of successful queries)
            print(f"⚠️  Skipping query due to execution error: {query}\n   Error: {e}\n")
            failed_executions += 1

print(f"\nExecution complete. Success: {successful_executions}, Failed: {failed_executions}.")

# --- 4. Save the Ground-Truth Data ---
with open(GROUND_TRUTH_FILE_PATH, 'w') as f:
    for entry in ground_truth_results:
        f.write(json.dumps(entry) + '\n')

print(f"\n✅ Successfully saved {len(ground_truth_results)} ground-truth results to `{GROUND_TRUTH_FILE_PATH}`.")

# --- 5. Print an Example ---
if ground_truth_results:
    print("\n--- Example ground-truth entry ---")
    print(json.dumps(ground_truth_results[0], indent=2))

Loaded 20 queries to execute.
Executing queries against the synthetic database...
⚠️  Skipping query due to execution error: SELECT a.name AS airline_name, COUNT(r.route_id) AS num_routes FROM airlines a LEFT JOIN routes r ON a.airline_id = r.airline_id GROUP BY a.name ORDER BY num_routes DESC LIMIT 10;
   Error: Binder Error: Table "r" does not have a column named "route_id"

Candidate bindings: : "source_airport"

⚠️  Skipping query due to execution error: SELECT p.name AS plane, COUNT(r.route_id) AS num_routes FROM planes p LEFT JOIN routes r ON p.iata = r.equipment GROUP BY p.name ORDER BY num_routes DESC;
   Error: Binder Error: Table "r" does not have a column named "route_id"

Candidate bindings: : "source_airport"


Execution complete. Success: 18, Failed: 2.

✅ Successfully saved 18 ground-truth results to `data/ground_truth_results.jsonl`.

--- Example ground-truth entry ---
{
  "query": "SELECT c.name AS country, COUNT(a.airport_id) AS num_airports FROM countries c LEFT JOIN

### 7. Generate Natural Language Questions for Final RFT Training Data
We now have pairs of `(SQL Query, Ground-Truth Result)`. The final piece missing from our training data is the user's input: a question in natural language. This is because our final goal is to use RFT to tune an LLM to map from a natural language question to a SQL query, having the reward signal be the actual result of the query, rather than just the query itself. This is important because there are many ways to write the same SQL query that yield the same, correct result.

Thus, we will use an LLM once again to translate our "historical" SQL queries into plausible questions a business user might ask, corresponding to that query. This will yield our final training dataset in the format: `(Natural Language Question, Ground-Truth Result)`.

> **Real World 🌍**: You might not need this step! If you have logs that already link user questions to the queries they ran (e.g., from a BI tool's search bar), you can use those directly. If not, this LLM-based translation is a powerful technique to bootstrap your training data.

In [8]:
import json
import time
import jsonlines
from typing import List

# --- 1. Define File Paths and Parameters ---
GROUND_TRUTH_FILE_PATH = "data/ground_truth_results.jsonl"
FINAL_TRAINING_DATA_PATH = "data/training_data.jsonl"

# --- 2. Load Ground-Truth Data ---
query_result_pairs = []
with jsonlines.open(GROUND_TRUTH_FILE_PATH) as reader:
    for obj in reader:
        query_result_pairs.append(obj)

print(f"Loaded {len(query_result_pairs)} query-result pairs.")

# --- 3. Use LLM to Generate Natural Language Questions ---
nl_generation_prompt_template = f"""
You are an expert data analyst who is great at translating SQL queries into plain English.
Based on the database schema and the provided SQL query, what is a natural language question a business user would ask to get this information?
Ensure that the question is precise enough to accurately map to the corresponding SQL query.

**Database Schema:**
{schema_for_prompt}

**SQL Query:**
```sql
{{query}}
```

Provide only the user's question, without any preamble or explanation.
"""

# The system prompt that will be included in the final training data for the RFT job.
# It gives the model its instructions at inference time.
rft_system_prompt = f"""
You are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide any explanation or text other than the SQL query itself.

**Database Schema:**
{schema_for_prompt}
"""

final_training_data = []
print(f"Generating natural language questions and formatting for RFT for {len(query_result_pairs)} queries...")

for i, pair in enumerate(query_result_pairs):
    print(f" - Processing query {i+1}/{len(query_result_pairs)}...")
    nl_generation_prompt = nl_generation_prompt_template.format(query=pair['query'])
    
    response = llm.chat.completions.create(
        messages=[{"role": "user", "content": nl_generation_prompt}],
        temperature=0.5
    )
    
    nl_question = response.choices[0].message.content
    if nl_question:
        # Assemble the final data structure
        rft_entry = {
            "messages": [
                {"role": "system", "content": rft_system_prompt},
                {"role": "user", "content": nl_question.strip()}
            ],
            "ground_truth": pair['result'] # The ground-truth result for the evaluator
        }
        final_training_data.append(rft_entry)
    
    time.sleep(1) # Be nice to the API

# --- 4. Save the Final RFT-Ready Dataset ---
with jsonlines.open(FINAL_TRAINING_DATA_PATH, mode='w') as writer:
    writer.write_all(final_training_data)

print(f"\n✅ Successfully created final RFT training dataset with {len(final_training_data)} entries at `{FINAL_TRAINING_DATA_PATH}`.")

# --- 5. Print an Example ---
if final_training_data:
    print("\n--- Example RFT training entry ---")
    print(json.dumps(final_training_data[0], indent=2))

Loaded 18 query-result pairs.
Generating natural language questions and formatting for RFT for 18 queries...
 - Processing query 1/18...
 - Processing query 2/18...
 - Processing query 3/18...
 - Processing query 4/18...
 - Processing query 5/18...
 - Processing query 6/18...
 - Processing query 7/18...
 - Processing query 8/18...
 - Processing query 9/18...
 - Processing query 10/18...
 - Processing query 11/18...
 - Processing query 12/18...
 - Processing query 13/18...
 - Processing query 14/18...
 - Processing query 15/18...
 - Processing query 16/18...
 - Processing query 17/18...
 - Processing query 18/18...

✅ Successfully created final RFT training dataset with 18 entries at `data/training_data.jsonl`.

--- Example RFT training entry ---
{
  "messages": [
    {
      "role": "system",
      "content": "\nYou are an expert SQL data analyst. Your task is to write a single, valid DuckDB SQL query to answer the user's question, based on the provided database schema. Do not provide 

### 8. 🛰️ Deploy an MCP Server for the Synthetic Data
Now, we'll start a remote server that speaks the Model Context Protocol (MCP). This server will wrap our synthetic DuckDB database, providing a standardized way for any external tool—in our case, the Fireworks RFT evaluator—to interact with it.
> Real World 🌍: This pattern is directly applicable. You would run a similar MCP server to provide a secure, read-only interface to a production database replica or a data warehouse, allowing the fine-tuning process to happen without granting direct database credentials to the training environment.

10. a) Create a server script in this project's root directory (`run_mcp_server.py`). This Python script starts our database server. It is configured to be read-only.

```python
    import os, contextlib, uvicorn
    from starlette.applications import Starlette
    from starlette.routing import Mount
    from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
    from mcp_server_motherduck import build_application

    DB = "data/synthetic_openflights.db"          # ← path from previous steps
    PORT = int(os.environ.get("PORT", 8080))        # Cloud Run injects $PORT

    # 1️⃣ Build the core SQL-aware MCP server (read-only for safety).
    server, _ = build_application(db_path=DB, read_only=True)

    # 2️⃣ Wrap it so HTTP clients can talk to it (ASGI handler).
    sess = StreamableHTTPSessionManager(app=server, event_store=None, stateless=True)

    async def handler(scope, receive, send):
        await sess.handle_request(scope, receive, send)

    @contextlib.asynccontextmanager
    async def lifespan(app):
        async with sess.run():
            yield                                        # keep sessions alive

    # 3️⃣ Starlette turns that handler into a full ASGI app Uvicorn can serve.
    app = Starlette(routes=[Mount("/mcp", app=handler)], lifespan=lifespan)

    if __name__ == "__main__":
        print(f"🔥 MCP endpoint → http://0.0.0.0:{PORT}/mcp")
        uvicorn.run(app, host="0.0.0.0", port=PORT)
```


### 9. ☁️ Set Up Google Cloud CLI
We'll first set up the Google Cloud CLI and authenticate.

> **Real World 🌍**  
> You would follow along here in the same way

9. a) **Install** the SDK (macOS/Linux):

      ```bash
      curl -sSL https://sdk.cloud.google.com | bash
      exec -l $SHELL  # reload shell so 'gcloud' is available
      ```

9. b) **Log in** (creates local access token)
      ```bash
      gcloud auth login
      ```

9. c) **Set your active project desired gcloud project**
      ```bash
      gcloud config set project < YOUR_PROJECT_ID >  # set up project in gcloud console before running this if not already done
      ```

### 10. 📦 Containerize & Deploy the MCP Server  
We’ll build a Docker image and push it straight to Cloud Run.  
Remember to replace **`YOUR_PROJECT_ID`** with the project you actually want to bill.

> **Real World 🌍**  
> You would follow along in the same way here.

10. a) Create `mcp_requirements.txt` containing the following:

```bash
        mcp
        mcp-server-motherduck
        duckdb
        uvicorn
        starlette
```

10. b) Create a `Dockerfile` (no extension) containing the following
```bash
        base
        FROM python:3.11-slim
        WORKDIR /app

        COPY mcp_requirements.txt .
        RUN pip install --no-cache-dir -r mcp_requirements.txt

        COPY run_mcp_server.py .
        COPY data/synthetic_openflights.db ./data/

        EXPOSE 8080

        CMD ["python", "run_mcp_server.py"]
```

10. c) Deploy your MCP server as a Cloud Run app by running (from your project root):
```bash
        reward-kit deploy-mcp \
        --id mcp-sql-rft-server \
        --dockerfile Dockerfile \
        --port 8080
```
