In [None]:
DATABASE = ""
DB_USER = ""
DB_HOST = ""
DB_PASSWORD = ""
DB_PORT = ""

JSON_QUESTIONS = ""

In [None]:
import psycopg2

def connect_postgresql():
        # Open database connection
        # Connect to the database
        db = psycopg2.connect(
            f"dbname={DATABASE} user={DB_USER} host={DB_HOST} password={DB_PASSWORD} port={DB_PORT}"
        )
        return db

def execute_postgresql_query(cursor, query):
    cursor.execute(query)
    result = cursor.fetchall()
    return result

def perform_query_on_postgresql_databases(query):
    db = connect_postgresql()
    cursor = db.cursor()
    result = execute_postgresql_query(cursor, query)
    db.close()
    return result

# Excecute a sql query for a postgre database. Returns answer truncated for reponses of more than a characters, to prevent passing answers longer than model context window.
def execute_sql_query(sql_query: str):
    MAX_LENGTH = 1000  
    a = perform_query_on_postgresql_databases(sql_query)
    answer_str = str(a)
    return answer_str[:MAX_LENGTH] + ("..." if len(answer_str) > MAX_LENGTH else "")

EXECUTE_INTERMEDIATE_SQL_TOOL_SCHEMA = {
    "type": "function",
    "function": {
        "name": "execute_sql_query",
        "description": "Executes a given SQL SELECT query against the database. Use this to test parts of a query, explore data patterns, or verify assumptions for constructing the final SQL. The query should be specific and aim to return a small result set (e.g., use LIMIT). Only SELECT queries are permitted.",
        "parameters": {
            "type": "object",
            "properties": {
                "sql_query": {
                    "type": "string",
                    "description": "The PostgreSQL SELECT query to execute for analysis or testing."
                }
            },
            "required": ["sql_query"]
        }
    }
}


In [None]:
prompts = {}
#Telling the models that they have lower count of max auto replies, since they do not have accurate tracking of which response they are in and this helps avoind mistakes where the model thinks it has 'time' left.
prompts["SQL_REPAIR"] = """
You are an expert SQL engineer. Your task is to complete a partially correct, incomplete, or non-working SQL query so that it answers the original user question.
You should aim to finish it as fast as possible as long as it works.
To asses wether it works or not you will be given access to a tool to run SQL.

You will be given:
- The chat history with the other agent. In this chat history there will be:
    - The schema of the database.
    - The question to be answered, with evidence.
    - The SQL process to generate an SQL by another agent.
- Any error messages or execution feedback from running the query

Your job:
1. Analyze the provided SQL query, error messages, and context.
2. Identify and fix any syntax errors, logical mistakes, or missing parts.
3. Ensure the repaired SQL query answers the original question as intended.
4. Use the schema and context to ensure all table and column names are correct.
5. Return ONLY the repaired SQL query.

You will have a max of 4 auto replies.
Remember to try to get it working in as few steps as possibe and use the tool to test it. Terminate only if it works.

- Your FINAL response MUST be ONLY the SQL query to answer the original NLQ. Do not include explanations or markdown.
- You then terminate by puting "TERMINATE_CHAT" in that same message.
"""
prompts["SQL_GENERATOR_SYSTEM_MESSAGE"] = """You are an expert SQL Reasoner and Generator.
Your task is to generate a final, syntactically correct, and semantically accurate PostgreSQL query
to answer the given Natural Language Question (NLQ) based on the provided context (schema, evidence, data profile, cleaning advice).

You have a tool to help you understand the schema and data before committing to a final query. As well as run tests:
1.  `execute_sql_query(sql_query)`: To run a SELECT query against the database to test a hypothesis, check data values, or explore relationships. Use this iteratively if needed. For example, you can use it to check distinct values in a column before deciding on a filter, or to verify a join condition on a small sample. Ensure queries are specific and return small results (e.g., use LIMIT).

Reasoning Process:
1. Understand the NLQ and the initially provided schema and context.
2. If you need to understand actual data patterns, NULL distributions, or test a query fragment, use the `execute_intermediate_sql_query` tool. You can use this tool multiple times if necessary to refine your understanding or query.
3. Based on all information gathered, construct your final PostgreSQL query.
4. Try the query and iterated on it when appropiate. 

When you have an idea for the final answer and see if its responding appropiately.
Once you produce a final SQL query or an answer, do not repeat it or continue generating more.
You will have a max of 13 auto replies.
Make sure the SQL answer works.
Do not terminate unless you reach the max or get a working SQL.

- Your FINAL response MUST be ONLY the SQL query to answer the original NLQ. Do not include explanations or markdown.
- You then terminate by puting "TERMINATE_CHAT" in that same message.
"""
prompts["ANSWER_DECIDER"] = """You are an expert SQL evaluator.
You will be given a question, evidence (extra context) and two answers, the first the ground truth and then the best guess.
The answers will be preluded with their corresponding SQL.
You need to determine if the guess is appropiate with the context of the correct answer and question.
It will be considered correct if it holds the same information, even with different formats.
Example:
    Question: In what month of 2025 did we achieve X?
    Correct: March
    Guess: 03/25
It will also be considered correct if it has extra data. 
Example:
    Question: What month of 2025 had the highest income?
    Correct: March
    Guess: [03, $542]
It will not be considered correct if numeric values such as floats, decimals and integers have discrepancies.
Example:
    Question: How much were earnings in July 2022?
    Correct: 486.64
    Guess: 480
Answer with only CORRECT or INCORRECT.
"""

In [None]:
from autogen import UserProxyAgent, ConversableAgent
import os
import json
from typing import List, Dict, Any
import pickle
import textwrap

In [None]:
# "model" = "gpt-4.1-nano"
# "model":"gpt-4.1",

LLM_CONFIG_LIST = [
    {
        "model":"o4-mini",
        "api_key": os.getenv("OPENAI_API_KEY")
    }
]

sql_generator_agents = {}
llm_config = {"config_list": LLM_CONFIG_LIST, "cache_seed": 42, "timeout": 180}
agent_specific_llm_config = {
    **llm_config,
    "tools": [
        EXECUTE_INTERMEDIATE_SQL_TOOL_SCHEMA 
    ],
}
SQLGenerator = ConversableAgent(
    name="SQLGenerator",
    system_message=prompts["SQL_GENERATOR_SYSTEM_MESSAGE"],
    llm_config=agent_specific_llm_config,
    human_input_mode="NEVER",
    max_consecutive_auto_reply=15,
    code_execution_config=False,
)

In [None]:
LLM__REPAIR_CONFIG_LIST = [
    {
        "model":"gpt-4.1",
        "api_key":os.getenv("OPENAI_API_KEY")
    }
]
llm_config = {"config_list": LLM__REPAIR_CONFIG_LIST, "cache_seed": 42, "timeout": 180}
agent_specific_llm_config = {
    **llm_config,
    "tools": [
        EXECUTE_INTERMEDIATE_SQL_TOOL_SCHEMA
    ],
}
SQLRepair = ConversableAgent(
    name="SQLRepair",
    system_message=prompts["SQL_REPAIR"], 
    llm_config=agent_specific_llm_config,
    human_input_mode="NEVER",
    max_consecutive_auto_reply=5,
    code_execution_config=False,
)

In [None]:
user_proxy = UserProxyAgent(
    name="UserProxyOrchestrator",
    human_input_mode="NEVER",
    code_execution_config=False,
    function_map={
        "execute_sql_query": execute_sql_query
    },
    is_termination_msg=lambda msg: isinstance(msg.get("content", ""), str) and "terminate_chat" in msg["content"].lower()
)

In [None]:
#Function to get questions to be tested, with specific distributions if desired for testing. 
def load_and_sample_by_difficulty(json_path: str, max_simple: int, max_moderate: int, max_challenging: int) -> List[Dict[str, Any]]:
    with open(json_path, 'r', encoding='utf-8') as f:
        all_qs = json.load(f)
    buckets = {
        "simple": [],
        "moderate": [],
        "challenging": []
    }
    limits = {
        "simple": max_simple,
        "moderate": max_moderate,
        "challenging": max_challenging
    }
    for q in all_qs:
        diff = q.get("difficulty", "")
        if diff in buckets and len(buckets[diff]) < limits[diff]:
            buckets[diff].append(q)
        if all(len(buckets[d]) >= limits[d] for d in buckets):
            break
    sampled = buckets["simple"] + buckets["moderate"] + buckets["challenging"]
    return sampled

In [None]:
queries = load_and_sample_by_difficulty( json_path=JSON_QUESTIONS, max_simple=500, max_moderate=500, max_challenging=500)

In [None]:
def fetch_full_schema_ddl() -> str:
    """
    Fetches CREATE TABLE DDL statements for tables in the 'public' schema of a PostgreSQL database.
    Includes columns, data types, NULL constraints, defaults, primary keys, foreign keys, and basic CHECK constraints.
    """
    conn = connect_postgresql()
    if not conn:
        return "-- ERROR: Could not establish database connection to fetch schema."

    schema_ddl_parts = []
    try:
        with conn.cursor() as cur:
            query_tables = """
                SELECT tablename
                FROM pg_catalog.pg_tables
                WHERE schemaname = 'public'
                ORDER BY tablename;
            """
            cur.execute(query_tables)
            all_table_names = [row[0] for row in cur.fetchall()]
            tables_to_process = all_table_names
            
            if not tables_to_process:
                return "-- No tables found or selected in 'public' schema to generate DDL."

            for table_name in tables_to_process:
                table_definition_parts = []

                query_columns = """
                    SELECT
                        column_name,
                        data_type,
                        udt_name, 
                        is_nullable,
                        column_default,
                        character_maximum_length,
                        numeric_precision,
                        numeric_scale,
                        datetime_precision
                    FROM information_schema.columns
                    WHERE table_schema = 'public' AND table_name = %s
                    ORDER BY ordinal_position;
                """
                cur.execute(query_columns, (table_name,))
                columns_data = cur.fetchall()

                column_defs = []
                for col_name, data_type, udt_name, is_nullable, col_default, char_max_len, num_prec, num_scale, dt_prec in columns_data:
                    col_type_str = udt_name 
                    if data_type == 'ARRAY':
                        elem_type_query = """
                            SELECT e.typname
                            FROM pg_type t
                            JOIN pg_namespace n ON n.oid = t.typnamespace
                            JOIN pg_type e ON e.oid = t.typelem
                            WHERE n.nspname = 'public' AND t.typname = %s;
                        """
                        cur.execute(elem_type_query, (udt_name,))
                        elem_res = cur.fetchone()
                        if elem_res:
                            col_type_str = f"{elem_res[0].upper()}[]"
                        else:
                            base_type_from_udt = udt_name[1:] if udt_name.startswith('_') else udt_name
                            col_type_str = f"{base_type_from_udt.upper()}[]"
                    elif data_type.lower() in ['character varying', 'varchar']:
                        col_type_str = f"VARCHAR({char_max_len})" if char_max_len else "VARCHAR"
                    elif data_type.lower() in ['character', 'char']:
                        col_type_str = f"CHAR({char_max_len})" if char_max_len else "CHAR"
                    elif data_type.lower() == 'numeric':
                        if num_prec and num_scale is not None:
                            col_type_str = f"NUMERIC({num_prec}, {num_scale})"
                        elif num_prec:
                            col_type_str = f"NUMERIC({num_prec})"
                        else:
                            col_type_str = "NUMERIC"
                    elif data_type.lower().startswith('timestamp'):
                        col_type_str = f"TIMESTAMP({dt_prec})" if dt_prec is not None else data_type.upper()
                    elif data_type.lower().startswith('time'):
                        col_type_str = f"TIME({dt_prec})" if dt_prec is not None else data_type.upper()
                    else:
                        col_type_str = data_type.upper()

                    col_def = f"    \"{col_name}\" {col_type_str}"
                    if is_nullable == 'NO':
                        col_def += " NOT NULL"
                    if col_default:
                        col_def += f" DEFAULT {col_default}"
                    column_defs.append(col_def)
                
                table_definition_parts.extend(column_defs)

                query_pk = """
                    SELECT kcu.column_name
                    FROM information_schema.table_constraints AS tc
                    JOIN information_schema.key_column_usage AS kcu
                        ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
                    WHERE tc.constraint_type = 'PRIMARY KEY'
                        AND tc.table_name = %s AND tc.table_schema = 'public';
                """
                cur.execute(query_pk, (table_name,))
                pk_columns = [f"\"{row[0]}\"" for row in cur.fetchall()]
                if pk_columns:
                    table_definition_parts.append(f"    PRIMARY KEY ({', '.join(pk_columns)})")

                query_fk = """
                    SELECT
                        tc.constraint_name,
                        kcu.column_name AS local_column,
                        ccu.table_name AS foreign_table_name,
                        ccu.column_name AS foreign_column_name
                    FROM information_schema.table_constraints AS tc
                    JOIN information_schema.key_column_usage AS kcu
                        ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
                    JOIN information_schema.constraint_column_usage AS ccu
                        ON tc.constraint_name = ccu.constraint_name AND tc.table_schema = ccu.table_schema
                    WHERE tc.constraint_type = 'FOREIGN KEY'
                        AND tc.table_name = %s AND tc.table_schema = 'public';
                """
                cur.execute(query_fk, (table_name,))
                fks_data = cur.fetchall()
                fk_constraints_grouped = {}
                for cons_name, loc_col, f_table, f_col in fks_data:
                    if cons_name not in fk_constraints_grouped:
                        fk_constraints_grouped[cons_name] = {
                            'local_columns': [],
                            'foreign_table_name': f_table,
                            'foreign_columns': []
                        }
                    fk_constraints_grouped[cons_name]['local_columns'].append(f"\"{loc_col}\"")
                    fk_constraints_grouped[cons_name]['foreign_columns'].append(f"\"{f_col}\"")
                
                for cons_name, fk_info in fk_constraints_grouped.items():
                    local_cols_str = ", ".join(sorted(list(set(fk_info['local_columns']))))
                    foreign_cols_str = ", ".join(sorted(list(set(fk_info['foreign_columns']))))
                    fk_def = (f"    CONSTRAINT \"{cons_name}\" FOREIGN KEY ({local_cols_str}) "
                                f"REFERENCES public.\"{fk_info['foreign_table_name']}\" ({foreign_cols_str})")
                    table_definition_parts.append(fk_def)

                query_check = """
                    SELECT cc.constraint_name, cc.check_clause 
                    FROM information_schema.check_constraints cc
                    JOIN information_schema.table_constraints tc
                        ON cc.constraint_name = tc.constraint_name AND cc.constraint_schema = tc.table_schema
                    WHERE tc.table_name = %s AND tc.table_schema = 'public' AND tc.constraint_type = 'CHECK';
                """
                cur.execute(query_check, (table_name,))
                check_constraints = cur.fetchall()
                for cons_name, check_clause in check_constraints:
                    table_definition_parts.append(f"    CONSTRAINT \"{cons_name}\" CHECK ({check_clause})")
                
                if table_definition_parts:
                    schema_ddl_parts.append(
                        f"CREATE TABLE public.\"{table_name}\" (\n" + 
                        ",\n".join(table_definition_parts) + 
                        "\n);"
                    )

        conn.rollback() 
        return "\n\n".join(schema_ddl_parts) if schema_ddl_parts else "-- No DDL generated (no tables or error)."
    except Exception as e:
        if conn and not conn.closed : conn.rollback()
        return f"-- Unexpected error fetching schema: {e}"
    finally:
        conn.close()

In [None]:
schema_sql = fetch_full_schema_ddl()

In [None]:
LLM_DECIDER_CONFIG_LIST = [
    {
        "model":"gpt-4.1-nano",
        "api_key": os.getenv("OPENAI_API_KEY") # Or environment variable
    }
]

llm_config = {"config_list": LLM_DECIDER_CONFIG_LIST, "cache_seed": 42, "timeout": 180}
agent_specific_llm_config = {
    **llm_config,
}
answer_evaluator = ConversableAgent(
        name="answer_evaluator",
        system_message=prompts["ANSWER_DECIDER"], # Prompt needs to mention these new tools
        llm_config=agent_specific_llm_config,
        human_input_mode="NEVER",
        max_consecutive_auto_reply=1, # Might need more turns for multiple tool calls
        code_execution_config=False,
        # is_termination_msg=lambda msg: "TERMINATE_CHAT" in msg["content"],
    )

user_proxy_decider = UserProxyAgent(
    name="UserProxyOrchestrator",
    human_input_mode="NEVER",
    code_execution_config=False,
    function_map={
        "execute_sql_query": execute_sql_query
    },
)

In [None]:
results=[]
def generate_answer(item):
    qid = item['question_id']
    Q = item['question']
    E = item['evidence']
    S = item['SQL']
    C = schema_sql
    sql_guess = ""
    answer =  ""

    target = execute_sql_query(S)
    
    try:
        chat_result = user_proxy.initiate_chat(
            recipient=SQLGenerator,
            message=f"\nQuestion:\n{Q}nEvidence:\n{E}nSCHEMA:\n{C}",
            clear_history=True,
            silent=False,
        )
        last_msg = chat_result.chat_history[-1]
        sql_guess = last_msg["content"].replace("TERMINATE_CHAT", "").strip()
        try:
            answer = execute_sql_query(sql_guess)
        except Exception as e:
            chat_repair = user_proxy.initiate_chat(
                recipient=SQLRepair,
                message=str(chat_result.chat_history),
                clear_history=True,
                silent=False,
            )
            answer = f"{e}"
            last_msg = chat_repair.chat_history[-1]
            sql_guess = last_msg["content"].replace("TERMINATE_CHAT", "").strip()
            chat_result = {"chat_result":chat_result, "chat_repair":chat_repair}
            try:
                answer = execute_sql_query(sql_guess)
            except Exception as e:
                answer = f"{e}"
    except Exception as e:
        chat_result = f"{e}"
    try:
        chat_decider = user_proxy_decider.initiate_chat(
            recipient=answer_evaluator,
            message=f"\nQuestion:\n{Q}\nEvidence:\n{E}\nSQL of correct:\n{S}\nCorrect:\n{target}\nSQL of guess:\n{sql_guess}\nGuess:\n{answer}",
            clear_history=True,
            silent=False,
                # human_input_mode="NEVER",  # Never ask for human input.
        )
        chat_decider = chat_decider.chat_history[1]["content"]
    except Exception as e:
        chat_decider = f"{e}"
    return {
        "id":qid,
        "target":target,
        "target_sql":S,
        "answer":answer,
        "answer_sql":sql_guess,
        "full_answer":chat_result,
        "evaluation": chat_decider
    }

try:
    for item in queries:
        result = generate_answer(item)
        results.append(result)  
finally:
    with open("results.pkl", "wb") as f:
        pickle.dump(results, f)

In [None]:
def print_wrapped_table(columns, headers, col_widths=None, add_row_numbers=True):
    num_cols = len(headers)

    # Validate data
    if len(columns) != num_cols:
        raise ValueError("Number of column data lists must match number of headers")

    if col_widths is None:
        col_widths = [30] * num_cols
    elif len(col_widths) < num_cols:
        col_widths += [30] * (num_cols - len(col_widths))

    # Add row number column
    rows = list(zip(*columns))  # Transpose to rows
    row_numbers = [str(i + 1) for i in range(len(rows))]

    if add_row_numbers:
        headers = ['#'] + headers
        col_widths = [4] + col_widths
        columns = [row_numbers] + columns
        rows = list(zip(*columns))

    # Print headers
    header_line = "  ".join(f"{h:<{w}}" for h, w in zip(headers, col_widths))
    print(header_line)
    print("-" * len(header_line))

    for row in rows:
        # Wrap each cell
        wrapped_cells = [
            textwrap.wrap(str(cell), width) or ['']
            for cell, width in zip(row, col_widths)
        ]

        max_lines = max(len(cell) for cell in wrapped_cells)

        # Pad wrapped cells
        for cell in wrapped_cells:
            cell += [''] * (max_lines - len(cell))

        # Print row line by line
        for i in range(max_lines):
            line = "  ".join(f"{wrapped_cells[j][i]:<{col_widths[j]}}" for j in range(len(headers)))
            print(line)

        print()  # Space between rows


In [None]:
def count_statuses(strings):
    counts = {"CORRECT": 0, "INCORRECT": 0, "UNKNOWN": 0}
    
    for s in strings:
        s_clean = s.strip().upper()
        if s_clean == "CORRECT":
            counts["CORRECT"] += 1
        elif s_clean == "INCORRECT":
            counts["INCORRECT"] += 1
        else:
            counts["UNKNOWN"] += 1

    return counts

In [None]:
print(count_statuses([result["evaluation"] for result in results]))
print_wrapped_table(
    columns = [
        [result["id"] for result in results],
        [result["target_sql"] for result in results],
        [result["target"] for result in results],
        [result["answer_sql"] for result in results],
        [result["answer"] for result in results],
        [result["evaluation"] for result in results]
    ], 
    headers = ["id", "target_sql", "target", "answer_sql", "answer", "evaluation"], 
    col_widths = [5,40,10,40,10,10]
)

In [None]:
def get_question_and_evidence(data, question_id):
    for item in data:
        if item.get("question_id") == question_id:
            return {
                "question": item.get("question"),
                "evidence": item.get("evidence")
            }
    # If not found, return empty strings to avoid unbound variable error
    return {
        "question": "",
        "evidence": ""
    }

# Filter results to only those that are not correct
print_wrapped_table(
    columns = [
        [result["id"] for result in results],
        [get_question_and_evidence(queries, result["id"])["question"] for result in results],
        [get_question_and_evidence(queries, result["id"])["evidence"] for result in results],
        [result["target_sql"] for result in results],
        [result["target"] for result in results],
        [result["answer_sql"] for result in results],
        [result["answer"] for result in results],
        [result["evaluation"] for result in results]
    ],
    headers = ["id","question", "evidence", "target_sql", "target", "answer_sql", "answer", "evaluation"],
    col_widths = [5,30,30,30,10,30,10,10]
)