In [12]:
from google.colab import files

# Upload your three files: corrected_schema.json, schema_triplets.json, business_rules_tripplets.json
uploaded = files.upload()


Saving corrected_schema.json to corrected_schema (1).json
Saving schema_triplets.json to schema_triplets (1).json
Saving business_rules_triplets.json to business_rules_triplets (1).json


In [26]:
import json

with open("corrected_schema (1).json") as f:
    corrected_schema = json.load(f)

with open("schema_triplets.json") as f:
    schema_triplets = json.load(f)

with open("business_rules_triplets.json") as f:
    business_rules_triplets = json.load(f)


In [14]:
# OR, print just the first table and its columns:
first_table = next(iter(corrected_schema))
print(f"Table: {first_table}")
print(json.dumps(corrected_schema[first_table], indent=2))

Table: salesperson
{
  "description": "Contains sales performance data for each salesperson.",
  "businessentityid": {
    "type": "INTEGER",
    "notnull": true,
    "pk": true,
    "sample": [
      274,
      275
    ],
    "fk": {
      "table": "employee",
      "column": "businessentityid"
    },
    "description": "Primary key. Unique identifier for the salesperson."
  },
  "territoryid": {
    "type": "INTEGER",
    "notnull": false,
    "pk": false,
    "sample": [
      "",
      2
    ],
    "fk": {
      "table": "salesterritory",
      "column": "territoryid"
    },
    "description": "Sales territory associated with the salesperson."
  },
  "salesquota": {
    "type": "FLOAT",
    "notnull": false,
    "pk": false,
    "sample": [
      "",
      300000
    ],
    "description": "Sales quota for the salesperson."
  },
  "bonus": {
    "type": "FLOAT",
    "notnull": false,
    "pk": false,
    "sample": [
      0,
      4100
    ],
    "description": "Bonus awarded to the

In [61]:
schema_chunks = []

for table_name, table_info in corrected_schema.items():
    table_desc = table_info.get("description", "")
    column_lines = []

    for column_name, column_info in table_info.items():
        if column_name == "description":
            continue

        col_type = column_info.get("type", "UNKNOWN")
        col_desc = column_info.get("description", "")
        sample_vals = column_info.get("sample", [])
        pk = "PK" if column_info.get("pk", False) else ""
        fk = column_info.get("fk")
        fk_info = f"FK → {fk['table']}.{fk['column']}" if fk else ""

        col_metadata = f"{column_name} ({col_type}"
        if pk: col_metadata += f", {pk}"
        if fk_info: col_metadata += f", {fk_info}"
        col_metadata += f"): {col_desc} [Samples: {sample_vals}]"
        column_lines.append(f"- {col_metadata}")

    chunk_text = (
        f"--- Table: {table_name} ---\n\n"
        f"Table Description: {table_desc}\n\n"
        f"Columns:\n" + "\n".join(column_lines)
    )

    schema_chunks.append({
        "table": table_name,
        "text": chunk_text
    })


In [62]:
print(f"Total schema chunks: {len(schema_chunks)}")

Total schema chunks: 12


In [63]:
# Print the first 5 schema embedding chunks as a preview
for i, chunk in enumerate(schema_chunks[:5], start=1):
    print(f"--- Chunk {i} ---")
    print(chunk)
    print()


--- Chunk 1 ---
{'table': 'salesperson', 'text': "--- Table: salesperson ---\n\nTable Description: Contains sales performance data for each salesperson.\n\nColumns:\n- businessentityid (INTEGER, PK, FK → employee.businessentityid): Primary key. Unique identifier for the salesperson. [Samples: [274, 275]]\n- territoryid (INTEGER, FK → salesterritory.territoryid): Sales territory associated with the salesperson. [Samples: ['', 2]]\n- salesquota (FLOAT): Sales quota for the salesperson. [Samples: ['', 300000]]\n- bonus (FLOAT): Bonus awarded to the salesperson. [Samples: [0, 4100]]\n- commissionpct (FLOAT): Commission percentage for the salesperson. [Samples: [0.0, 0.012]]\n- salesytd (FLOAT): Year-to-date sales for the salesperson. [Samples: [559697.5639, 3763178.1787]]\n- saleslastyear (FLOAT): Sales for the previous year. [Samples: [0.0, 1750406.4785]]\n- rowguid (TEXT): Unique identifier for the row. [Samples: ['48754992-9ee0-4c0e-8c94-9451604e3e02', '1e0a7274-3064-4f58-88ee-4c6586c87

In [144]:
relationship_chunks = []

for table_name, table_info in corrected_schema.items():
    for column_name, column_info in table_info.items():
        if column_name == "description" or not isinstance(column_info, dict):
            continue
        fk = column_info.get("fk")
        if fk:
            rel_text = (
                f"Relationship: Table '{table_name}' joins '{fk['table']}' "
                f"on {table_name}.{column_name} = {fk['table']}.{fk['column']}. "
                f"{column_info.get('description', '')}"
            )
            relationship_chunks.append({
                "tables": [table_name, fk["table"]],
                "columns": [f"{table_name}.{column_name}", f"{fk['table']}.{fk['column']}"],  # Always present
                "text": rel_text
            })


In [145]:
print(f"Total schema chunks: {len(relationship_chunks)}")

Total schema chunks: 24


In [146]:
# Print the first 5 schema embedding chunks as a preview
for i, chunk in enumerate(relationship_chunks[:5], start=1):
    print(f"--- Chunk {i} ---")
    print(chunk)
    print()

--- Chunk 1 ---
{'tables': ['salesperson', 'employee'], 'columns': ['salesperson.businessentityid', 'employee.businessentityid'], 'text': "Relationship: Table 'salesperson' joins 'employee' on salesperson.businessentityid = employee.businessentityid. Primary key. Unique identifier for the salesperson."}

--- Chunk 2 ---
{'tables': ['salesperson', 'salesterritory'], 'columns': ['salesperson.territoryid', 'salesterritory.territoryid'], 'text': "Relationship: Table 'salesperson' joins 'salesterritory' on salesperson.territoryid = salesterritory.territoryid. Sales territory associated with the salesperson."}

--- Chunk 3 ---
{'tables': ['product', 'productsubcategory'], 'columns': ['product.productsubcategoryid', 'productsubcategory.productsubcategoryid'], 'text': "Relationship: Table 'product' joins 'productsubcategory' on product.productsubcategoryid = productsubcategory.productsubcategoryid. Subcategory of the product."}

--- Chunk 4 ---
{'tables': ['product', 'productmodel'], 'columns'

In [19]:
!pip install openai --quiet

In [36]:
# ---------- Configuration ----------
from openai import OpenAI
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_KEY')
 #---------- Client Setup ----------
client = OpenAI(api_key=OPENAI_API_KEY)

In [64]:
import openai  # or from openai import OpenAI if you're using OpenAI class
import time  # Optional, for pacing if rate limits are a concern

# Assuming you already have:
# - schema_chunks: list of table-level chunks with keys 'table' and 'text'
# - client: your OpenAI client

schema_embeddings = []

for i, chunk in enumerate(schema_chunks):
    try:
        # Embed the table-level chunk
        response = client.embeddings.create(
            input=chunk["text"],
            model="text-embedding-3-small"
        )
        embedding = response.data[0].embedding

        # Append embedding with metadata
        schema_embeddings.append({
            "chunk_id": i,
            "table": chunk["table"],
            "text": chunk["text"],
            "embedding": embedding,
            "preview": chunk["text"][:120] + "..."  # Optional: for visual debugging
        })

        print(f"✅ Embedded chunk {i+1}/{len(schema_chunks)}: {chunk['table']}")

    except Exception as e:
        print(f"❌ Error embedding chunk {i+1}: {e}")
        time.sleep(1)  # in case of transient failure, wait and continue


✅ Embedded chunk 1/12: salesperson
✅ Embedded chunk 2/12: product
✅ Embedded chunk 3/12: productmodelproductdescriptionculture
✅ Embedded chunk 4/12: productdescription
✅ Embedded chunk 5/12: productreview
✅ Embedded chunk 6/12: productcategory
✅ Embedded chunk 7/12: productsubcategory
✅ Embedded chunk 8/12: salesorderdetail
✅ Embedded chunk 9/12: salesorderheader
✅ Embedded chunk 10/12: salesterritory
✅ Embedded chunk 11/12: countryregioncurrency
✅ Embedded chunk 12/12: currencyrate


In [147]:
relationship_embeddings = []

for i, chunk in enumerate(relationship_chunks):
    try:
        response = client.embeddings.create(
            input=chunk["text"],
            model="text-embedding-3-small"
        )
        embedding = response.data[0].embedding
        relationship_embeddings.append({
            "chunk_id": i,
            "tables": chunk["tables"],
            "columns": chunk["columns"],
            "text": chunk["text"],
            "embedding": embedding,
            "preview": chunk["text"][:120] + "..."
        })
        print(f"✅ Embedded relationship chunk {i+1}/{len(relationship_chunks)}: {chunk['tables']}")
    except Exception as e:
        print(f"❌ Error embedding relationship chunk {i+1}: {e}")
        time.sleep(1)


✅ Embedded relationship chunk 1/24: ['salesperson', 'employee']
✅ Embedded relationship chunk 2/24: ['salesperson', 'salesterritory']
✅ Embedded relationship chunk 3/24: ['product', 'productsubcategory']
✅ Embedded relationship chunk 4/24: ['product', 'productmodel']
✅ Embedded relationship chunk 5/24: ['productmodelproductdescriptionculture', 'productmodel']
✅ Embedded relationship chunk 6/24: ['productmodelproductdescriptionculture', 'productdescription']
✅ Embedded relationship chunk 7/24: ['productreview', 'product']
✅ Embedded relationship chunk 8/24: ['productsubcategory', 'productcategory']
✅ Embedded relationship chunk 9/24: ['salesorderdetail', 'salesorderheader']
✅ Embedded relationship chunk 10/24: ['salesorderdetail', 'product']
✅ Embedded relationship chunk 11/24: ['salesorderdetail', 'specialoffer']
✅ Embedded relationship chunk 12/24: ['salesorderheader', 'customer']
✅ Embedded relationship chunk 13/24: ['salesorderheader', 'salesperson']
✅ Embedded relationship chunk 14

In [148]:
!pip install rdflib



In [101]:
from rdflib import Graph, Namespace, URIRef, Literal
import json
import re

# Namespaces for schema and business logic
SCHEMA = Namespace("http://example.org/schema/")
BUSINESS = Namespace("http://example.org/business/")

# Initialize RDF Graph
g = Graph()

def safe_uri(name):
    """Ensure safe URI naming for RDF (alphanumeric, underscores, dots)."""
    if isinstance(name, str):
        return re.sub(r'[^A-Za-z0-9_\.]', '_', name)
    return str(name)

# --- 1. Add schema triplets ---
for triple in schema_triplets:
    subj, pred, obj = triple

    subj_uri = SCHEMA[safe_uri(subj)]
    pred_uri = SCHEMA[safe_uri(pred)]

    # FK/PK object as URI, else Literal
    if isinstance(obj, dict) and 'table' in obj and 'column' in obj:
        obj_val = SCHEMA[f"{safe_uri(obj['table'])}.{safe_uri(obj['column'])}"]
    elif isinstance(obj, str) and re.match(r"^[A-Za-z0-9_]+\.[A-Za-z0-9_]+$", obj):
        obj_val = SCHEMA[safe_uri(obj)]
    else:
        obj_val = Literal(obj)

    g.add((subj_uri, pred_uri, obj_val))

# --- 2. Add business rules triplets ---
for br in business_rules_triplets:
    subj_uri = BUSINESS[safe_uri(br["subject"])]
    pred_uri = BUSINESS[safe_uri(br["predicate"])]
    obj_val = Literal(br["object"])  # Always a string literal
    g.add((subj_uri, pred_uri, obj_val))

# --- 3. Save as Turtle format ---
g.serialize("schema_and_business_rules.ttl", format="turtle")

# --- 4. (Optional) Print sample triples for debug ---
print("--- First 5 schema triples ---")
for t in list(g)[:5]:
    print(t)


--- First 5 schema triples ---
(rdflib.term.URIRef('http://example.org/schema/salesterritory.territoryid'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.Literal('Primary key. Unique identifier for the sales territory.'))
(rdflib.term.URIRef('http://example.org/schema/product.productline'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.Literal('Product line code.'))
(rdflib.term.URIRef('http://example.org/schema/salesperson.businessentityid'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.Literal('Primary key. Unique identifier for the salesperson.'))
(rdflib.term.URIRef('http://example.org/schema/productsubcategory.productcategoryid'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.Literal('Product category identifier.'))
(rdflib.term.URIRef('http://example.org/schema/salesorderheader.currencyrateid'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.Liter

In [102]:
from rdflib import Namespace

SCHEMA = Namespace("http://example.org/schema/")

for s, p, o in g.triples((None, SCHEMA['foreign_key'], None)):
    print(f"{s} is FK to {o}")

http://example.org/schema/salesperson.businessentityid is FK to http://example.org/schema/person.businessentityid
http://example.org/schema/salesperson.territoryid is FK to http://example.org/schema/salesterritory.territoryid
http://example.org/schema/salesorderheader.territoryid is FK to http://example.org/schema/salesterritory.territoryid
http://example.org/schema/product.sizeunitmeasurecode is FK to http://example.org/schema/unitmeasure.unitmeasurecode
http://example.org/schema/product.weightunitmeasurecode is FK to http://example.org/schema/unitmeasure.unitmeasurecode
http://example.org/schema/product.productsubcategoryid is FK to http://example.org/schema/productsubcategory.productsubcategoryid
http://example.org/schema/product.productmodelid is FK to http://example.org/schema/productmodel.productmodelid
http://example.org/schema/productmodelproductdescriptionculture.productmodelid is FK to http://example.org/schema/productmodel.productmodelid
http://example.org/schema/productmode

In [66]:
from rdflib.plugins.sparql import prepareQuery

q = prepareQuery("""
  SELECT ?s ?p ?o WHERE {
    ?s ?p ?o .
  }
""")

for row in g.query(q):
    print(row)


(rdflib.term.URIRef('http://example.org/schema/salesperson.businessentityid'), rdflib.term.URIRef('http://example.org/schema/column_type'), rdflib.term.URIRef('http://example.org/schema/INTEGER'))
(rdflib.term.URIRef('http://example.org/schema/salesorderheader.modifieddate'), rdflib.term.URIRef('http://example.org/schema/column_type'), rdflib.term.URIRef('http://example.org/schema/DATETIME'))
(rdflib.term.URIRef('http://example.org/schema/productmodelproductdescriptionculture.modifieddate'), rdflib.term.URIRef('http://example.org/schema/column_type'), rdflib.term.URIRef('http://example.org/schema/DATETIME'))
(rdflib.term.URIRef('http://example.org/schema/product.rowguid'), rdflib.term.URIRef('http://example.org/schema/column_type'), rdflib.term.URIRef('http://example.org/schema/TEXT'))
(rdflib.term.URIRef('http://example.org/schema/productreview.reviewername'), rdflib.term.URIRef('http://example.org/schema/description'), rdflib.term.URIRef('http://example.org/schema/Name_of_the_reviewe

In [32]:
# Golden NLQs and their SQL
nlqs = [
    "What are the top 3 products by average rating in each territory?",
    "Which salespersons exceeded their quota and contributed to >5% YoY growth?",
    "How many orders had FX-adjusted revenue above $5000 in 2013?",
    "What is the product bundle most frequently bought together in each category?",
    "Which customers submitted reviews within 30 days of purchase?"
]

gold_sqls = [
    """SELECT
    st.name AS territory,
    p.productid,
    p.NAME,
    AVG(pr.rating) AS avg_rating
FROM
    salesterritory st
JOIN salesperson sp ON sp.territoryid = st.territoryid
JOIN salesorderheader soh ON soh.salespersonid = sp.businessentityid
JOIN salesorderdetail sod ON sod.salesorderid = soh.salesorderid
JOIN product p ON sod.productid = p.productid
JOIN productreview pr ON pr.productid = p.productid
GROUP BY
    st.name, p.productid, p.NAME
ORDER BY
    st.name, avg_rating DESC
LIMIT 3;
""",
    """SELECT
    sp.businessentityid,
    sp.territoryid,
    sp.salesquota,
    sp.salesytd,
    sp.saleslastyear,
    ((sp.salesytd - sp.saleslastyear) / NULLIF(sp.saleslastyear, 0)) * 100 AS yoy_growth_pct
FROM
    salesperson sp
WHERE
    sp.salesytd > sp.salesquota
    AND ((sp.salesytd - sp.saleslastyear) / NULLIF(sp.saleslastyear, 0)) * 100 > 5;
""",
    """SELECT
    COUNT(*) AS num_orders
FROM
    salesorderheader soh
LEFT JOIN currencyrate cr ON soh.currencyrateid = cr.currencyrateid
WHERE
    strftime('%Y', soh.orderdate) = '2013'
    AND (soh.totaldue * COALESCE(cr.averagerate, 1.0)) > 5000;
""",
    """WITH product_pairs AS (
    SELECT
        sod1.salesorderid,
        LEAST(sod1.productid, sod2.productid) AS productid1,
        GREATEST(sod1.productid, sod2.productid) AS productid2
    FROM
        salesorderdetail sod1
    JOIN salesorderdetail sod2
        ON sod1.salesorderid = sod2.salesorderid
        AND sod1.productid < sod2.productid
)
SELECT
    pc.name AS category,
    p1.NAME AS product1,
    p2.NAME AS product2,
    COUNT(*) AS bundle_count
FROM
    product_pairs pp
JOIN product p1 ON pp.productid1 = p1.productid
JOIN product p2 ON pp.productid2 = p2.productid
JOIN productsubcategory ps1 ON p1.productsubcategoryid = ps1.productsubcategoryid
JOIN productcategory pc ON ps1.productcategoryid = pc.productcategoryid
GROUP BY
    pc.name, p1.NAME, p2.NAME
ORDER BY
    pc.name, bundle_count DESC
""",
    """SELECT DISTINCT
    soh.customerid,
    pr.productid,
    pr.reviewername,
    soh.orderdate,
    pr.reviewdate
FROM
    salesorderheader soh
JOIN salesorderdetail sod ON soh.salesorderid = sod.salesorderid
JOIN productreview pr ON sod.productid = pr.productid
WHERE
    julianday(pr.reviewdate) - julianday(soh.orderdate) BETWEEN 0 AND 30
""",
]

In [149]:
import numpy as np
import heapq
from rdflib import URIRef  # Add this import!
from openai import OpenAI
from rdflib import Namespace

# Cosine similarity
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

from rdflib import URIRef

def get_fk_pk_neighbors(table, graph, SCHEMA=Namespace("http://example.org/schema/")):
    uri = SCHEMA[table]
    neighbors = set()

    # Outgoing FK
    for s, p, o in graph.triples((None, SCHEMA["foreign_key"], None)):
        if str(s).startswith(str(uri) + "."):
            if isinstance(o, URIRef) and str(o).startswith(str(SCHEMA)):
                obj_str = str(o).replace(str(SCHEMA), "")
                if '.' in obj_str:
                    neighbor = obj_str.split('.')[0]
                    if neighbor != table:
                        neighbors.add(neighbor)

    # Incoming FK
    for s, p, o in graph.triples((None, SCHEMA["foreign_key"], None)):
        if isinstance(o, URIRef) and str(o).startswith(str(uri) + "."):
            if isinstance(s, URIRef) and str(s).startswith(str(SCHEMA)):
                subj_str = str(s).replace(str(SCHEMA), "")
                if '.' in subj_str:
                    neighbor = subj_str.split('.')[0]
                    if neighbor != table:
                        neighbors.add(neighbor)
    return neighbors

def semantic_search(query, top_k=5, expand_hops=1, graph=None, real_tables=None):
    if client is None:
        raise ValueError("OpenAI client not initialized")
    if schema_embeddings is None:
        raise ValueError("schema_embeddings must be provided")

    response = client.embeddings.create(input=query, model="text-embedding-3-small")
    query_embedding = np.array(response.data[0].embedding)
    similarities = []
    all_embeddings = schema_embeddings + relationship_embeddings
    for item in all_embeddings:
        sim = cosine_similarity(query_embedding, np.array(item["embedding"]))
        similarities.append((sim, item["table"]))
    top_k_hits = heapq.nlargest(top_k, similarities, key=lambda x: x[0])
    tables = set([table for _, table in top_k_hits])

    if expand_hops > 0 and graph is not None:
        expanded = set(tables)
        for table in list(tables):
            neighbors = get_fk_pk_neighbors(table, graph)
            expanded.update(neighbors)
        tables = expanded

    # Only keep real tables if provided
    if real_tables is not None:
        tables = set(t for t in tables if t in real_tables)

    return [(table, None) for table in sorted(tables)]



In [150]:
top_tables = [t for t, _ in semantic_search("Show FX-adjusted revenue", top_k=5)]
print(top_tables)

KeyError: 'table'

In [124]:
expanded_tables = set(top_tables)
for t in top_tables:
    neighbors = get_fk_pk_neighbors(t, g)
    print(f"1-hop neighbors of {t}: {neighbors}")
    expanded_tables.update(neighbors)
print("Expanded tables (semantic + 1-hop):", sorted(expanded_tables))

1-hop neighbors of countryregioncurrency: {'currency', 'countryregion'}
1-hop neighbors of currencyrate: {'currency', 'salesorderheader'}
1-hop neighbors of salesorderheader: {'shipmethod', 'creditcard', 'salesperson', 'currencyrate', 'salesterritory', 'salesorderdetail', 'customer', 'address'}
1-hop neighbors of salesperson: {'person', 'salesterritory', 'salesorderheader'}
1-hop neighbors of salesterritory: {'salesperson', 'salesorderheader', 'countryregion'}
Expanded tables (semantic + 1-hop): ['address', 'countryregion', 'countryregioncurrency', 'creditcard', 'currency', 'currencyrate', 'customer', 'person', 'salesorderdetail', 'salesorderheader', 'salesperson', 'salesterritory', 'shipmethod']


In [129]:
import networkx as nx

def build_fk_graph(corrected_schema):
    """
    Build an undirected graph of table-to-table edges for join path search.
    """
    G = nx.Graph()
    for table, meta in corrected_schema.items():
        G.add_node(table)
        for col, col_meta in meta.items():
            if col == "description": continue
            fk = col_meta.get("fk")
            if fk:
                target_table = fk["table"]
                G.add_edge(table, target_table)
    return G

def get_minimal_table_set(initial_tables, fk_graph):
    """
    Given initial tables, return all tables on the minimal join path(s) between them.
    """
    # Connect every pair of initial tables, union all shortest paths
    minimal_tables = set(initial_tables)
    for i in range(len(initial_tables)):
        for j in range(i + 1, len(initial_tables)):
            t1, t2 = initial_tables[i], initial_tables[j]
            if nx.has_path(fk_graph, t1, t2):
                path = nx.shortest_path(fk_graph, t1, t2)
                minimal_tables.update(path)
    return list(minimal_tables)


In [136]:
def format_join_path(join_path):
    # If join_path contains tuples of (from_table, from_column, to_table, to_column), format accordingly
    if isinstance(join_path, list) and all(isinstance(x, tuple) and len(x) == 4 for x in join_path):
        return " -> ".join([f"{a}.{b} → {c}.{d}" for a, b, c, d in join_path])
    # If join_path is just a list of columns or tables, join them with arrows
    return " -> ".join(join_path)

def build_prompt(nlq, final_tables, schema, join_paths=None):
    lines = [
        "You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.",
        ""
    ]
    if join_paths:
        lines.append("# Join path(s):")
        for jp in join_paths:
            lines.append("  " + format_join_path(jp))
        lines.append("")  # for spacing

    lines.append("### Schema Context:\n")
    for table in final_tables:
        t_meta = schema.get(table, {})
        desc = t_meta.get("description", "")
        lines.append(f"## Table: {table}")
        if desc:
            lines.append(f"Table Description: {desc}")
        col_lines = []
        for col, meta in t_meta.items():
            if col == "description":
                continue
            coltype = meta.get("type", "UNKNOWN")
            coldesc = meta.get("description", "")
            pk = "PK" if meta.get("pk") else ""
            fk = meta.get("fk")
            fk_info = f"FK to {fk['table']}.{fk['column']}" if fk else ""
            col_str = f"- {col} ({coltype} {pk} {fk_info})"
            if coldesc:
                col_str += f": {coldesc}"
            col_lines.append(col_str)
        if col_lines:
            lines += col_lines
        lines.append("")  # blank line between tables

    lines.append(f"### User Question:\n{nlq}\n")
    lines.append("Return only the SQL.")

    return "\n".join(lines)


In [137]:
import pandas as pd

results = []

for i, nlq in enumerate(nlqs):
    # 1. Find initial top tables via semantic search (vector similarity)
    initial_tables = [t for t, _ in semantic_search(nlq, top_k=5)]

    # 2. Expand tables with 1-hop neighbors from KG
    expanded_tables = set(initial_tables)
    for t in initial_tables:
        neighbors = get_fk_pk_neighbors(t, g)
        expanded_tables.update(neighbors)

    # 3. Only keep real tables (those present in your schema)
    final_tables = [t for t in expanded_tables if t in corrected_schema]

    # 4. Build the prompt using ONLY the final_tables
    prompt = build_prompt(nlq, final_tables, corrected_schema, g)

    # 5. (Optional) Generate SQL with your LLM (not included here)
    generated_sql = "[PLACEHOLDER: LLM SQL OUTPUT HERE]"

    # 6. Collect for report/eval
    results.append({
        "NLQ": nlq,
        "Initial Tables": initial_tables,
        "Expanded Tables": sorted(list(expanded_tables)),
        "Final Tables": final_tables,
        "Prompt": prompt,
        "Golden_SQL": gold_sqls[i],
        "Generated_SQL": generated_sql
    })

# Display as DataFrame
pd.set_option('display.max_colwidth', 200)
df_results = pd.DataFrame(results)
display(df_results[["NLQ", "Initial Tables", "Expanded Tables", "Final Tables", "Prompt", "Golden_SQL", "Generated_SQL"]])

# Optionally: Save to CSV
# df_results.to_csv("llm_sql_eval_results_full.csv", index=False)


Unnamed: 0,NLQ,Initial Tables,Expanded Tables,Final Tables,Prompt,Golden_SQL,Generated_SQL
0,What are the top 3 products by average rating in each territory?,"[product, productmodelproductdescriptionculture, productreview, salesperson, salesterritory]","[countryregion, culture, person, product, productdescription, productmodel, productmodelproductdescriptionculture, productreview, productsubcategory, salesorderdetail, salesorderheader, salesperso...","[productreview, productmodelproductdescriptionculture, productsubcategory, salesperson, salesterritory, product, salesorderdetail, productdescription, salesorderheader]",You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/s...,"SELECT\n st.name AS territory,\n p.productid,\n p.NAME,\n AVG(pr.rating) AS avg_rating\nFROM\n salesterritory st\nJOIN salesperson sp ON sp.territoryid = st.territoryid\nJOIN saleso...",[PLACEHOLDER: LLM SQL OUTPUT HERE]
1,Which salespersons exceeded their quota and contributed to >5% YoY growth?,"[product, salesorderdetail, salesorderheader, salesperson, salesterritory]","[address, countryregion, creditcard, currencyrate, customer, person, product, productmodel, productreview, productsubcategory, salesorderdetail, salesorderheader, salesperson, salesterritory, ship...","[productreview, productsubcategory, salesperson, salesorderdetail, currencyrate, salesterritory, product, salesorderheader]",You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/s...,"SELECT\n sp.businessentityid,\n sp.territoryid,\n sp.salesquota,\n sp.salesytd,\n sp.saleslastyear,\n ((sp.salesytd - sp.saleslastyear) / NULLIF(sp.saleslastyear, 0)) * 100 AS yo...",[PLACEHOLDER: LLM SQL OUTPUT HERE]
2,How many orders had FX-adjusted revenue above $5000 in 2013?,"[currencyrate, salesorderdetail, salesorderheader, salesperson, salesterritory]","[address, countryregion, creditcard, currency, currencyrate, customer, person, product, salesorderdetail, salesorderheader, salesperson, salesterritory, shipmethod, specialoffer]","[salesperson, currencyrate, salesterritory, product, salesorderdetail, salesorderheader]",You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/s...,"SELECT\n COUNT(*) AS num_orders\nFROM\n salesorderheader soh\nLEFT JOIN currencyrate cr ON soh.currencyrateid = cr.currencyrateid\nWHERE\n strftime('%Y', soh.orderdate) = '2013'\n AND ...",[PLACEHOLDER: LLM SQL OUTPUT HERE]
3,What is the product bundle most frequently bought together in each category?,"[product, productcategory, productmodelproductdescriptionculture, productreview, productsubcategory]","[culture, product, productcategory, productdescription, productmodel, productmodelproductdescriptionculture, productreview, productsubcategory, salesorderdetail, unitmeasure]","[productreview, productcategory, productsubcategory, productmodelproductdescriptionculture, product, salesorderdetail, productdescription]",You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/s...,"WITH product_pairs AS (\n SELECT\n sod1.salesorderid,\n LEAST(sod1.productid, sod2.productid) AS productid1,\n GREATEST(sod1.productid, sod2.productid) AS productid2\n F...",[PLACEHOLDER: LLM SQL OUTPUT HERE]
4,Which customers submitted reviews within 30 days of purchase?,"[product, productdescription, productreview, salesorderdetail, salesorderheader]","[address, creditcard, currencyrate, customer, product, productdescription, productmodel, productmodelproductdescriptionculture, productreview, productsubcategory, salesorderdetail, salesorderheade...","[productreview, productsubcategory, productmodelproductdescriptionculture, salesperson, currencyrate, product, salesterritory, salesorderdetail, productdescription, salesorderheader]",You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/s...,"SELECT DISTINCT\n soh.customerid,\n pr.productid,\n pr.reviewername,\n soh.orderdate,\n pr.reviewdate\nFROM\n salesorderheader soh\nJOIN salesorderdetail sod ON soh.salesorderid ...",[PLACEHOLDER: LLM SQL OUTPUT HERE]


In [138]:
#let's generate the SQL
def get_sql_from_llm(prompt, model="gpt-4o"):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": "You are a world-class SQL data analyst. Write only the SQL query, no explanation or comments."
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        temperature=0.0
    )
    return response.choices[0].message.content.strip()

In [139]:
for i, row in df_results.iterrows():
    prompt = row["Prompt"]
    try:
        llm_sql = get_sql_from_llm(prompt)
        df_results.at[i, "Generated_SQL"] = llm_sql
        print(f"✅ Generated SQL for NLQ {i + 1}")
    except Exception as e:
        df_results.at[i, "Generated_SQL"] = f"[ERROR: {str(e)}]"
        print(f"❌ Error for NLQ {i + 1}: {e}")
pd.set_option("display.max_colwidth", 400)
display(df_results[["NLQ", "Prompt", "Golden_SQL", "Generated_SQL"]])

# Optional: Save to file
#df_results.to_csv("llm_sql_eval_results.csv", index=False)

✅ Generated SQL for NLQ 1
✅ Generated SQL for NLQ 2
✅ Generated SQL for NLQ 3
✅ Generated SQL for NLQ 4
✅ Generated SQL for NLQ 5


Unnamed: 0,NLQ,Prompt,Golden_SQL,Generated_SQL
0,What are the top 3 products by average rating in each territory?,You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/schema/description -> Primary key. Unique identifier for the sales territory.\n http://example.org/schema/product.productline -> http://example.org/schema/description -> Product line code.\n http://e...,"SELECT\n st.name AS territory,\n p.productid,\n p.NAME,\n AVG(pr.rating) AS avg_rating\nFROM\n salesterritory st\nJOIN salesperson sp ON sp.territoryid = st.territoryid\nJOIN salesorderheader soh ON soh.salespersonid = sp.businessentityid\nJOIN salesorderdetail sod ON sod.salesorderid = soh.salesorderid\nJOIN product p ON sod.productid = p.productid\nJOIN productreview pr ON pr....","```sql\nSELECT \n st.territoryid,\n p.productid,\n p.name AS product_name,\n AVG(pr.rating) AS average_rating\nFROM \n productreview pr\nJOIN \n product p ON pr.productid = p.productid\nJOIN \n salesorderdetail sod ON p.productid = sod.productid\nJOIN \n salesorderheader soh ON sod.salesorderid = soh.salesorderid\nJOIN \n salesterritory st ON soh.territoryid = st.ter..."
1,Which salespersons exceeded their quota and contributed to >5% YoY growth?,You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/schema/description -> Primary key. Unique identifier for the sales territory.\n http://example.org/schema/product.productline -> http://example.org/schema/description -> Product line code.\n http://e...,"SELECT\n sp.businessentityid,\n sp.territoryid,\n sp.salesquota,\n sp.salesytd,\n sp.saleslastyear,\n ((sp.salesytd - sp.saleslastyear) / NULLIF(sp.saleslastyear, 0)) * 100 AS yoy_growth_pct\nFROM\n salesperson sp\nWHERE\n sp.salesytd > sp.salesquota\n AND ((sp.salesytd - sp.saleslastyear) / NULLIF(sp.saleslastyear, 0)) * 100 > 5;\n",```sql\nSELECT s.businessentityid\nFROM salesperson s\nWHERE s.salesytd > s.salesquota\nAND ((s.salesytd - s.saleslastyear) / s.saleslastyear) * 100 > 5;\n```
2,How many orders had FX-adjusted revenue above $5000 in 2013?,You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/schema/description -> Primary key. Unique identifier for the sales territory.\n http://example.org/schema/product.productline -> http://example.org/schema/description -> Product line code.\n http://e...,"SELECT\n COUNT(*) AS num_orders\nFROM\n salesorderheader soh\nLEFT JOIN currencyrate cr ON soh.currencyrateid = cr.currencyrateid\nWHERE\n strftime('%Y', soh.orderdate) = '2013'\n AND (soh.totaldue * COALESCE(cr.averagerate, 1.0)) > 5000;\n",```sql\nSELECT COUNT(DISTINCT soh.salesorderid) AS order_count\nFROM salesorderheader soh\nJOIN currencyrate cr ON soh.currencyrateid = cr.currencyrateid\nWHERE YEAR(soh.orderdate) = 2013\nAND (soh.totaldue * cr.endofdayrate) > 5000;\n```
3,What is the product bundle most frequently bought together in each category?,You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/schema/description -> Primary key. Unique identifier for the sales territory.\n http://example.org/schema/product.productline -> http://example.org/schema/description -> Product line code.\n http://e...,"WITH product_pairs AS (\n SELECT\n sod1.salesorderid,\n LEAST(sod1.productid, sod2.productid) AS productid1,\n GREATEST(sod1.productid, sod2.productid) AS productid2\n FROM\n salesorderdetail sod1\n JOIN salesorderdetail sod2\n ON sod1.salesorderid = sod2.salesorderid\n AND sod1.productid < sod2.productid\n)\nSELECT\n pc.name AS category,\n...","```sql\nSELECT \n pc.name AS category_name,\n p1.productid AS product1_id,\n p2.productid AS product2_id,\n COUNT(*) AS bundle_count\nFROM \n salesorderdetail sod1\nJOIN \n salesorderdetail sod2 ON sod1.salesorderid = sod2.salesorderid AND sod1.productid < sod2.productid\nJOIN \n product p1 ON sod1.productid = p1.productid\nJOIN \n product p2 ON sod2.productid = p2.prod..."
4,Which customers submitted reviews within 30 days of purchase?,You are a SQL expert. Use only the tables and columns listed below. Do not invent tables/columns.\n\n# Join path(s):\n http://example.org/schema/salesterritory.territoryid -> http://example.org/schema/description -> Primary key. Unique identifier for the sales territory.\n http://example.org/schema/product.productline -> http://example.org/schema/description -> Product line code.\n http://e...,"SELECT DISTINCT\n soh.customerid,\n pr.productid,\n pr.reviewername,\n soh.orderdate,\n pr.reviewdate\nFROM\n salesorderheader soh\nJOIN salesorderdetail sod ON soh.salesorderid = sod.salesorderid\nJOIN productreview pr ON sod.productid = pr.productid\nWHERE\n julianday(pr.reviewdate) - julianday(soh.orderdate) BETWEEN 0 AND 30\n","```sql\nSELECT DISTINCT soh.customerid\nFROM salesorderheader soh\nJOIN salesorderdetail sod ON soh.salesorderid = sod.salesorderid\nJOIN productreview pr ON sod.productid = pr.productid\nWHERE DATEDIFF(day, soh.orderdate, pr.reviewdate) <= 30;\n```"
