In [1]:
#!/usr/bin/env python3
"""
Ask natural-language questions of a SQLite database using GPT.
"""

import os
import sqlite3
from dotenv import load_dotenv
from openai import OpenAI

# ---------- Setup ----------
_ = load_dotenv(".dbenv")

# This dbenv only needs an OPENAI_API_KEY
# No LANGCHAIN_API_KEY is needed as langchain has been removed

client = OpenAI()
DB_PATH = "eunomia.sqlite"
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()


# ---------- Utilities ----------
def get_schema() -> str:
    """Return database schema text for GPT context."""
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    schema = ""
    for (table_name,) in tables:
        cursor.execute(f"PRAGMA table_info({table_name});")
        cols = [c[1] for c in cursor.fetchall()]
        schema += f"Table {table_name}: {cols}\n"
    return schema.strip()


def clean_sql(sql: str) -> str:
    """Remove Markdown fences."""
    sql = sql.strip()
    if sql.startswith("```"):
        sql = sql.split("\n", 1)[-1]
        if sql.endswith("```"):
            sql = sql.rsplit("```", 1)[0]
    return sql.strip()


def chat(prompt: str, role: str = "SQL assistant") -> str:
    """Convenience wrapper for a single chat completion."""
    r = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": f"You are a helpful {role}."},
            {"role": "user", "content": prompt},
        ],
        temperature=0,
    )
    return r.choices[0].message.content.strip()


# ---------- GPT helpers ----------
def generate_sql(question: str, schema: str) -> str:
    prompt = f"""
Database schema:

{schema}

Write a valid SQLite SQL query that answers:
{question}

Return only SQL, no explanations or code fences.
"""
    sql = chat(prompt, role="SQL expert")
    return clean_sql(sql)


def explain_results(question: str, sql: str, rows: list) -> str:
    result_text = "\n".join(str(r) for r in rows[:20])
    prompt = f"""
The user asked: {question}

You wrote this SQL:
{sql}

Results sample:
{result_text}

Explain the results clearly in everyday language.
"""
    return chat(prompt, role="data analyst")


def summarize_query(question: str, sql: str) -> str:
    prompt = f"""
The user asked: {question}

You wrote this SQL:
{sql}

Explain briefly, in human-readable terms, what this SQL query is doing
and how it answers the question.
"""
    return chat(prompt, role="SQL explainer")


# ---------- Main entry point ----------
def ask_question(
    question: str,
    include_sql: bool = True,
    explain_results_flag: bool = True,
    human_summary: bool = True,
    max_retries: int = 3,
):
    """
    Generate SQL, run it, retry if needed, and optionally include explanations.
    Returns dict with sql, rows, explanation, summary, or error.
    """
    schema = get_schema()
    sql = generate_sql(question, schema)

    rows = []
    error = None

    for attempt in range(1, max_retries + 1):
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            break  # success
        except Exception as e:
            error = str(e)
            if attempt < max_retries:
                # Ask GPT to fix its SQL
                fix_prompt = f"""
The following SQL caused an error when executed on SQLite:

SQL:
{sql}

Error:
{error}

Schema:
{schema}

Please correct the SQL so that it will run successfully on this schema.
Return only corrected SQL.
"""
                sql = clean_sql(chat(fix_prompt, role="SQL debugger"))
            else:
                return {"error": error, "sql": sql}

    result = {"rows": rows}
    if include_sql:
        print("\n--- SQL ---")
        print(sql)
        result["sql"] = sql
    print("\n--- Rows ---")
    print(rows)
    if human_summary:
        print("\n--- Generating summary ---")
        result["summary"] = summarize_query(question, sql)
        print("\n--- Summary ---\n", result["summary"])
    if explain_results_flag and rows:
        print("\n--- Explaining results ---")
        result["explanation"] = explain_results(question, sql, rows)
        print("\n--- Explanation ---\n", result["explanation"])

    print("\n=== Done ===")
    
    return result

#
# ---------- Example ----------
#if __name__ == "__main__":
#    q = "What are the top 5 patients with the highest blood pressure?"
#    answer = ask_question(q, include_sql=True, explain_results=True, human_summary=True)
#
#    print("\n--- SQL ---\n", answer.get("sql", ""))
#    print("\n--- Rows ---\n", answer.get("rows", []))
#    print("\n--- Summary ---\n", answer.get("summary", ""))
#    print("\n--- Explanation ---\n", answer.get("explanation", ""))
#    

In [2]:

question = "What are the top 5 patients with the highest blood pressure?"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT PERSON_ID, MAX(VALUE_AS_NUMBER) AS MAX_BLOOD_PRESSURE
FROM MEASUREMENT
WHERE MEASUREMENT_CONCEPT_ID IN (
    SELECT CONCEPT_ID FROM CONCEPT WHERE CONCEPT_NAME LIKE '%blood pressure%'
)
GROUP BY PERSON_ID
ORDER BY MAX_BLOOD_PRESSURE DESC
LIMIT 5;

--- Rows ---
[]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to find the top 5 patients with the highest recorded blood pressure measurements. Here's a breakdown of how it works:

1. **Selecting Data**: The query starts by selecting the `PERSON_ID` and the maximum blood pressure value (`MAX(VALUE_AS_NUMBER)`) for each patient from the `MEASUREMENT` table.

2. **Filtering for Blood Pressure**: It filters the measurements to include only those that are related to blood pressure. This is done by checking if the `MEASUREMENT_CONCEPT_ID` is in a subquery that retrieves all concept IDs from the `CONCEPT` table where the concept name includes the phrase "blood pressure".

3. **Grouping Results**: The re

In [3]:
    
question = "Display the columns and field types for the condition_occurrence table."

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
PRAGMA table_info(CONDITION_OCCURRENCE);

--- Rows ---
[(0, 'CONDITION_OCCURRENCE_ID', 'REAL', 0, None, 0), (1, 'PERSON_ID', 'REAL', 0, None, 0), (2, 'CONDITION_CONCEPT_ID', 'REAL', 0, None, 0), (3, 'CONDITION_START_DATE', 'REAL', 0, None, 0), (4, 'CONDITION_START_DATETIME', 'REAL', 0, None, 0), (5, 'CONDITION_END_DATE', 'REAL', 0, None, 0), (6, 'CONDITION_END_DATETIME', 'REAL', 0, None, 0), (7, 'CONDITION_TYPE_CONCEPT_ID', 'REAL', 0, None, 0), (8, 'STOP_REASON', 'TEXT', 0, None, 0), (9, 'PROVIDER_ID', 'REAL', 0, None, 0), (10, 'VISIT_OCCURRENCE_ID', 'REAL', 0, None, 0), (11, 'VISIT_DETAIL_ID', 'REAL', 0, None, 0), (12, 'CONDITION_SOURCE_VALUE', 'TEXT', 0, None, 0), (13, 'CONDITION_SOURCE_CONCEPT_ID', 'REAL', 0, None, 0), (14, 'CONDITION_STATUS_SOURCE_VALUE', 'TEXT', 0, None, 0), (15, 'CONDITION_STATUS_CONCEPT_ID', 'REAL', 0, None, 0)]

--- Generating summary ---

--- Summary ---
 The SQL query `PRAGMA table_info(CONDITION_OCCURRENCE);` is used to retrieve information abou

In [4]:

question = "List all table names in the CDM schema"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT name FROM sqlite_master WHERE type='table';

--- Rows ---
[('CARE_SITE',), ('CDM_SOURCE',), ('COHORT',), ('COHORT_ATTRIBUTE',), ('CONDITION_ERA',), ('CONDITION_OCCURRENCE',), ('COST',), ('DEATH',), ('DEVICE_EXPOSURE',), ('DOSE_ERA',), ('DRUG_ERA',), ('DRUG_EXPOSURE',), ('FACT_RELATIONSHIP',), ('LOCATION',), ('MEASUREMENT',), ('METADATA',), ('NOTE',), ('NOTE_NLP',), ('OBSERVATION',), ('OBSERVATION_PERIOD',), ('PAYER_PLAN_PERIOD',), ('PERSON',), ('PROCEDURE_OCCURRENCE',), ('PROVIDER',), ('SPECIMEN',), ('VISIT_DETAIL',), ('VISIT_OCCURRENCE',), ('CONCEPT',), ('CONCEPT_ANCESTOR',), ('CONCEPT_CLASS',), ('CONCEPT_RELATIONSHIP',), ('CONCEPT_SYNONYM',), ('DOMAIN',), ('DRUG_STRENGTH',), ('RELATIONSHIP',), ('SOURCE_TO_CONCEPT_MAP',), ('VOCABULARY',)]

--- Generating summary ---

--- Summary ---
 The SQL query you provided is designed to retrieve the names of all tables in a SQLite database. Here's a breakdown of what it does:

1. **SELECT name**: This part of the query specifi

In [5]:

question = "list the top 10 distinct conditions and their count in descending count order. Do not list concept_ids but rather list by concept name"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT C.CONCEPT_NAME, COUNT(*) AS condition_count
FROM CONDITION_OCCURRENCE CO
JOIN CONCEPT C ON CO.CONDITION_CONCEPT_ID = C.CONCEPT_ID
GROUP BY C.CONCEPT_NAME
ORDER BY condition_count DESC
LIMIT 10;

--- Rows ---
[('Viral sinusitis', 17268), ('Acute viral pharyngitis', 10217), ('Acute bronchitis', 8184), ('Otitis media', 3605), ('Osteoarthritis', 2694), ('Streptococcal sore throat', 2656), ('Sprain of ankle', 1915), ('Concussion with no loss of consciousness', 1013), ('Sinusitis', 1001), ('Acute bacterial sinusitis', 939)]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to retrieve the top 10 distinct medical conditions from a database, along with the number of times each condition has been recorded. Here's a breakdown of what each part of the query does:

1. **SELECT C.CONCEPT_NAME, COUNT(*) AS condition_count**: This part selects the name of the condition (from the `CONCEPT` table) and counts how many times each condition appears in the `CONDIT

In [6]:

question = "list 10 people who have any variation of sinusitis"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT DISTINCT p.PERSON_ID
FROM PERSON p
JOIN CONDITION_OCCURRENCE co ON p.PERSON_ID = co.PERSON_ID
JOIN CONCEPT c ON co.CONDITION_CONCEPT_ID = c.CONCEPT_ID
WHERE c.CONCEPT_NAME LIKE '%sinusitis%'
LIMIT 10;

--- Rows ---
[(3.0,), (5.0,), (12.0,), (17.0,), (36.0,), (40.0,), (49.0,), (63.0,), (64.0,), (69.0,)]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to find and list up to 10 unique individuals who have been diagnosed with any form of sinusitis. Here's a breakdown of how it works:

1. **Selecting Unique People**: The query starts by selecting distinct `PERSON_ID` values from the `PERSON` table, which represents individuals in the database.

2. **Joining Tables**: It joins the `PERSON` table with the `CONDITION_OCCURRENCE` table to link each person to their medical conditions. It also joins the `CONCEPT` table to access the names of those conditions.

3. **Filtering Conditions**: The `WHERE` clause filters the results to include only those con

In [7]:

question = "List 10 people with sinusitis or its descendant concepts. Results should also include the concept names and ids."

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT DISTINCT p.PERSON_ID, p.GENDER_CONCEPT_ID, p.YEAR_OF_BIRTH, c.CONCEPT_ID, c.CONCEPT_NAME
FROM PERSON p
JOIN CONDITION_OCCURRENCE co ON p.PERSON_ID = co.PERSON_ID
JOIN CONCEPT c ON co.CONDITION_CONCEPT_ID = c.CONCEPT_ID
JOIN CONCEPT_ANCESTOR ca ON c.CONCEPT_ID = ca.DESCENDANT_CONCEPT_ID
WHERE ca.ANCESTOR_CONCEPT_ID = (SELECT CONCEPT_ID FROM CONCEPT WHERE CONCEPT_NAME = 'Sinusitis')
LIMIT 10;

--- Rows ---
[(3.0, 8507.0, 1916.0, 4294548.0, 'Acute bacterial sinusitis'), (5.0, 8507.0, 1968.0, 4294548.0, 'Acute bacterial sinusitis'), (12.0, 8532.0, 1963.0, 4294548.0, 'Acute bacterial sinusitis'), (17.0, 8532.0, 1950.0, 4294548.0, 'Acute bacterial sinusitis'), (36.0, 8532.0, 1958.0, 4294548.0, 'Acute bacterial sinusitis'), (40.0, 8507.0, 1951.0, 4294548.0, 'Acute bacterial sinusitis'), (49.0, 8507.0, 1972.0, 4294548.0, 'Acute bacterial sinusitis'), (63.0, 8532.0, 1955.0, 4294548.0, 'Acute bacterial sinusitis'), (64.0, 8532.0, 1974.0, 4294548.0, 'Acute bacterial sinusitis'

In [8]:

question = "For the first 10 measurements in the measurement table, list the descriptions of each measurement."

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT M.MEASUREMENT_ID, C.CONCEPT_NAME AS MEASUREMENT_DESCRIPTION
FROM MEASUREMENT M
JOIN CONCEPT C ON M.MEASUREMENT_CONCEPT_ID = C.CONCEPT_ID
LIMIT 10;

--- Rows ---
[(10204.0, 'Oral temperature'), (4182.0, 'Throat culture'), (90388.0, 'Throat culture'), (20211.0, 'Oral temperature'), (15652.0, 'Throat culture'), (15744.0, 'Throat culture'), (27099.0, 'Measurement of respiratory function'), (27166.0, 'Throat culture'), (82693.0, 'Throat culture'), (91609.0, 'Measurement of respiratory function')]

--- Generating summary ---

--- Summary ---
 This SQL query retrieves information about the first 10 measurements from a database. Here's a breakdown of what it does:

1. **SELECT Statement**: It specifies that we want to retrieve two pieces of information: the `MEASUREMENT_ID` from the `MEASUREMENT` table and the `CONCEPT_NAME` from the `CONCEPT` table, which serves as the description of each measurement.

2. **FROM Clause**: It indicates that the data is coming from the `MEAS

In [9]:

question = "what is the person_id of the person and the visit date and time for the most recent recorded visit. Format the date and time in an understandable format"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT PERSON_ID, 
       strftime('%Y-%m-%d %H:%M:%S', VISIT_START_DATETIME) AS VISIT_DATETIME
FROM VISIT_OCCURRENCE
ORDER BY VISIT_START_DATETIME DESC
LIMIT 1;

--- Rows ---
[(3026.0, None)]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to retrieve the most recent visit information from a table called `VISIT_OCCURRENCE`. Here's a breakdown of what it does:

1. **SELECT Statement**: It selects two columns: `PERSON_ID`, which identifies the person who made the visit, and `VISIT_START_DATETIME`, which records the date and time of the visit.

2. **Date Formatting**: The `strftime` function is used to format the `VISIT_START_DATETIME` into a more understandable format (YYYY-MM-DD HH:MM:SS). This makes it easier to read and interpret the date and time of the visit.

3. **Ordering**: The results are ordered by `VISIT_START_DATETIME` in descending order, meaning the most recent visit will appear first.

4. **Limiting Results**: The `LIMIT 1` clause ens

In [10]:

question = "Of the conditions in the database, list the 10 most prescribed drugs and their count and condition treated"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT d.DRUG_CONCEPT_ID, COUNT(d.DRUG_EXPOSURE_ID) AS drug_count, co.CONDITION_CONCEPT_ID
FROM DRUG_EXPOSURE d
JOIN CONDITION_OCCURRENCE co ON d.PERSON_ID = co.PERSON_ID
GROUP BY d.DRUG_CONCEPT_ID, co.CONDITION_CONCEPT_ID
ORDER BY drug_count DESC
LIMIT 10;

--- Rows ---
[(1127433.0, 64246, 40481087.0), (40213227.0, 47886, 40481087.0), (40213160.0, 45618, 40481087.0), (1127433.0, 38067, 4112343.0), (1127433.0, 36275, 260139.0), (19059056.0, 31961, 40481087.0), (1713671.0, 29791, 40481087.0), (40213227.0, 28263, 4112343.0), (40213160.0, 26715, 4112343.0), (40213227.0, 22659, 260139.0)]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to retrieve information about the most commonly prescribed drugs along with the conditions they are used to treat. Here's a breakdown of what each part of the query does:

1. **SELECT Statement**: The query selects three pieces of information:
   - `d.DRUG_CONCEPT_ID`: The identifier for the drug.
   - `COUNT(d.DRUG_EXPO

In [11]:

question = "Of the conditions in the database, list the 10 most prescribed drugs using the drug names, corresponding concept_ids, and counts"

result = ask_question( question, include_sql=True, explain_results_flag=True,human_summary=True)



--- SQL ---
SELECT d.DRUG_SOURCE_VALUE AS drug_name, d.DRUG_CONCEPT_ID AS concept_id, COUNT(d.DRUG_EXPOSURE_ID) AS prescription_count
FROM DRUG_EXPOSURE d
JOIN CONDITION_OCCURRENCE c ON d.PERSON_ID = c.PERSON_ID AND d.DRUG_EXPOSURE_START_DATE BETWEEN c.CONDITION_START_DATE AND c.CONDITION_END_DATE
GROUP BY d.DRUG_SOURCE_VALUE, d.DRUG_CONCEPT_ID
ORDER BY prescription_count DESC
LIMIT 10;

--- Rows ---
[('313782', 1127433.0, 9930), ('243670', 19059056.0, 4689), ('562251', 1713671.0, 4082), ('282464', 1127078.0, 2331), ('1043400', 40229134.0, 2108), ('834061', 19133873.0, 1911), ('310965', 19078461.0, 1846), ('10', 40213160.0, 1237), ('105078', 19006318.0, 1200), ('849574', 1115171.0, 1184)]

--- Generating summary ---

--- Summary ---
 This SQL query is designed to find the 10 most commonly prescribed drugs for patients who have specific conditions recorded in the database. Here's a breakdown of what each part of the query does:

1. **SELECT Statement**: The query selects three pieces o