In [1]:
import altair as alt
import pandas as pd
import mercury as mr
import ollama
from typing import Optional
from dbclient import DatabaseClient
from safeish import SafeishPythonExecutor

In [2]:
db = DatabaseClient()
db_schema = db.get_schema_summary()

In [3]:
messages = []
messages.append({
    "role": "system",
    "content": (
        "You are an SQL assistant connected directly to a PostgreSQL database. "
        "You can execute SELECT queries on this database, "
        "and your system will automatically run any SQL query you provide. "
        "Always try to answer user questions by generating and executing an SQL query first, "
        "even if you think you already know the answer logically. "
        "Never assume the result ‚Äî always verify it in the database. "
        "Only if the question cannot possibly be answered with SQL, then ask for clarification. "
        "Use SELECT statements only (no INSERT, UPDATE, DELETE). "
        "When creating visualizations (such as charts, graphs, or plots), "
        "Use the Altair library for all visual outputs. It is already available\n"
        "Create chart object with Altair. I will do display(chart). Please set width to 600px. "
        "Data is in pandas dataframe called df - use df variable. DONT create sample df variable. "
        "If you think you need another library, do not attempt to import it ‚Äî "
        "simply explain that it is not available. "
        "Database schema:\n"
        f"{db_schema}"
    )
})

In [4]:
def query_database(sql: str) -> pd.DataFrame:
    """Query database 
    
    Args:
      sql: SQL query to be executed
    
    Returns:
      Pandas DataFrame with query results
    """
    result = db.query(sql)
    df = DatabaseClient.to_dataframe(result)
    return df

In [5]:
def create_atair_chart(python_code) -> Optional[alt.Chart]:
    """Execute python code to create Altair plot on last query result
    
    Args:
      python_code: string with python code that will create altair chart

    Returns:
      Altair chart object
    """
        
    executor = SafeishPythonExecutor(safe_globals={"alt": alt, "pd": pd})

    res = executor.run(
        python_code,
        context={"df": df},   # last query result
        return_locals=True,
    )

    chart = None
    if res.ok:
        chart = res.locals.get("chart")
    else:
        print(res.error)
    return chart

In [6]:
chat = mr.Chat()

VBox(children=(HTML(value='\n            <div style="\n              color:#b5b5b5;\n              text-align:‚Ä¶

<mercury.chat.chat.ScrollHelper object at 0x74cad79da150>

In [7]:
prompt = mr.ChatInput()

<mercury.chat.chatinput.ChatInputWidget object at 0x74cb414a2450>

In [10]:
if prompt.value:

    user_msg = mr.Message(prompt.value, role="user", emoji="üë§")
    chat.add(user_msg)

    ai_msg = mr.Message(role="assistant", emoji="ü§ñ")
    ai_msg.set_gradient_text("Thinking ...")
    chat.add(ai_msg)
    
    messages += [{"role": "user", "content": prompt.value}]
    response = ollama.chat(
      model='gpt-oss:20b',
      messages=messages,
      think='low',
      tools=[query_database, create_atair_chart]
    )
    messages.append(response.message.model_dump(exclude_none=True))
    if response.message.thinking:
        ai_msg.append_markdown(response.message.thinking)
    if response.message.content:
        ai_msg.append_markdown(response.message.content)
        
    if response.message.tool_calls:
        for tc in response.message.tool_calls:
    
            if tc.function.name == "query_database":
                
                with ai_msg:
                    sql_expander = mr.Expander("‚öíÔ∏è SQL query", key=f"expander-{len(messages)}")
                    with sql_expander:
                        print(tc.function.arguments["sql"])
                    
                df = query_database(**tc.function.arguments)
                messages.append({'role': 'tool', 'tool_name': tc.function.name, 'content': DatabaseClient.describe_dataframe_for_llm(df)})
                with ai_msg:
                    display(df)
                    
            elif tc.function.name == "create_atair_chart":
                chart = create_atair_chart(tc.function.arguments["python_code"])
                messages.append({'role': 'tool', 'tool_name': tc.function.name, 'content': "Plot created" if chart else "Cant create a plot"})
                if chart:
                    with ai_msg:
                        display(chart)
        