In [1]:
%run ../utils/init_env.py

In [2]:
import os
import re
import ast
import json
import config
import sqlite3
from collections import defaultdict

In [3]:
def parse_schema_sqlite(db_path, max_samples=3):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]

    formatted_tables = []

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table});")
        columns = [row[1] for row in cursor.fetchall()]  # row[1] is column name

        cursor.execute(f"SELECT * FROM {table} LIMIT {max_samples};")
        rows = cursor.fetchall()

        col_samples = list(zip(*rows)) if rows else [[] for _ in columns]

        col_strs = []
        for col, vals in zip(columns, col_samples):
            val_list = ", ".join(str(v) for v in vals[:max_samples])
            col_strs.append(f"{col}[{val_list}]")
        formatted = f"# {table}(" + ", ".join(col_strs) + ")"
        formatted_tables.append(formatted)

    conn.close()
    return formatted_tables


def build_llm_prompt_with_data(entry, formatted_data_lines, example_qas=None):
    header = (
        "### Answer the question by SQLite SQL query only and with no explanation. "
        "You must minimize SQL execution time while ensuring correctness.\n"
    )

    schema = f"### Sqlite SQL tables, with their properties:\n#\n{entry['simplified_ddl']}\n#\n"

    data_section = "### Here is some data information about database references.\n#\n"
    data_section += "\n".join(formatted_data_lines) + "\n#\n"

    foreign_keys = "### Foreign key information of SQLite tables, used for table joins:\n#\n"
    foreign_keys += "# " + "\n# ".join(entry["foreign_key"]) + "\n#\n"

    few_shot = ""
    if example_qas:
        few_shot = "### Some example pairs of questions and corresponding SQL queries are provided based on similar questions:\n"
        for q, sql in example_qas:
            few_shot += f"### {q}\n{sql}\n"

    question = f"### {entry['question']}\n"

    return f"{header}{schema}{data_section}{foreign_keys}{few_shot}{question}"

In [4]:
with open(os.path.join(config.PREPROCESSED_JSON), 'r', encoding='utf-8') as f:
    entries = json.load(f)

print(len(entries))

2147


In [None]:
example_qas= [
    ("How many farms are there?", "SELECT count(*) FROM farm"),
    ("What is the average, minimum, and maximum age for all French singers?", "SELECT avg(age) ,  min(age) ,  max(age) FROM singer WHERE country  =  'France'"),
    ("Show the ID of the high schooler named Kyle.", "SELECT ID FROM Highschooler WHERE name  =  \"Kyle\"")
]

prompts = []
for entry in entries:
    schema_path = f'{os.path.join(config.SPIDER_DATA_DIR, "test_database")}/{entry["db"]}/{entry["db"]}.sqlite'
    formatted_data_lines = parse_schema_sqlite(schema_path)
    prompt = build_llm_prompt_with_data(entry, formatted_data_lines, example_qas)
    prompts.append({
        "id": entry['id'],
        "db": entry['db'],
        "gold_sql": entry["gold_sql"],
        'prompt': prompt
    }
    )

# Save output
with open(config.PROMPTS_JSON, "w", encoding="utf-8") as f:
    json.dump(prompts, f, indent=4)

# for prompt in prompts[:]:
#     print(prompt)
#     print("-" * 80)

IndentationError: unexpected indent (1371661793.py, line 12)