In [None]:
from dotenv import load_dotenv
from openai import OpenAI, AuthenticationError
import psycopg
import re
import pandas as pd
import altair as alt 
import traceback
import math
import os
import builtins


In [None]:
from widgets import Chat, ChatInput, Message
from widgets import Expander
from widgets.md import Markdown

In [None]:
_ = load_dotenv()

In [None]:
# openai client connection
api_key = os.environ["OPENAI_KEY"]

def create_client(api_key):
    try:
        client = OpenAI(api_key=api_key)
        client.models.list()
        return client
    except AuthenticationError:
        print("Incorrect API")
    return None

client = create_client(api_key)

In [None]:
# db connection function
def db_engine():
    host = os.environ.get("DB_HOST")
    port = os.environ.get("DB_POST")
    user = os.environ.get("DB_USERNAME")
    password = os.environ.get("DB_PASSWORD")
    db = os.environ.get("DB_DATABASE")
    
    return f"user={user} password={password} host={host} port={port} dbname={db}"

In [None]:
# check db connection
try:
    with psycopg.connect(
       db_engine()
    ) as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT version();")
            version = cur.fetchone()

except Exception as e:
    print(f"database error: {e}")

In [None]:
iterator_py = 0 
iterator_sql = 0

safe_builtins = {
    name: getattr(builtins, name)
    for name in dir(builtins)
    if not name.startswith("__")  # omit dunder names
}

# Remove dangerous ones
for name in [
    "open", "exec", "eval", "compile", "input", "help",
    "dir", "vars", "globals", "locals",
    "exit", "quit", "getattr", "setattr", "delattr",
    "__import__" #  <-- critical, blocks imports
]:
    safe_builtins.pop(name, None)

SAFE_GLOBALS = {
    "__builtins__": safe_builtins,
    "alt": alt,
    "pd": pd,
    "math": math,
    "datetime": datetime,
    "calendar": calendar,
    "display": display,
}


In [None]:
def get_schema_summary():
    query = """
    SELECT table_name, column_name, data_type
    FROM information_schema.columns
    WHERE table_schema = 'public'
    ORDER BY table_name, ordinal_position;
    """

    with psycopg.connect(db_engine()) as conn, conn.cursor() as cur:
        cur.execute(query)
        rows = cur.fetchall()

    schema = {}
    for table, column, dtype in rows:
        schema.setdefault(table, []).append(f"{column} ({dtype})")

    summary = "Database structure:\n"
    for table, columns in schema.items():
        summary += f"- {table}: {', '.join(columns)}\n"

    return summary

In [None]:
def is_safe_sql(sql: str) -> bool:
    sql_lower = sql.lower()
    forbidden = ["insert", "update", "delete", "drop", "alter", "truncate", "create", "grant", "revoke"]
    return sql_lower.strip().startswith("select") and not any(word in sql_lower for word in forbidden)

In [None]:
def run_sql_query_in_message(message, sql):
    """
    Executes an SQL SELECT query inside a chat message (with message:)
    so the results and code are rendered inline with an Expander below.
    """
    global iterator_sql

    try:
        with message:
            # validate query type
            if not is_safe_sql(sql):
                return "Unsafe SQL query detected â€” execution blocked."

            # run query
            with psycopg.connect(db_engine()) as conn, conn.cursor() as cur:
                cur.execute(sql)
                rows = cur.fetchall()

                if not rows:
                    Markdown("No results found.")
                    return

                # handle simple single-value query
                if len(rows[0]) == 1 and len(rows) == 1:
                    result = str(rows[0][0])
                    Markdown(result)
                else:
                    # format result as Markdown table
                    cols = [desc[0] for desc in cur.description]
                    header = "| " + " | ".join(cols) + " |"
                    separator = "| " + " | ".join(["---"] * len(cols)) + " |"
                    body = "\n".join(
                        "| " + " | ".join(map(str, row)) + " |"
                        for row in rows[:10]
                    )
                    result = "\n".join([header, separator, body])
                    Markdown(result)

            # collapsed expander showing the SQL code used
            ex = Expander(label="Show SQL", expanded=False, key=f"show-sql-expander-{iterator_sql}")
            with ex:
                Markdown(f"""
```sql
{sql.strip()}
```
                """)
    
        iterator_sql += 1
        add_to_history_executed_sql(sql, result)
    
    except Exception:
        return "SQL error:\n" + traceback.format_exc()

        

In [None]:
def extract_sql(text):
    """
    Extracts SQL SELECT queries from a text message.
    Supports code blocks (```sql ...```), backticks, and plain SELECT statements.
    """
    pattern = r"""
        ```(?:sql)?\s*(SELECT[\s\S]+?)\s*```   # code block ```sql ... ```
        |`\s*(SELECT[\s\S]+?)\s*`              # inline `SELECT ...`
        |(SELECT[\s\S]+?;)                     # SELECT ending with semicolon
        |(SELECT[\s\S]+)$                      # SELECT till end of text
    """
    m = re.search(pattern, text, flags=re.IGNORECASE | re.VERBOSE)
    if not m:
        return None

    # return the first non-None match group
    for i in range(1, len(m.groups()) + 1):
        if m.group(i):
            return m.group(i).strip()
    return None

In [None]:
def extract_python(text):
    pattern = r"""
        ```(?:python)?\s*([\s\S]+?)\s*```   # code block ```python ... ```
        |`\s*([\s\S]+?)\s*`                 # code inside single backticks
        |(^|\n)\s*([ \t]*def\s+[\s\S]+?)$   # function definition def ... until end of text
        |(^|\n)\s*([ \t]*class\s+[\s\S]+?)$ # class definition class ... until end of text
    """
    m = re.search(pattern, text, flags=re.IGNORECASE | re.VERBOSE)
    if not m:
        return None

    # return the first non-None group
    for i in range(1, len(m.groups()) + 1):
        if m.group(i):
            code = m.group(i).strip()
            # Remove all import lines (e.g., "import ..." or "from ... import ...")
            clean_code = re.sub(r'(?m)^\s*(import|from)\s+[^\n]+', '', code)
            # Also remove excessive blank lines left after stripping
            clean_code = re.sub(r'\n{2,}', '\n', clean_code).strip()
            return clean_code

    return None


In [None]:
def run_python_code_in_message(message, code):
    global iterator_py
    """
    Executes Python code inside the message context (with message:)
    so that Altair visualizations and other outputs are rendered directly within that message.
    """
    try:
        with message:
            exec_env = {"alt": alt}
            exec(code, SAFE_GLOBALS, {})

            ex = Expander(label="Show code", expanded=False, key=f"show-code-expander-{iterator_py}")
            with ex:
                Markdown(f"""
```python
{code}
```
                """)

        iterator_py += 1
        add_to_history_executed_code(code)

    except Exception:
        return "Python error:\n" + traceback.format_exc()

In [None]:
def add_to_history_executed_sql(sql: str, result_preview: str):
    safe_sql = sql.replace("\n", " ")
    preview = result_preview.replace("\n", " ")
    history.append({
        "role": "system",
        "content": (
            f"Executed SQL query (truncated): {safe_sql} "
            f"\nResult preview (truncated): {preview}"
        )
    })

In [None]:
def add_to_history_executed_code(code: str):
    safe_code = code.replace("\n", " ")
    history.append({
        "role": "system",
        "content": f"Executed Python code (truncated): {safe_code}"
    })

In [None]:
def agent(prompt, context=[], message=None):
    messages = context + [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        model="gpt-4.1",
        messages=messages,
    )
    answer = response.choices[0].message.content

    if message:
        message.clear()

    sql = extract_sql(answer)
    python = extract_python(answer)

    if sql:
        text_out = run_sql_query_in_message(message, sql)
        return ""

    if python:
        # render the chart directly in this message
        text_out = run_python_code_in_message(message, python)
        # if there is no textual output, do not insert "No output."
        return ""

    return answer


In [None]:
history = []
db_schema = get_schema_summary()

In [None]:
history.append({
    "role": "system",
    "content": (
        "You are an SQL assistant connected directly to a PostgreSQL database. "
        "You can execute SELECT queries on this database, "
        "and your system will automatically run any SQL query you provide. "
        "Do not say that you donâ€™t have access â€” you have full read-only access to the data. "
        "Always try to answer user questions by generating and executing an SQL query first, "
        "even if you think you already know the answer logically. "
        "Never assume the result â€” always verify it in the database. "
        "Only if the question cannot possibly be answered with SQL, then ask for clarification. "
        "Use SELECT statements only (no INSERT, UPDATE, DELETE). "
        "When creating visualizations (such as charts, graphs, or plots), "
        "use the Altair library for all visual outputs. "
        "To display charts, use `display(chart)`. "
        "Do not include any `import` statements or module-loading code in your Python responses. "
        "All Python code you generate must run using only the preloaded libraries available "
        "in the environment (Altair as `alt`, Pandas as `pd`, and standard Python built-ins). "
        "If you think you need another library, do not attempt to import it â€” "
        "simply explain that it is not available. "
        "Database schema:\n"
        f"{db_schema}"
    )
})


In [None]:
chat = Chat()

In [None]:
chat_input = ChatInput()

In [None]:
msg = (chat_input.submitted or "").strip()

if msg:
    # user msg
    user_msg = Message(role="user", emoji="ðŸ‘¤")
    user_msg.set_message(markdown=msg)
    chat.add(user_msg)

    # ai msg
    bot_msg = Message(role="assistant", emoji="ðŸ¤–")
    bot_msg.set_gradient_text("Thinking hard ðŸ¤”")
    chat.add(bot_msg)
    reply = agent(msg, history, message=bot_msg)
    history.append({"role": "user", "content": msg})
    history.append({"role": "assistant", "content": reply})
    # if agent returned text display it below the visualization
    if reply:
        bot_msg.clear()
        bot_msg.append_markdown(reply)
