# inference example

In [31]:
import json
import os
import re
import sqlite3

from typing import Any, Literal

import openai

from dotenv import load_dotenv
from openai import AzureOpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt


In [7]:
load_dotenv()

True

In [9]:
# for o3, we need version 1.61.1 so upgrade if needed
print(openai.__version__)

1.61.1


## basic functions

In [15]:
# create the client (CAPS because GLOBAL)
AZURE_CLIENT = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2024-12-01-preview",
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    timeout=60,
    max_retries=5,
)

### database functions

taken from `src/text2sql/data/sqlite_functions.py` and `src/text2sql/data/schema_to_text.py`

for other formats besides DAIL-SQL-like, see `schema_to_text.py`

In [27]:
def get_sqlite_database_file(base_dir: str, database: str) -> str:
    """get path to sqlite database file based on dataset and database name"""
    # support nested and flat directory structures
    sqlite_flat_path = os.path.join(base_dir, database + ".sqlite")
    sqlite_nested_path = os.path.join(base_dir, database, database + ".sqlite")
    for sqlite_path in [sqlite_flat_path, sqlite_nested_path]:
        if os.path.exists(sqlite_path):
            return sqlite_path
    raise FileNotFoundError(f"Database file for {database=} not found in {base_dir=}")

In [28]:
def query_sqlite_database(base_dir: str, database: str, sql_query: str) -> list[dict]:
    """query sqlite database and return results"""
    db_path = get_sqlite_database_file(base_dir=base_dir, database=database)
    connection = sqlite3.connect(db_path)
    connection.row_factory = sqlite3.Row
    cursor = connection.cursor()
    result = cursor.execute(sql_query)
    json_result = [dict(r) for r in result.fetchall()]
    connection.close()
    return json_result

In [29]:
def get_sqlite_schema(base_dir: str, database: str) -> dict[str, Any]:
    """get sqlite schema, columns, relations as a dictionary"""
    database_path = get_sqlite_database_file(base_dir=base_dir, database=database)
    connection = sqlite3.connect(database_path)
    cursor = connection.cursor()

    schema = {"tables": {}}

    # Get table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = cursor.fetchall()

    for table in tables:
        table_name = table[0]
        schema["tables"][table_name] = {"columns": {}, "keys": {}, "foreign_keys": {}}

        # Get column information
        cursor.execute(f"PRAGMA table_info('{table_name}')")
        columns = cursor.fetchall()
        for column in columns:
            cid, col_name, col_type, is_notnull, default_value, is_pk = column
            schema["tables"][table_name]["columns"][col_name] = col_type
            if is_pk:
                schema["tables"][table_name]["keys"]["primary_key"] = [col_name]

        # Get foreign key information
        cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
        foreign_keys = cursor.fetchall()
        for fk in foreign_keys:
            _, _, ref_table, col_name, ref_col, *_ = fk
            schema["tables"][table_name]["foreign_keys"][col_name] = {
                "referenced_table": ref_table,
                "referenced_column": ref_col,
            }

    cursor.close()
    connection.close()
    return schema

In [30]:
def schema_to_basic_format(
    database_name: str, schema: dict[str, Any], include_types: bool = False, include_relations: bool = False
) -> str:
    """represent schema in basic table (column, column, ...) format (following DAIL-SQL)

    this supports optional inclusion of column types and relations
    """
    output = []

    for table_name, table_info in schema["tables"].items():
        columns = []
        for col_name, col_type in table_info["columns"].items():
            col_name = str(col_name)  # Convert to string in case it's an integer
            if include_types:
                columns.append(f"{col_name} ({col_type})")
            else:
                columns.append(col_name)

        table_line = f"table '{table_name}' with columns: {' , '.join(columns)}"
        output.append(table_line)

    if include_relations:
        output.append("\nRelations:")
        for table_name, table_info in schema["tables"].items():
            if "foreign_keys" in table_info and table_info["foreign_keys"]:
                for fk_column, fk_info in table_info["foreign_keys"].items():
                    fk_column = str(fk_column)  # Convert to string in case it's an integer
                    ref_table = fk_info["referenced_table"]
                    ref_column = fk_info["referenced_column"]
                    relation = f"{table_name}.{fk_column} -> {ref_table}.{ref_column}"
                    output.append(relation)

    return "\n".join(output)

### inference functions

In [55]:
def inference_gpt4o(messages: list[dict], deployment_name: str, temperature: float) -> str:
    """inference using gpt4o family model"""
    result = AZURE_CLIENT.chat.completions.create(
        model=deployment_name,
        messages=messages,
        temperature="medium"
    )
    result_text = result.choices[0].message.content
    return result_text

In [56]:
def inference_o3(messages: list[dict], deployment_name: str, reasoning_effort: Literal["low", "medium", "high"]) -> str:
    """inference using o3 family model"""
    result = AZURE_CLIENT.chat.completions.create(
        model=deployment_name,
        messages=messages,
        reasoning_effort=reasoning_effort
    )
    result_text = result.choices[0].message.content
    return result_text

In [48]:
def extract_first_code_block(text: str) -> str:
    """extract code block contents from llm output"""
    pattern = r'```(?:sql|json|\w*)\n?(.*?)\n?```'
    matches = re.finditer(pattern, text, re.DOTALL)
    results = []
    for match in matches:
        content = match.group(1).strip()
        results.append(content)
    if len(results) == 0:
        return None
    return results[0]

## data processing

In [35]:
# point to BIRD dev.json file
dev_json_file = "/data/sql_datasets/bird/dev_20240627/dev.json"

# point to BIRD dev databases base directory
dev_database_base_dir = "/data/sql_datasets/bird/dev_20240627/dev_databases"


In [36]:
# load dev data
with open(dev_json_file, "r") as f:
    dev_data = json.load(f)
print(f"Loaded {len(dev_data)} examples from {dev_json_file}")

Loaded 1534 examples from /data/sql_datasets/bird/dev_20240627/dev.json


In [37]:
print(json.dumps(dev_data[:2], indent=2))

[
  {
    "question_id": 0,
    "db_id": "california_schools",
    "question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?",
    "evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`",
    "SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
    "difficulty": "simple"
  },
  {
    "question_id": 1,
    "db_id": "california_schools",
    "question": "Please list the lowest three eligible free rates for students aged 5-17 in continuation schools.",
    "evidence": "Eligible free rates for students aged 5-17 = `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)`",
    "SQL": "SELECT `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` FROM frpm WHERE `Educational Option Type` = 'Continuation School' AND `Free Meal Count (Ages 5-17)` / `Enrollment (Ages 5-17)` IS 

In [41]:
# get the schema for the first sample ("california_schools")
# first, get the tables and columns info
schema_info: dict = get_sqlite_schema(base_dir=dev_database_base_dir, database=dev_data[0]["db_id"])
# then, format it to DAIL-SQL format, here we add types and relations (db name not used here, but used in related functions)
schema_description: str = schema_to_basic_format(dev_data[0]["db_id"], schema_info, include_types=True, include_relations=True)
print(schema_description)

table 'frpm' with columns: CDSCode (TEXT) , Academic Year (TEXT) , County Code (TEXT) , District Code (INTEGER) , School Code (TEXT) , County Name (TEXT) , District Name (TEXT) , School Name (TEXT) , District Type (TEXT) , School Type (TEXT) , Educational Option Type (TEXT) , NSLP Provision Status (TEXT) , Charter School (Y/N) (INTEGER) , Charter School Number (TEXT) , Charter Funding Type (TEXT) , IRC (INTEGER) , Low Grade (TEXT) , High Grade (TEXT) , Enrollment (K-12) (REAL) , Free Meal Count (K-12) (REAL) , Percent (%) Eligible Free (K-12) (REAL) , FRPM Count (K-12) (REAL) , Percent (%) Eligible FRPM (K-12) (REAL) , Enrollment (Ages 5-17) (REAL) , Free Meal Count (Ages 5-17) (REAL) , Percent (%) Eligible Free (Ages 5-17) (REAL) , FRPM Count (Ages 5-17) (REAL) , Percent (%) Eligible FRPM (Ages 5-17) (REAL) , 2013-14 CALPADS Fall 1 Certification Status (INTEGER)
table 'satscores' with columns: cds (TEXT) , rtype (TEXT) , sname (TEXT) , dname (TEXT) , cname (TEXT) , enroll12 (INTEGER) 

## inference

this is example with `o3` model

In [43]:
# create messages for inference
# gpt-4o uses "system prompts" and o3 uses "developer messages" 
# but the api is forwards and backwards-compatible so we use "system prompt" for both
system_prompt = "You are a data scientist who writes SQL. Based on the information provided by the user, write an SQL query to answer their question. Output the result as SQL inside a markdown code block. do not output anything else."

# load the data from first question
idx = 0
question = dev_data[idx]["question"]
evidence = dev_data[idx]["evidence"]
schema_info: dict = get_sqlite_schema(base_dir=dev_database_base_dir, database=dev_data[0]["db_id"])
schema_description: str = schema_to_basic_format(dev_data[0]["db_id"], schema_info, include_types=True, include_relations=True)

# user input
user_message = f"""database schema:
```
{schema_description}
```

question: {question}
evidence: {evidence}

give me the appropriate SQLite SQL query in a markdown code block."""

# messages
messages: list[dict] = [
    {"role": "system", "content": system_prompt}, 
    {"role": "user", "content": user_message}
]

In [57]:
# run inference
raw_output: str = inference_o3(
    messages=messages, 
    deployment_name="gena-o3-mini", 
    reasoning_effort="medium"
)
print(raw_output)

```sql
SELECT MAX("Free Meal Count (K-12)" / "Enrollment (K-12)") AS highest_free_rate
FROM frpm
WHERE "County Name" = 'Alameda County'
  AND "Enrollment (K-12)" > 0;
```


In [58]:
# process output
sql_prediction: str = extract_first_code_block(raw_output)
print(sql_prediction)

SELECT MAX("Free Meal Count (K-12)" / "Enrollment (K-12)") AS highest_free_rate
FROM frpm
WHERE "County Name" = 'Alameda County'
  AND "Enrollment (K-12)" > 0;


## get database result

In [59]:
# execute the SQL query
result: list[dict] = query_sqlite_database(
    base_dir=dev_database_base_dir, 
    database=dev_data[idx]["db_id"],  # "california_schools"
    sql_query=sql_prediction
)

In [61]:
print(result)

[{'highest_free_rate': None}]
