In [None]:
# Install OpenAI and NetworkX if not already present
!pip install openai networkx --quiet

In [None]:
import sqlite3
import openai
import networkx as nx
import pandas as pd

In [None]:
# ---------- Configuration ----------
from openai import OpenAI
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_KEY')
 #---------- Client Setup ----------
client = OpenAI(api_key=OPENAI_API_KEY)

In [None]:
def extract_schema_info(db_path, sample_rows=3):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema = {}
    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
    ).fetchall()
    tables = [t[0] for t in tables]

    for table in tables:
        columns = cursor.execute(f"PRAGMA table_info('{table}')").fetchall()
        fkeys = cursor.execute(f"PRAGMA foreign_key_list('{table}')").fetchall()
        samples = cursor.execute(f"SELECT * FROM '{table}' LIMIT {sample_rows}").fetchall()
        col_names = [c[1] for c in columns]

        col_meta = {}
        for idx, col in enumerate(columns):
            col_name, col_type = col[1], col[2]
            values = [row[idx] for row in samples if row[idx] is not None]
            sample_val = values[0] if values else None
            col_meta[col_name] = {
                "type": col_type,
                "sample": sample_val,
                "pk": bool(col[5]),
                "notnull": bool(col[3])
            }

        schema[table] = {
            "columns": col_meta,
            "foreign_keys": [
                {
                    "from": fk[3],
                    "to_table": fk[2],
                    "to_col": fk[4]
                } for fk in fkeys
            ],
            "sample_row": dict(zip(col_names, samples[0])) if samples else {}
        }
    conn.close()
    return schema

# Usage: upload and extract
from google.colab import files
uploaded = files.upload()  # Upload your baseball.sqlite
db_path = next(iter(uploaded))  # Assumes one file
schema = extract_schema_info(db_path)
table_list = list(schema.keys())

Saving baseball_1.sqlite to baseball_1.sqlite


In [None]:
import json

def enrich_table_metadata_with_llm(table, meta, client, model="gpt-4o"):
    """
    Uses OpenAI LLM to enrich table/column metadata: description, synonyms, tags, aliases.
    """
    columns = meta['columns']
    col_block = [
        f"{col} ({cmeta.get('type', 'TEXT')}, sample: {cmeta.get('sample', '')})"
        for col, cmeta in columns.items()
    ]
    col_block_str = "\n".join(col_block)
    sample_row = meta.get('sample_row', {})

    prompt = f"""
You are an expert in database schema documentation and baseball analytics.
Table: {table}
Columns:
{col_block_str}
Sample row: {sample_row}

Return valid, minified JSON with:
- "table": description, synonyms, tags, aliases (alternatives)
- "columns": for each, description, synonyms, tags.
No markdown, only JSON.
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=900
    )
    def to_list(val):
        if isinstance(val, list): return val
        if isinstance(val, str): return [x.strip() for x in val.split(',') if x.strip()]
        return []
    try:
        content = response.choices[0].message.content.strip()
        meta_llm = json.loads(content)
        tmeta = meta_llm.get('table', {})
        meta['description'] = tmeta.get('description', '').strip()
        meta['synonyms'] = to_list(tmeta.get('synonyms', []))
        meta['tags'] = to_list(tmeta.get('tags', []))
        meta['alternatives'] = to_list(tmeta.get('alternatives', []))
        for col, cmeta in meta['columns'].items():
            enrich = meta_llm.get('columns', {}).get(col, {})
            cmeta['description'] = enrich.get('description', '').strip()
            cmeta['synonyms'] = to_list(enrich.get('synonyms', []))
            cmeta['tags'] = to_list(enrich.get('tags', []))
    except Exception as e:
        print(f"[{table}] Meta parse failed: {e}")
        meta["meta_raw"] = response.choices[0].message.content
    return meta


NameError: name 'complex_samples' is not defined

In [None]:
import yaml

def export_schema_semantics_yaml(meta_dict, filename="semantic_layer.yaml"):
    """
    Serializes the entire enriched schema (with all facets) as YAML and JSON.
    """
    with open(filename, "w", encoding="utf-8") as f:
        yaml.dump(meta_dict, f, sort_keys=False, allow_unicode=True)
    with open(filename.replace('.yaml', '.json'), "w", encoding="utf-8") as f:
        json.dump(meta_dict, f, indent=2, ensure_ascii=False)
    print(f"Exported to {filename} and {filename.replace('.yaml', '.json')}")


In [None]:
import json, yaml
import re
import time

# --- Safe LLM JSON response with code-fence/array handling and retry ---
def safe_llm_json_response(response, table, facet, max_attempts=2):
    """
    Defensive JSON parse with logging for empty, invalid, or non-JSON LLM responses.
    Strips markdown/code fencing, accepts both objects and arrays, and retries once.
    """
    content = response.choices[0].message.content.strip()
    # Remove code fencing if present
    content = re.sub(r'^```[a-zA-Z]*\s*', '', content)
    content = re.sub(r'```$', '', content).strip()
    for attempt in range(max_attempts):
        if not content or (not content.startswith('{') and not content.startswith('[')):
            print(f"[{table}] {facet} parse failed: Not valid JSON.\nRaw content:\n{content}")
            if attempt == 0:
                print(f"Retrying {facet} for {table}...")
                time.sleep(2)
                continue
            return None, content
        try:
            return json.loads(content), None
        except Exception as e:
            print(f"[{table}] {facet} JSON decode error: {e}\nRaw content:\n{content}")
            if attempt == 0:
                print(f"Retrying {facet} for {table} after parse error...")
                time.sleep(2)
                continue
            return None, content

# --- Utility to ensure lists ---
def to_list(val):
    if isinstance(val, list): return val
    if isinstance(val, str): return [x.strip() for x in val.split(',') if x.strip()]
    return []

# --- Backfill blank descriptions to help LLM ---
def backfill_descriptions(meta, table):
    if not meta.get('description'):
        meta['description'] = f"{table} is a table in a baseball database. Columns: {', '.join(meta['columns'].keys())}"
    for col, cmeta in meta['columns'].items():
        if not cmeta.get('description'):
            cmeta['description'] = f"{col} column in the {table} table."
    return meta

# --- Metadata enrichment ---
def enrich_table_metadata_with_llm(table, meta, client, model="gpt-4o"):
    columns = meta['columns']
    col_block = [
        f"{col} ({cmeta.get('type', 'TEXT')}, sample: {cmeta.get('sample', '')})"
        for col, cmeta in columns.items()
    ]
    col_block_str = "\n".join(col_block)
    sample_row = meta.get('sample_row', {})

    prompt = f"""
You are an expert in database schema documentation and baseball analytics.
Table: {table}
Columns:
{col_block_str}
Sample row: {sample_row}
Return valid, minified JSON with:
- "table": description, synonyms, tags, aliases (alternatives)
- "columns": for each, description, synonyms, tags.
No markdown, only JSON.
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=900
    )
    meta_llm, raw = safe_llm_json_response(response, table, "metadata")
    if meta_llm:
        tmeta = meta_llm.get('table', {})
        meta['description'] = tmeta.get('description', '').strip()
        meta['synonyms'] = to_list(tmeta.get('synonyms', []))
        meta['tags'] = to_list(tmeta.get('tags', []))
        meta['alternatives'] = to_list(tmeta.get('alternatives', []))
        for col, cmeta in meta['columns'].items():
            enrich = meta_llm.get('columns', {}).get(col, {})
            cmeta['description'] = enrich.get('description', '').strip()
            cmeta['synonyms'] = to_list(enrich.get('synonyms', []))
            cmeta['tags'] = to_list(enrich.get('tags', []))
    else:
        meta["meta_raw"] = raw
    return meta

# --- Semantics enrichment ---
def enrich_table_semantics_with_llm(table, meta, client, model="gpt-4o"):
    columns_list = list(meta['columns'].keys())
    prompt = f"""
You are a semantic data architect for baseball analytics.
Given table: {table}
Columns: {columns_list}
Describe as JSON:
- typical_joins: list of common foreign keys/joins and their meaning
- business_role: fact/dimension/entity/event/reference/other
- process_context: static/event/snapshot/temporal
- notes: nuances, ambiguity, or real-world edge cases.
Return ONLY valid, minified JSON.
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=700
    )
    semantics, raw = safe_llm_json_response(response, table, "semantics")
    if semantics:
        meta["semantics"] = semantics
    else:
        meta["semantics_raw"] = raw
    return meta

# --- Ontology enrichment ---
def enrich_table_ontology_with_llm(table, meta, client, model="gpt-4o"):
    col_desc = "\n".join(
        f"- {col}: {cmeta.get('description','')}" for col, cmeta in meta['columns'].items()
    )
    fk_desc = ""
    if meta.get('foreign_keys'):
        fk_desc = "\nForeign Keys:\n" + "\n".join(
            f"- {fk['from']} -> {fk['to_table']}.{fk['to_col']}" for fk in meta['foreign_keys']
        )

    prompt = f"""
You are an ontology engineer for database systems.
Table: {table}
Columns and Descriptions:
{col_desc}
{fk_desc}

Produce minified JSON for this table's ontology:
- entity_type: e.g. Player, Team, Game, Event
- super_type: broader class if any (e.g. Person, Organization, Event)
- is_a: array of super/sub types (hierarchy)
- part_of: larger process/entity if any
- relationships: array of objects, each with:
    - name: relationship (e.g. 'plays_for')
    - from_col: source column (if any)
    - to_table: referenced table (if any)
    - to_col: referenced column (if any)
- related_columns: list of {{"column": ..., "meaning": ...}}
Return ONLY valid, minified JSON. No markdown, no prose, no explanation.
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=900
    )
    ontology, raw = safe_llm_json_response(response, table, "ontology")
    if ontology:
        meta["ontology"] = ontology
    else:
        meta["ontology_raw"] = raw
    return meta

# --- Example enrichment with join/aggregate hints ---
def enrich_table_examples_with_llm(table, meta, client, model="gpt-4o"):
    columns_desc = "\n".join(
        f"- {col} ({cmeta.get('type', 'TEXT')})" for col, cmeta in meta['columns'].items()
    )
    desc = meta.get('description', '')
    join_hint = ""
    if meta.get('foreign_keys'):
        join_hint = "\n- At least one NLQ/SQL pair should demonstrate a JOIN using a foreign key."
    prompt = f"""
Given this baseball table:
Table: {table}
Description: {desc}
Columns:
{columns_desc}
{join_hint}

Generate 4 pairs:
- NLQ: realistic natural language question using column names and plausible values
- SQL: matching SQL (single-table, or JOIN if foreign key exists)
Include at least 1 aggregate query and 1 join if possible.
Respond ONLY with minified JSON: [{{"NLQ": "...", "SQL": "..."}}, ...]
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=900
    )
    examples, raw = safe_llm_json_response(response, table, "examples")
    if examples:
        # Accept either a dict (shouldn't happen) or a list (preferred)
        if isinstance(examples, list):
            meta["examples"] = [
                {"NLQ": ex.get("NLQ") or ex.get("nlq"), "SQL": ex.get("SQL") or ex.get("sql")}
                for ex in examples if (ex.get("NLQ") or ex.get("nlq")) and (ex.get("SQL") or ex.get("sql"))
            ]
        else:  # fallback: single dict with numbered keys
            meta["examples"] = [
                {"NLQ": v.get("NLQ") or v.get("nlq"), "SQL": v.get("SQL") or v.get("sql")}
                for k, v in examples.items() if (v.get("NLQ") or v.get("nlq")) and (v.get("SQL") or v.get("sql"))
            ]
    else:
        meta["examples_raw"] = raw
    return meta

# --- LLM review/correction ---
def judge_schema_output_with_llm(facet, table, candidate, client, model="gpt-4o"):
    prompt = f"""
You are a critical reviewer of database semantic metadata and ontology for baseball analytics.
Given table: {table}
Facet: {facet}
Candidate JSON:
{json.dumps(candidate, ensure_ascii=False)}

Please do ALL of the following:
- Judge the completeness, accuracy, and clarity of the above.
- Point out specific errors, inconsistencies, or missing details.
- Suggest corrections or improvements (output as corrected JSON).

Your output must be valid JSON with these keys:
- "assessment": [your brief review],
- "issues": [list of problems, if any],
- "correction": [your improved/corrected JSON for this facet]
Respond ONLY with JSON.
"""
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=1000
    )
    review, raw = safe_llm_json_response(response, table, f"{facet}:judgement")
    if review:
        return review
    else:
        print(f"[{table}:{facet}] Judgement parse failed. Raw:\n{raw}")
        return {"assessment": "Parse error", "issues": [raw], "correction": candidate}

def export_schema_semantics_yaml(meta_dict, filename="semantic_layer.yaml"):
    with open(filename, "w", encoding="utf-8") as f:
        yaml.dump(meta_dict, f, sort_keys=False, allow_unicode=True)
    with open(filename.replace('.yaml', '.json'), "w", encoding="utf-8") as f:
        json.dump(meta_dict, f, indent=2, ensure_ascii=False)
    print(f"Exported to {filename} and {filename.replace('.yaml', '.json')}")

# ================= PIPELINE USAGE ===================

schema_meta_dict = {}
table_list = list(schema.keys())
def get_initial_meta(table):
    return schema[table].copy()

for table in table_list:
    meta = get_initial_meta(table)
    # --- Backfill blank descriptions before LLM ---
    meta = backfill_descriptions(meta, table)
    # --- Pipeline in robust order ---
    meta = enrich_table_metadata_with_llm(table, meta, client)
    meta = enrich_table_semantics_with_llm(table, meta, client)
    meta = enrich_table_ontology_with_llm(table, meta, client)
    meta = enrich_table_examples_with_llm(table, meta, client)

    # --- LLM self-assessment for metadata facet (optional) ---
    try:
        review = judge_schema_output_with_llm('metadata', table, meta, client)
        if review.get('issues'):
            print("LLM Reflection/Correction:", review['issues'])
            meta = review['correction']
    except Exception as e:
        print(f"[{table}:metadata] Judgement parse failed: {e}")

    schema_meta_dict[table] = meta

export_schema_semantics_yaml(schema_meta_dict)

# --- Audit tables missing examples ---
missing_examples = [table for table, meta in schema_meta_dict.items() if not meta.get("examples")]
if missing_examples:
    print("Tables without NLQ/SQL examples:", missing_examples)
else:
    print("All tables have example pairs!")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        "Description for 'openings' does not adequately distinguish it from 'games'.",
        "The 'synonyms' and 'tags' sections for some columns could expand to cover edge cases or alternate uses.",
        "In 'typical_joins', references to tables such as 'league' should have 'to_column' specified for consistency.",
        "The 'process_context' should perhaps be 'seasonal' rather than 'temporal' to better reflect the scope."
    ],
    "correction": {
        "columns": {
            "year": {
                "type": "INTEGER",
                "sample": 1871,
                "pk": false,
                "notnull": false,
                "description": "The year in which the home games were played.",
                "synonyms": ["game year", "season year"],
                "tags": ["temporal", "date", "season"]
            },
            "league_id": {
                "type": "TEXT",
                "sample": "NL",
 

In [None]:
import yaml

with open('semantic_layer.yaml', 'r', encoding='utf-8') as f:
    loaded_schema = yaml.safe_load(f)

for table, meta in loaded_schema.items():
    print(f"\n=== {table.upper()} ===")
    print(f"Description: {meta.get('description', '-')}")
    print(f"Synonyms: {meta.get('synonyms', '-')}")
    print(f"Tags: {meta.get('tags', '-')}")
    print(f"Alternatives: {meta.get('alternatives', '-')}")
    print(f"Columns: {list(meta.get('columns', {}).keys())}")
    for col, cmeta in meta.get('columns', {}).items():
        print(f"  - {col}: {cmeta.get('description', '')}")

    # Print ontology/semantics if available
    ontology = meta.get('ontology')
    if ontology:
        print(f"Ontology: {ontology}")
    semantics = meta.get('semantics')
    if semantics:
        print(f"Semantics: {semantics}")

    exs = meta.get('examples', [])
    print(f"Examples ({len(exs)}):")
    for i, ex in enumerate(exs):
        nlq = ex.get('NLQ') or ex.get('nlq')
        sql = ex.get('SQL') or ex.get('sql')
        print(f"  {i+1}. NLQ: {nlq}\n     SQL: {sql}")

    print("-" * 60)



=== ALL_STAR ===
Description: Details the records of baseball players selected as All-Stars in various years and their participation in All-Star games.
Synonyms: ['All-Star records', 'All-Star appearances', 'All-Star game participants']
Tags: ['baseball', 'All-Star game', 'sports', 'player participation']
Alternatives: []
Columns: ['player_id', 'year', 'game_num', 'game_id', 'team_id', 'league_id', 'gp', 'starting_pos']
  - player_id: Unique identifier for a player, typically derived from the player's name.
  - year: The year in which the All-Star game took place.
  - game_num: The numeric order of the All-Star game played in that year.
  - game_id: A unique identifier for a specific All-Star game.
  - team_id: The code representing the team for which the player was selected as an All-Star.
  - league_id: The code representing the league in which the player's team competes.
  - gp: Indicates the number of games played by the player in an All-Star series.
  - starting_pos: Indicates th

In [None]:
from google.colab import files
files.download("semantic_layer.yaml")
# files.download("semantic_layer.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#up to here for contextual enrichment

In [None]:
print(schema['player_college'].get('foreign_keys'))

[{'from': 'college_id', 'to_table': 'college', 'to_col': 'college_id'}, {'from': 'player_id', 'to_table': 'player', 'to_col': 'player_id'}]


In [None]:
# Quick assert/validation block
for table, meta in schema.items():
    assert meta.get('description'), f"{table} missing description"
    assert meta.get('synonyms'), f"{table} missing synonyms"
    assert meta.get('tags'), f"{table} missing tags"
    assert 'examples' in meta and len(meta['examples']) >= 2, f"{table} missing examples"
    for ex in meta['examples']:
        # Ensure each example SQL mentions a real column
        assert any(col in ex['SQL'] for col in meta['columns']), f"Example SQL missing real column for {table}"
print("All tables pass basic enrichment validation!")


AssertionError: Example SQL missing real column for all_star

In [None]:
import networkx as nx

G = nx.MultiDiGraph()

for table, meta in schema.items():
    # Add table node, including LLM-enriched attributes
    G.add_node(
        table,
        type="table",
        description=meta.get("description", ""),
        synonyms=meta.get("synonyms", []),
        tags=meta.get("tags", []),
        alternatives=meta.get("alternatives", []),
        # Keep other meta fields if you want
    )

    # Synonyms/tags as their own nodes (optional)
    for syn in meta.get("synonyms", []):
        syn_node = f"{table}_syn_{syn}"
        G.add_node(syn_node, type="synonym", value=syn)
        G.add_edge(table, syn_node, type="synonym")
    for tag in meta.get("tags", []):
        tag_node = f"{table}_tag_{tag}"
        G.add_node(tag_node, type="tag", value=tag)
        G.add_edge(table, tag_node, type="tag")

    # Columns and their own attributes/synonyms
    for col, cmeta in meta["columns"].items():
        col_node = f"{table}.{col}"
        G.add_node(
            col_node,
            type="column",
            col_type=cmeta.get("col_type") or cmeta.get("type"),
            sample=cmeta.get("sample"),
            pk=cmeta.get("pk", False),
            description=cmeta.get("description", ""),
            synonyms=cmeta.get("synonyms", []),
            tags=cmeta.get("tags", []),
        )
        G.add_edge(table, col_node, type="has_column")
        # Column synonyms/tags (optional)
        for csyn in cmeta.get("synonyms", []):
            csyn_node = f"{col_node}_syn_{csyn}"
            G.add_node(csyn_node, type="synonym", value=csyn)
            G.add_edge(col_node, csyn_node, type="synonym")
        for ctag in cmeta.get("tags", []):
            ctag_node = f"{col_node}_tag_{ctag}"
            G.add_node(ctag_node, type="tag", value=ctag)
            G.add_edge(col_node, ctag_node, type="tag")

    # Foreign keys
    for fk in meta.get("foreign_keys", []):
        to_table = fk["to_table"]
        G.add_edge(
            table, to_table,
            type="foreign_key",
            from_col=fk["from"],
            to_col=fk["to_col"]
        )
        # Optionally, add reverse relationship (not required for basic traversal)
        # G.add_edge(to_table, table, type="referenced_by", from_col=fk["to_col"], to_col=fk["from"])

    # NLQ–SQL Example nodes/edges
    for j, ex in enumerate(meta.get("examples", [])):
        ex_node = f"{table}_example_{j}"
        G.add_node(ex_node, type="example", nlq=ex.get("NLQ") or ex.get("nlq"), sql=ex.get("SQL") or ex.get("sql"))
        G.add_edge(table, ex_node, type="example")

print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")


Graph has 1261 nodes and 1255 edges.


In [None]:
import networkx as nx

def get_relevant_schema_from_kg(nlq, kg, max_hops=2):
    """
    Given an NLQ and the schema KG, return:
    - sub_schema: minimal schema dict with only relevant tables/columns
    - examples: list of relevant example dicts (if present)
    Optionally: can return matched node lists for debugging.
    """
    nlq_lc = nlq.lower()
    table_hits = set()
    column_hits = set()

    for node, data in kg.nodes(data=True):
        t = data.get('type')
        val = data.get('value', '')
        # Direct table name match
        if t == 'table' and node.lower() in nlq_lc:
            table_hits.add(node)
        # Tag/synonym match
        if t in ('synonym', 'tag') and val and val.lower() in nlq_lc:
            for pred in kg.predecessors(node):
                if kg.nodes[pred].get('type') == 'table':
                    table_hits.add(pred)
        # Column match (by name or synonym)
        if t == 'column':
            col_name = node.split('.')[-1].lower()
            # Match column name or any column synonyms
            names = {col_name}
            names.update(map(str.lower, data.get('synonyms', [])))
            if names & set(nlq_lc.split()):
                column_hits.add(node)
            # You may want more advanced tokenization or fuzzy matching here.

    # Traverse local neighborhood for context
    relevant_nodes = set()
    for node in table_hits | column_hits:
        for n in nx.single_source_shortest_path_length(kg, node, cutoff=max_hops):
            relevant_nodes.add(n)

    # Build minimal schema dict for the prompt
    relevant_tables = {n for n in relevant_nodes if kg.nodes[n].get('type') == 'table'}
    sub_schema = {}
    for t in relevant_tables:
        node_data = kg.nodes[t]
        sub_schema[t] = {
            'description': node_data.get('description', ''),
            'columns': {}
        }
        for neighbor in kg.neighbors(t):
            ndata = kg.nodes[neighbor]
            if ndata.get('type') == 'column' and neighbor in relevant_nodes:
                col = neighbor.split('.', 1)[1]
                # Copy all column metadata except 'type'
                sub_schema[t]['columns'][col] = {k: v for k, v in ndata.items() if k != 'type'}

    # Collect any example nodes in the relevant context
    examples = []
    for n in relevant_nodes:
        ndata = kg.nodes[n]
        if ndata.get('type') == 'example':
            # Always use consistent keys for NLQ/SQL if possible
            ex = dict(ndata)
            if 'nlq' in ex: ex['NLQ'] = ex['nlq']
            if 'sql' in ex: ex['SQL'] = ex['sql']
            examples.append(ex)

    return sub_schema, examples


In [None]:
# Test NLQ
nlq = "Which players played for the Boston Red Sox?"

# Extract sub-schema and example info for this query
sub_schema, examples = get_relevant_schema_from_kg(nlq, G)

print("--- Relevant Tables and Descriptions ---")
for t in sub_schema:
    print(f"{t}: {sub_schema[t]['description']}")

print("\n--- Columns per Table ---")
for t in sub_schema:
    print(f"\nTable: {t}")
    for col, cmeta in sub_schema[t]['columns'].items():
        print(f"  - {col}: {cmeta}")

print("\n--- Examples Found (if any) ---")
for ex in examples:
    print(ex)


--- Relevant Tables and Descriptions ---
player: The 'player' table holds information about baseball players, including details such as birth and death dates, names, physical attributes, and career details like debut and final game dates.

--- Columns per Table ---

Table: player
  - player_id: {'col_type': 'TEXT', 'sample': 'aardsda01', 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_year: {'col_type': 'NUMERIC', 'sample': 1981, 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_month: {'col_type': 'NUMERIC', 'sample': 12, 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_day: {'col_type': 'NUMERIC', 'sample': 27, 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_country: {'col_type': 'TEXT', 'sample': 'USA', 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_state: {'col_type': 'TEXT', 'sample': 'CO', 'pk': False, 'description': '', 'synonyms': [], 'tags': []}
  - birth_city: {'col_typ

In [None]:
#tokken extraction
import re

def parse_nlq_terms(nlq):
    nlq = nlq.lower()
    nlq = re.sub(r"[^a-z0-9_ ]+", " ", nlq)
    tokens = set(nlq.split())
    return tokens

In [None]:
#Dynamic KG Querying: Table/Column Matches (Optional Helper)
def match_kg_nodes(nlq, G):
    tokens = parse_nlq_terms(nlq)
    matched_tables, matched_columns = set(), set()
    for node, data in G.nodes(data=True):
        # Table match (name, synonyms, tags, alternatives)
        if data.get("type") == "table":
            names = set([node.lower()])
            names |= set(map(str.lower, data.get("synonyms", [])))
            names |= set(map(str.lower, data.get("tags", [])))
            for alt in data.get("alternatives", []): names.add(alt.lower())
            if tokens & names:
                matched_tables.add(node)
        # Column match
        if data.get("type") == "column":
            colname = node.split('.')[-1].lower()
            colnames = set([colname])
            colnames |= set(map(str.lower, data.get("synonyms", []))) if "synonyms" in data else set()
            if tokens & colnames:
                matched_columns.add(node)
    return matched_tables, matched_columns


In [None]:
def build_dynamic_prompt(
    nlq,
    schema_dict,
    examples=None,
    example_nlq=None,
    example_sql=None,
    max_examples=3
):
    """
    Build a robust prompt for NL-to-SQL generation.
    Includes:
      - Relevant schema subset (with desc, columns, synonyms, and samples)
      - Up to N NLQ–SQL few-shot examples (if any)
      - The NLQ for completion
    """

    if not isinstance(nlq, str) or not nlq.strip():
        raise ValueError("NLQ is missing or invalid for prompt building.")
    if not isinstance(schema_dict, dict) or not schema_dict:
        raise ValueError("Schema_dict is empty or invalid.")

    if examples is None:
        examples = []

    # --- Schema Context ---
    def render_table(table, meta):
        desc = meta.get('description', '')
        s = f"Table: {table}\n  Description: {desc}\n  Columns:"
        for col, cmeta in meta.get('columns', {}).items():
            coltype = cmeta.get('col_type', '')
            sample = cmeta.get('sample')
            syns = cmeta.get('synonyms', [])
            s += f"\n    - {col} ({coltype})"
            if sample:
                s += f" [sample: {sample}]"
            if syns:
                s += f" [synonyms: {', '.join(syns)}]"
        return s + "\n"

    schema_text = "\n".join([render_table(t, schema_dict[t]) for t in schema_dict])
    if not schema_text.strip():
        schema_text = "Schema: [No relevant tables found for this question]\n"

    # --- Example Block ---
    shots = ""
    num_examples = 0
    for ex in (examples or []):
        if num_examples >= max_examples:
            break
        # Accept keys in any case (NLQ/nlq)
        nlq_ex = ex.get('NLQ') or ex.get('nlq')
        sql_ex = ex.get('SQL') or ex.get('sql')
        if nlq_ex and sql_ex:
            shots += f"Example:\nNLQ: {nlq_ex}\nSQL: {sql_ex}\n\n"
            num_examples += 1

    if not shots and example_nlq and example_sql:
        shots += f"Example:\nNLQ: {example_nlq}\nSQL: {example_sql}\n\n"

    # --- Build Prompt String ---
    prompt = f"""You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
{schema_text}

{shots}Now generate ONLY the SQL for this NLQ (no explanation, no markdown):

NLQ: {nlq}
SQL:
""".strip()

    # Defensive: Ensure never returns None
    if not prompt or not isinstance(prompt, str):
        raise ValueError("Prompt built to empty/None. Inputs likely invalid.")

    return prompt


In [None]:
import torch
import re

def infer_sql(
    prompt,
    tokenizer,
    model,
    max_tokens=256,
    clean=True
):
    """
    Generate SQL from a prompt using a HuggingFace LLM.
    Args:
        prompt (str): The input prompt.
        tokenizer: HuggingFace tokenizer.
        model: HuggingFace model.
        max_tokens (int): Max number of new tokens to generate.
        clean (bool): Whether to post-process the output for SQL only.
    Returns:
        result_sql (str): The generated SQL string.
    Raises:
        Exception if model or tokenizer fails.
    """
    if not prompt or not isinstance(prompt, str):
        raise ValueError("Prompt is empty or not a string!")

    # Some tokenizers require 'text' kwarg, some accept just the string as the first arg.
    try:
        # Try the most universal method (for fast and legacy tokenizers)
        encoded = tokenizer(prompt, return_tensors="pt", truncation=True)
    except Exception as e:
        # Fallback for legacy HuggingFace models/tokenizers
        encoded = tokenizer.__call__(prompt, return_tensors="pt", truncation=True)

    input_ids = encoded.input_ids.to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else None,
        )
    # Only take the generated continuation (not the prompt itself)
    generated = outputs[0][input_ids.shape[-1]:]
    result_sql = tokenizer.decode(generated, skip_special_tokens=True).strip()

    if clean:
        # Remove prompt repetition, markdown, or extra headers
        lines = result_sql.splitlines()
        # Usually the first nonempty line is the SQL (but some models might add extra lines)
        lines = [ln for ln in lines if ln.strip()]
        if lines:
            result_sql = lines[0]
        result_sql = re.sub(r"^(SQL:|NLQ:|Schema:).*", "", result_sql, flags=re.IGNORECASE).strip()
        result_sql = re.sub(r"^`+", "", result_sql)
        sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH"]
        tokens = result_sql.split()
        if tokens and tokens[0].upper() not in sql_keywords:
            for i, tok in enumerate(tokens):
                if tok.upper() in sql_keywords:
                    result_sql = " ".join(tokens[i:])
                    break

    return result_sql


In [None]:
# 1. Imports
import networkx as nx
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 2. Model & Tokenizer Path
model_path = "seeklhy/OmniSQL-7B"

# 3. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# 4. Load Model (with smart dtype/device)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

# 5. Model in evaluation mode (recommended for inference)
model.eval()

# 6. Utility: Move to CUDA if available (optional; device_map usually handles this)
if torch.cuda.is_available():
    model = model.to("cuda")
else:
    model = model.to("cpu")


In [None]:
complex_samples = [
    {
        "NLQ": "what is the full name and id of the college with the largest number of baseball players?",
        "Gold SQL": "SELECT T1.name_full, T1.college_id FROM college AS T1 JOIN player_college AS T2 ON T1.college_id = T2.college_id GROUP BY T1.college_id ORDER BY count(*) DESC LIMIT 1;"
    },
    {
        "NLQ": "What is average salary of the players in the team named 'Boston Red Stockings' ?",
        "Gold SQL": "SELECT avg(T1.salary) FROM salary AS T1 JOIN team AS T2 ON T1.team_id = T2.team_id_br WHERE T2.name = 'Boston Red Stockings';"
    },
    {
        "NLQ": "What are first and last names of players participating in all star game in 1998?",
        "Gold SQL": "SELECT name_first, name_last FROM player AS T1 JOIN all_star AS T2 ON T1.player_id = T2.player_id WHERE YEAR = 1998;"
    },
    {
        "NLQ": "What are the first name, last name and id of the player with the most all star game experiences? Also list the count.",
        "Gold SQL": "SELECT T1.name_first, T1.name_last, T1.player_id, count(*) FROM player AS T1 JOIN all_star AS T2 ON T1.player_id = T2.player_id GROUP BY T1.player_id ORDER BY count(*) DESC LIMIT 1;"
    },
    {
        "NLQ": "How many players enter hall of fame each year?",
        "Gold SQL": "SELECT yearid, count(*) FROM hall_of_fame GROUP BY yearid;"
    },
    {
        "NLQ": "In 2014, what are the id and rank of the team that has the largest average number of attendance?",
        "Gold SQL": "SELECT T2.team_id, T2.rank FROM home_game AS T1 JOIN team AS T2 ON T1.team_id = T2.team_id WHERE T1.year = 2014 GROUP BY T1.team_id ORDER BY avg(T1.attendance) DESC LIMIT 1;"
    },
    {
        "NLQ": "What are the manager's first name, last name and id who won the most manager award?",
        "Gold SQL": "SELECT T1.name_first, T1.name_last, T2.player_id FROM player AS T1 JOIN manager_award AS T2 ON T1.player_id = T2.player_id GROUP BY T2.player_id ORDER BY count(*) DESC LIMIT 1;"
    },
    {
        "NLQ": "Which 3 players won the most player awards? List their full name and id.",
        "Gold SQL": "SELECT T1.name_first, T1.name_last, T1.player_id FROM player AS T1 JOIN player_award AS T2 ON T1.player_id = T2.player_id GROUP BY T1.player_id ORDER BY count(*) DESC LIMIT 3;"
    },
    {
        "NLQ": "List three countries which are the origins of the least players.",
        "Gold SQL": "SELECT birth_country FROM player GROUP BY birth_country ORDER BY count(*) ASC LIMIT 3;"
    },
    {
        "NLQ": "Find all the players' first name and last name who have empty death record.",
        "Gold SQL": "SELECT name_first, name_last FROM player WHERE death_year = '';"
    }
    ]
#     {
#         "NLQ": "What is the average height of the players from the college named 'Yale University'?",
#         "Gold SQL": "SELECT avg(T1.height) FROM player AS T1 JOIN player_college AS T2 ON T1.player_id = T2.player_id JOIN college AS T3 ON T3.college_id = T2.college_id WHERE T3.name_full = 'Yale University';"
#     },
#     {
#         "NLQ": "What is the highest salary among each team? List the team name, id and maximum salary.",
#         "Gold SQL": "SELECT T1.name, T1.team_id, max(T2.salary) FROM team AS T1 JOIN salary AS T2 ON T1.team_id = T2.team_id GROUP BY T1.team_id;"
#     },
#     {
#         "NLQ": "What are the name and id of the team offering the lowest average salary?",
#         "Gold SQL": "SELECT T1.name, T1.team_id FROM team AS T1 JOIN salary AS T2 ON T1.team_id = T2.team_id GROUP BY T1.team_id ORDER BY avg(T2.salary) ASC LIMIT 1;"
#     },
#     {
#         "NLQ": "Find the players' first name and last name who won award both in 1960 and in 1961.",
#         "Gold SQL": "SELECT T1.name_first, T1.name_last FROM player AS T1 JOIN player_award AS T2 WHERE T2.year = 1960 INTERSECT SELECT T1.name_first, T1.name_last FROM player AS T1 JOIN player_award AS T2 WHERE T2.year = 1961;"
#     },
#     {
#         "NLQ": "List players' first name and last name who have weight greater than 220 or height shorter than 75.",
#         "Gold SQL": "SELECT name_first, name_last FROM player WHERE weight > 220 OR height < 75;"
#     },
#     {
#         "NLQ": "What are the maximum scores the team Boston Red Stockings got when the team won in postseason?",
#         "Gold SQL": "SELECT max(T1.wins) FROM postseason AS T1 JOIN team AS T2 ON T1.team_id_winner = T2.team_id_br WHERE T2.name = 'Boston Red Stockings';"
#     },
#     {
#         "NLQ": "How many times did Boston Red Stockings lose in 2009 postseason?",
#         "Gold SQL": "SELECT count(*) FROM postseason AS T1 JOIN team AS T2 ON T1.team_id_loser = T2.team_id_br WHERE T2.name = 'Boston Red Stockings' AND T1.year = 2009;"
#     },
#     {
#         "NLQ": "What are the name and id of the team with the most victories in 2008 postseason?",
#         "Gold SQL": "SELECT T2.name, T1.team_id_winner FROM postseason AS T1 JOIN team AS T2 ON T1.team_id_winner = T2.team_id_br WHERE T1.year = 2008 GROUP BY T1.team_id_winner ORDER BY count(*) DESC LIMIT 1;"
#     },
#     {
#         "NLQ": "What is the total number of postseason games that team Boston Red Stockings participated in?",
#         "Gold SQL": "SELECT count(*) FROM ( SELECT * FROM postseason AS T1 JOIN team AS T2 ON T1.team_id_winner = T2.team_id_br WHERE T2.name = 'Boston Red Stockings' UNION SELECT * FROM postseason AS T1 JOIN team AS T2 ON T1.team_id_loser = T2.team_id_br WHERE T2.name = 'Boston Red Stockings' );"
#     },
#     {
#         "NLQ": "What is the total salary paid by team Boston Red Stockings in 2010?",
#         "Gold SQL": "SELECT sum(T1.salary) FROM salary AS T1 JOIN team AS T2 ON T1.team_id = T2.team_id_br WHERE T2.name = 'Boston Red Stockings' AND T1.year = 2010;"
#     }
# ]


In [None]:
results = []

for sample in complex_samples:
    nlq = sample["NLQ"]
    gold_sql = sample["Gold SQL"]

    # Step 1: Use KG to get relevant schema & examples
    sub_schema, exs = get_relevant_schema_from_kg(nlq, G, max_hops=2)

    # Step 2: Prepare examples for the prompt
    few_shots = []
    for ex in exs:
        if 'NLQ' in ex and 'SQL' in ex:
            few_shots.append({"NLQ": ex["NLQ"], "SQL": ex["SQL"]})
    few_shots = few_shots[:3]  # up to 3 examples

    # Step 3: Build the dynamic prompt
    prompt = build_dynamic_prompt(
        nlq=nlq,
        schema_dict=sub_schema,
        examples=few_shots,
        example_nlq=None,
        example_sql=None
    )
    import transformers
    print("Transformers version:", transformers.__version__)
    print("Tokenizer class:", type(tokenizer))
    print (prompt)
    # Step 4: Model inference (no timing)
    #result_sql = infer_sql(prompt, tokenizer=tokenizer, model=model, max_tokens=256)
    result_sql = infer_sql(
      prompt,
      tokenizer=tokenizer,
      model=model,
      max_tokens=256
  )
    # try:
    #     result_sql = infer_sql(prompt, tokenizer=tokenizer, model=model, max_tokens=256)
    # except Exception as e:
    #     result_sql = f"ERROR: {e}"

    # Step 5: Append results
    results.append({
        "NLQ": nlq,
        "Gold SQL": gold_sql,
        "Predicted SQL": result_sql,
        "Prompt": prompt
    })

    print("\n---")
    print("NLQ:", nlq)
    print("Generated SQL:", result_sql)

# --- 3. Results DataFrame ---
import pandas as pd
results_df = pd.DataFrame(results)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - death_city (TEXT)
    - name_first (TEXT) [sample: David]
    - name_last (TEXT) [sample: Aardsma]
    - name_given (TEXT) [sample: David Allan]
    - weight (NUMERIC) [sample: 220]
    - height (NUMERIC) [sample: 75]
    - bats (TEXT) [sample: R]
    - throws (TEXT) [sample: R

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: what is the full name and id of the college with the largest number of baseball players?
Generated SQL: SELECT c.college_id, c.name_full FROM college AS c JOIN player AS p ON c.college_id = p.bbref_id GROUP BY c.college_id ORDER BY COUNT(p.player_id) DESC LIMIT 1;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - dea

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: What is average salary of the players in the team named 'Boston Red Stockings' ?
Generated SQL: SELECT AVG(salary.salary) AS average_salary FROM salary JOIN team ON salary.team_id = team.team_id_br WHERE team.name = 'Boston Red Stockings';
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - death_city (TEXT)
    - name

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: What are first and last names of players participating in all star game in 1998?
Generated SQL: SELECT DISTINCT p.name_first, p.name_last FROM player AS p JOIN player_allstar AS pa ON p.player_id = pa.player_id WHERE pa.year = 1998;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - death_city (TEXT)
    - name_first 

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: What are the first name, last name and id of the player with the most all star game experiences? Also list the count.
Generated SQL: SELECT name_first, name_last, player_id, COUNT(*) as all_star_count FROM player GROUP BY player_id ORDER BY all_star_count DESC LIMIT 1;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: How many players enter hall of fame each year?
Generated SQL: SELECT COUNT(*) FROM player;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: team
  Description: 
  Columns:
    - year (INTEGER) [sample: 1871]
    - league_id (TEXT)
    - team_id (TEXT) [sample: BS1]
    - franchise_id (TEXT) [sample: BNA]
    - div_id (TEXT)
    - rank (INTEGER) [sample: 3]
    - g (INTEGER) [sample: 31]
    - ghome (NUMERIC)
    - w (INTEGER) [sample: 20]
    - l (INTEGER) [sample: 10]
    - div_win (TEXT)
    - wc_win (TEXT)
    - lg_win (TEXT) [sample: N]
    - ws_win (TEXT)
    - r (INTEGER) [sample: 401]
    - ab (INTEGER) [sample: 1372]
    - h (INTEGER) [sample: 426]
    - double (INTEGER) [sample: 70]
    - triple (INTEGER) [sample: 37]
    - hr (INTEGER) [sample: 3]
    - b

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: In 2014, what are the id and rank of the team that has the largest average number of attendance?
Generated SQL: SELECT team_id, rank FROM team WHERE year = 2014 GROUP BY team_id ORDER BY AVG(attendance) DESC LIMIT 1;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: manager
  Description: 
  Columns:
    - player_id (TEXT) [sample: wrighha01]
    - year (INTEGER) [sample: 1871]
    - team_id (TEXT) [sample: BS1]
    - league_id (TEXT)
    - inseason (INTEGER) [sample: 1]
    - g (INTEGER) [sample: 31]
    - w (INTEGER) [sample: 20]
    - l (INTEGER) [sample: 10]
    - rank (NUMERIC) [sample: 3]
    - plyr_mgr (TEXT) [sample: Y]

Table: team
  Description: 
  Columns:
    - year (INTEGER) [sample: 1871]
    - league_id (TEXT)
    - team_id (TEXT) [sample: BS1]
    - 

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: What are the manager's first name, last name and id who won the most manager award?
Generated SQL: SELECT m.plyr_mgr AS manager_name, COUNT(*) AS num_awards FROM manager m WHERE m.plyr_mgr = 'Y' GROUP BY m.plyr_mgr ORDER BY num_awards DESC LIMIT 1;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - death_city (TEXT)
 

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: Which 3 players won the most player awards? List their full name and id.
Generated SQL: SELECT p.name_first || ' ' || p.name_last AS full_name, p.player_id FROM player p WHERE p.player_id IN (SELECT pa.player_id FROM player_awards pa GROUP BY pa.player_id ORDER BY COUNT(pa.award_id) DESC LIMIT 3);
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



---
NLQ: List three countries which are the origins of the least players.
Generated SQL: SELECT birth_country FROM player GROUP BY birth_country ORDER BY COUNT(*) ASC LIMIT 3;
Transformers version: 4.53.0
Tokenizer class: <class 'transformers.models.qwen2.tokenization_qwen2_fast.Qwen2TokenizerFast'>
You are an expert SQL agent for the baseball database.
OUTPUT ONLY THE SQL QUERY—no explanations, markdown, or commentary.

Schema:
Table: player
  Description: 
  Columns:
    - player_id (TEXT) [sample: aardsda01]
    - birth_year (NUMERIC) [sample: 1981]
    - birth_month (NUMERIC) [sample: 12]
    - birth_day (NUMERIC) [sample: 27]
    - birth_country (TEXT) [sample: USA]
    - birth_state (TEXT) [sample: CO]
    - birth_city (TEXT) [sample: Denver]
    - death_year (NUMERIC)
    - death_month (NUMERIC)
    - death_day (NUMERIC)
    - death_country (TEXT)
    - death_state (TEXT)
    - death_city (TEXT)
    - name_first (TEXT) [sample: David]
    - name_last (TEXT) [sample: Aardsma]
  

In [None]:
import re

def clean_predicted_sql(pred_sql: str) -> str:
    """
    Cleans predicted SQL query to remove markdown, explanations, and only keep first statement.
    Keeps only the first complete SQL statement (ending with a semicolon).
    """
    if not isinstance(pred_sql, str):
        return ""
    # Remove markdown code blocks and backticks (works for ```sql, ``` and `)
    cleaned = re.sub(r"^```(?:sql)?|```|`", "", pred_sql, flags=re.MULTILINE | re.IGNORECASE)
    # Remove explanations before the first SQL verb (case-insensitive, matches on word boundary)
    cleaned_split = re.split(r'(?i)\b(SELECT|INSERT|UPDATE|DELETE)\b', cleaned, maxsplit=1)
    if len(cleaned_split) > 2:
        cleaned = cleaned_split[1] + cleaned_split[2]
    elif len(cleaned_split) == 2:
        cleaned = cleaned_split[1]
    else:
        cleaned = cleaned_split[0]
    # Only keep up to the first semicolon (inclusive, in case of multiple statements)
    semicolon_pos = cleaned.find(';')
    if semicolon_pos != -1:
        cleaned = cleaned[:semicolon_pos + 1]
    # Remove trailing/leading whitespace and markdown artifacts
    cleaned = cleaned.strip()
    cleaned = re.sub(r"```", "", cleaned)
    cleaned = re.sub(r"[`*#\-]+", "", cleaned)
    # Ensure a single semicolon at the end
    cleaned = cleaned.rstrip(';') + ';'
    # Collapse internal whitespace to single space
    cleaned = re.sub(r"\s+", " ", cleaned)
    return cleaned.strip()


In [None]:
import time

def run_query(sql, conn):
    """
    Execute a SQL query and return canonicalized results for comparison.

    Args:
        sql (str): The SQL statement to execute.
        conn (sqlite3.Connection): The SQLite database connection.

    Returns:
        result (tuple of tuples): Sorted, hashable result rows.
        error (str or None): Error message if any exception occurs, else None.
        elapsed (float): Time taken to execute the query, in seconds.
    """
    try:
        start = time.time()
        cur = conn.execute(sql)
        rows = cur.fetchall()
        elapsed = time.time() - start
        # Canonicalize result for consistent comparison (sorted, hashable)
        result = tuple(sorted(tuple(row) for row in rows))
        return result, None, elapsed
    except Exception as e:
        return None, str(e), 0.0


In [None]:
conn = sqlite3.connect(db_path)

In [None]:
import time
#results_df = results_sql #pd.DataFrame(results)

exec_results = []

for i, row in results_df.iterrows():
    nlq = row["NLQ"]
    gold_sql = row["Gold SQL"]
    raw_predicted_sql = row["Predicted SQL"]

    # Clean model output to get valid SQL
    pred_sql = clean_predicted_sql(raw_predicted_sql)

    # Run gold SQL
    gold_result, gold_error, gold_time = run_query(gold_sql, conn)

    # Run predicted SQL
    pred_result, pred_error, pred_time = run_query(pred_sql, conn)

    # Execution accuracy: results must match and no errors in either
    exec_match = (
        (gold_result == pred_result) and
        (gold_error is None) and
        (pred_error is None)
    )

    exec_results.append({
        "NLQ": nlq,
        "Gold SQL": gold_sql,
        "Predicted SQL": pred_sql,
        "Gold Result": gold_result,
        "Predicted Result": pred_result,
        "Gold Error": gold_error,
        "Pred Error": pred_error,
        "Gold Time (s)": gold_time,
        "Pred Time (s)": pred_time,
        "Exec_Match": exec_match
    })

# If you want to convert this to a DataFrame at the end:
import pandas as pd
exec_results_df = pd.DataFrame(exec_results)


In [None]:
exec_df = pd.DataFrame(exec_results)
accuracy = exec_df["Exec_Match"].mean()

print(f"Execution Accuracy: {accuracy:.2%} ({exec_df.Exec_Match.sum()} / {len(exec_df)})")

Execution Accuracy: 30.00% (3 / 10)


In [None]:
# Only in Jupyter/Colab:
from IPython.display import display

print("\nMismatched Results:")
display(exec_df[~exec_df["Exec_Match"]][[
    "NLQ", "Gold SQL", "Predicted SQL", "Gold Result", "Predicted Result", "Gold Error", "Pred Error"
]])


Mismatched Results:


Unnamed: 0,NLQ,Gold SQL,Predicted SQL,Gold Result,Predicted Result,Gold Error,Pred Error
0,what is the full name and id of the college wi...,"SELECT T1.name_full, T1.college_id FROM colleg...","SELECT c.college_id, c.name_full FROM college ...","((University of Texas at Austin, texas),)",(),,
2,What are first and last names of players parti...,"SELECT name_first, name_last FROM player AS T1...","SELECT DISTINCT p.name_first, p.name_last FROM...","((Aaron, Sele), (Alex, Rodriguez), (Andres, Ga...",,,no such table: player_allstar
3,"What are the first name, last name and id of t...","SELECT T1.name_first, T1.name_last, T1.player_...","SELECT name_first, name_last, player_id, COUNT...","((Hank, Aaron, aaronha01, 25),)","((Tony, Zych, zychto01, 1),)",,
4,How many players enter hall of fame each year?,"SELECT yearid, count(*) FROM hall_of_fame GROU...",SELECT COUNT() FROM player;,"((1936, 110), (1937, 118), (1938, 122), (1939,...","((18846,),)",,
6,"What are the manager's first name, last name a...","SELECT T1.name_first, T1.name_last, T2.player_...","SELECT m.plyr_mgr AS manager_name, COUNT() AS ...","((Bobby, Cox, coxbo01),)","((Y, 645),)",,
7,Which 3 players won the most player awards? Li...,"SELECT T1.name_first, T1.name_last, T1.player_...",SELECT p.name_first || ' ' || p.name_last AS f...,"((Barry, Bonds, bondsba01), (Joe, DiMaggio, di...",,,no such table: player_awards
9,Find all the players' first name and last name...,"SELECT name_first, name_last FROM player WHERE...","SELECT name_first, name_last FROM player WHERE...","((, Boland), (, Booth), (, Carroll), (, Collin...",(),,
