# SQL Agent


Highlight:
1. If the schema of the database is too big to fit the content length, such as too many tables or too many columns, the Agent will truncate the schema and will ask the LLM to select tables on the fly.
    a. We still recommend increasing the content size as much as possible, as more information can be significantly beneficial.
2. I provided a customized (real-world) database in the example, because I found that GPT is already pretrained on many public databases, such as IMDB, TPC-H, and Chinook. The provided database is from Azure Brain project, with real tables and columns. Names and values in the rows are randomly anonymized for privacy considerations. The simple database with 2 tables can show a minimum example of SQLAgent for education purpose.
3. We initiate the SQLAgent with a warm start by default, which loads some example questions and answers into the history (i.e., oai_messages).


TODOs:
A flag check which column schema the agent has seen before.


Acknowledgement:
This notebook is inspired by LangChain's SQL Chain, Microsoft's DB CoPilot, and Microsoft's C&AI Team. We would also like to thank the Azure Allocator Brain team.

### Before everything starts, install AutoGen
```bash
pip install "pyautogen~=0.2.0b4"
```

In [1]:
import sqlite3
import autogen
import requests
import json
import pdb
import os
import re
import tiktoken


from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import autogen
from autogen import AssistantAgent, Agent, UserProxyAgent, ConversableAgent
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang

from termcolor import colored
import random


In [2]:
config_list_gpt4 = autogen.config_list_from_json(
    "OAI_CONFIG_LIST",
    filter_dict={
        "model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
    },
)

gpt4_llm_config = {"config_list": config_list_gpt4, "cache_seed": 42}



config_list_gpt35 = autogen.config_list_from_json(
    "OAI_CONFIG_LIST",
    filter_dict={
        "model": ["gpt-35-turbo", "gpt-3.5-turbo"],
    },
)

gpt35_llm_config = {"config_list": config_list_gpt35, "cache_seed": 42}


## Download SQL data for test

In [3]:
# !wget https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sqlite
# filename = "Chinook_Sqlite.sqlite"

In [4]:
import requests
import zipfile
import io

# URL of the file to be downloaded
url = "https://azurecloudpublicdataset2.blob.core.windows.net/azurepublicdatasetv2/azurevmallocation_dataset2020/AzurePackingTraceV1.zip"
folder_name = "vm"


if not os.path.exists(folder_name):
    # Send a GET request to the URL
    response = requests.get(url)

    # Check if the request was successful
    if response.status_code == 200:
        # Create a ZipFile object in memory using the content of the response
        with zipfile.ZipFile(io.BytesIO(response.content)) as z:
            # Extract all the contents into a desired directory
            z.extractall(folder_name)
        print("File downloaded and unzipped successfully.")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")
    
# filename = "Chinook_Sqlite.sqlite"
filename = os.path.join(folder_name, "packing_trace_zone_a_v1.sqlite")
assert os.path.exists(filename)


## Define the SQLAgent

It is a naive version, which works ok.

In [5]:
# Function to list all tables in the database
def list_tables(cursor):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    ans = []
    for table in cursor.fetchall():
        ans.append(table[0])
    return ans
    
# Function to show the schema of a table
def show_table_schema(cursor, table_name):
    cursor.execute(f"PRAGMA table_info({table_name});")
    
    ans = f"Here is the table information for Table: {table_name}\n"
    ans += "-" * 30 + "\n"
    ans += "Column ID | Column Name | Data Type | Not Null Constraint | Default Val | Primary Key | \n"
    ans += "-" * 30 + "\n"
    for row in cursor.fetchall():
        ans += " | ".join([str(x) for x in row]) + "\n"
    ans += "-" * 30 + "\n\n\n"
    return ans

# Function to execute a query
def execute_query(cursor, query):
    cursor.execute(query)
    return cursor.fetchall()


def find_potential_joins(cursor):
    # Query to get foreign key relationships
    cursor.execute("SELECT tbl_name, sql FROM sqlite_master WHERE type='table'")
    tables = cursor.fetchall()

    joins = []
    for table, create_statement in tables:
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        foreign_keys = cursor.fetchall()
        for fk in foreign_keys:
            # Each item in foreign_keys is a tuple like:
            # (id, seq, table, from, to, on_update, on_delete, match)
            from_table = table
            from_column = fk[3]
            to_table = fk[2]
            to_column = fk[4]
            join_info = f"{from_table}.{from_column} -> {to_table}.{to_column}"
            joins.append(join_info)

    return joins

def column_names(cursor, table_name):
    """
    Retrieves the column names for a specified table.
    
    Args:
    cursor (sqlite3.Cursor): The cursor object connected to the database.
    table_name (str): The name of the table for which to retrieve column names.

    Returns:
    list: A list of column names for the specified table.
    """
    query = f"PRAGMA table_info({table_name});"
    cursor.execute(query)
    # Fetches all rows from the query result, extracts the second column which contains column names
    return [row[1] for row in cursor.fetchall()]




In [6]:
def table_prompt(cursor, question="", model="gpt-4", max_tokens=10e9):
    db_info = ""
    if question:
        db_info = f"## My question\n{question}\n\n\n"
        
    db_info += "## Table information\n"
    tables = list_tables(cursor)
    db_info += f"There are {len(tables)} tables in the database, which are:"
    for t in tables:
        db_info += "\n- " + t
        
    joins = find_potential_joins(cursor)
    if joins:
        db_info += "\n\n## Join information\n"
        db_info += "Here are the columns (and their tables) that could be potentially joined:"
        for j in joins:
            db_info += "\n- " + j
        db_info += "\n\n"    
        
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    n_tokens = lambda x: len(encoding.encode(x))
    if max_tokens is None:
        max_tokens = 10e9
        
    db_info = db_info
    assert n_tokens(db_info) < max_tokens
    
    #### First, try full schema info
    schema_info = "\n\n## Schema\n"
    for table in tables:
        table_info = show_table_schema(cursor, table)
        schema_info += table_info

    if n_tokens(schema_info + schema_info) < max_tokens:
        return db_info + schema_info
        
    #### if not possible, try column names
    column_info = "\n\n## Columns\n"
    for table in tables:
        column_info += f"{table}: {column_names(cursor, table)}\n"

    if n_tokens(column_info + column_info) < max_tokens:
        return db_info + column_info
    
    alert = "\n\n### Schemas are Skipped for some tables.\n__Note__: "
    alert += "You SHOULD use the `column_names(cursor, table_name)` function to show columns of a table."
    
    return db_info + alert
    
# print(table_prompt(cursor, max_tokens=100))

In [7]:

def execute_code_with_cursor(code, cursor, **kwargs) -> Dict:
    locals_dict = {"cursor": cursor, "show_table_schema": show_table_schema, "column_names": column_names}

    if 'lang' in kwargs and kwargs['lang'].lower() == "sql":
        print(colored("WARNING: raw sql code is received for execution. We allow it for now, but the LLM should be improved.", "yellow"))
        code = f"rst = cursor.execute({code}).fetchall()"

    try:
        exec(code, locals_dict, locals_dict)
    except Exception as e:
        return "Error: " + str(e)
    if "rst" in locals_dict:
        return {"rst": locals_dict["rst"]}
    known_keys = list(locals().keys()) + list(globals().keys())
    return {k: v for k, v in locals_dict.items() if k not in known_keys}
    
# execute_code_with_cursor("x = print(cursor)", cursor)

In [8]:
DEFAULT_SQL_CODE_CONFIG = {
                "last_n_messages": 3,
                "work_dir": ".",
                "use_docker": False
            }

IS_TERMINATE_MSG = lambda x: content_str(x.get("content")).find("ANSWER") >= 0
# or content_str(x.get("content")).strip() == ""

class SQLAgent(ConversableAgent):
    
    def __init__(self, sql_file, name, max_sys_msg_tokens, llm_config: dict, 
                              is_termination_msg=IS_TERMINATE_MSG,
                 **kwargs):
        super().__init__(name, llm_config=llm_config,  is_termination_msg=IS_TERMINATE_MSG,
                         max_consecutive_auto_reply=0,
                         **kwargs)
        self.sql_file = sql_file
        connection = sqlite3.connect(sql_file)
        self.cursor = connection.cursor()
        
        self.assistant = AssistantAgent(name="sql llm", system_message="""You can interact with the SQLlite3 database and answer questions.
        
There are two types of answers you can provide: code and answer.

## Code
In the code format, you must write code directly in code block to answer questions. For instance,
```python
# Python code by using the `cursor` variable.
```
The cursor variable is provided directly to you. 
Save the result to the "rst" variable so that you can see the output.
Note that you are interacting with a Python environment, and it can only reply you if you write code.


## Answer
If you gather enough information (e.g., from coding or prior knowledge) to answer the user question, then you should reply with your final answer in the foramt:
ANSWER: your answer goes here.
""" + table_prompt(self.cursor, question="", model=llm_config["config_list"][0]["model"], max_tokens=max_sys_msg_tokens), 
                                        llm_config=llm_config,
                                        is_termination_msg=is_termination_msg,
                                       max_consecutive_auto_reply=10)
        
        self.proxy = SQLProxy(sql_file=filename, name="sql proxy", 
                              human_input_mode="NEVER",
                             is_termination_msg=is_termination_msg)
        self.register_reply([Agent, None], SQLAgent.generate_sql_reply)
        self.synth_history()
        
    def synth_history(self):
        # we inject some synthetic Q&A with the desired format, so that the LLM can learn the pattern in context.
        tables = list_tables(self.cursor)
        n_tables = len(tables)
        query = f"SELECT COUNT(*) FROM {tables[0]};"
        n_rows_table0 = self.cursor.execute(query).fetchall()[0][0]        
        table0_info = show_table_schema(self.cursor, tables[0])
        self.table_memory = {table_name: show_table_schema(self.cursor, table_name) for table_name in tables}
        
        prior = [{'content': 'QUESTION: How many tables are in the database?',
'role': 'user'},
{'content': 'We can count the number of tables in the database.\n```python\ncursor.execute("SELECT COUNT(name) FROM sqlite_master WHERE type=\'table\';")\nrst = cursor.fetchone()\nrst\n```',
'role': 'assistant'},
{'content': "exitcode: 0 (execution succeeded)\nCode output: \n{'rst': (" + str(n_tables) + ",)}",
'role': 'user'},
{'content': f'ANSWER: The database contains {n_tables} tables. If you need any more information, feel free to ask!',
'role': 'assistant'},
 {'content': f'QUESTION: How many rows are in the {tables[0].lower()} table?',
'role': 'user'},
{'content': f'We can run the following code\n```python\nquery = "SELECT COUNT(*) FROM {tables[0]};"\nrst = cursor.execute(query).fetchall()\n```',
'role': 'assistant'},
{'content': "exitcode: 0 (execution succeeded)\nCode output: \n{'rst': (" + str(n_rows_table0) + ",)}",
'role': 'user'},
{'content': f'ANSWER: There are {n_rows_table0} rows in the {tables[0]} table.',
'role': 'assistant'},
{'content': f'QUESTION: What columns are in the {tables[0].lower()} table?',
'role': 'user'},
{'content': f'We can run the function\n```python\nrst = show_table_schema(cursor, {tables[0]})\n```',
'role': 'assistant'},
{'content': "exitcode: 0 (execution succeeded)\nCode output: \n{'rst': " + str(table0_info) + "}",
'role': 'user'},
{'content': f'ANSWER: The columns in the {tables[0]} table are: ' + ", ".join(column_names(self.cursor, tables[0])),
'role': 'assistant'},
 {'content': "QUESTION: Who are you?",
'role': 'user'},
{'content': "ANSWER: Sorry, I don't have enough information to answer your question, because it is not in the database.",
'role': 'assistant'},
]
        self.assistant._oai_messages[self.proxy] = prior
        
    def generate_sql_reply(self, messages: Optional[List[Dict]], sender: "Agent", config):
        """Generate a reply using OpenAI DALLE call."""
        client = self.client if config is None else config
        if client is None:
            return False, None
        if messages is None:
            messages = self._oai_messages[sender]
        prompt = messages[-1]["content"]

        
        self.proxy.initiate_chat(self.assistant, message="QUESTION: " + prompt, clear_history=False)
        ans = content_str(self.assistant.last_message()["content"])
        ans = ans.replace("ANSWER:", "").strip().rstrip()
        
        return True, ans
    
class SQLProxy(UserProxyAgent):
    def __init__(self, sql_file, name, 
                 code_execution_config: dict = DEFAULT_SQL_CODE_CONFIG, **kwargs):
        super().__init__(name,
                         code_execution_config=code_execution_config, **kwargs)
        self.sql_file = sql_file
        connection = sqlite3.connect(sql_file)
        self.cursor = connection.cursor()
        
    def run_code(self, code, **kwargs):
        """Run the code and return the result.

        Override this function to modify the way to run the code.
        Args:
            code (str): the code to be executed.
            **kwargs: other keyword arguments.

        Returns:
            A tuple of (exitcode, logs, image).
            exitcode (int): the exit code of the code execution.
            logs (str): the logs of the code execution.
            image (str or None): the docker image used for the code execution.
        """
        rst = execute_code_with_cursor(code, self.cursor, **kwargs)
        exitcode = 0 if str(rst).find("Error: ") != 0 else 1
        
        return exitcode, str(rst), None

    
sql_assistant = SQLAgent(sql_file=filename, name="SQL Helper", max_sys_msg_tokens=2000,
                         llm_config=gpt35_llm_config)
user = UserProxyAgent(name="user", human_input_mode="NEVER", max_consecutive_auto_reply=0)
user.initiate_chat(sql_assistant, message="Which VM has the longest runtime and what type is it?")


[33muser[0m (to SQL Helper):

Which VM has the longest runtime and what type is it?

--------------------------------------------------------------------------------
[33msql proxy[0m (to sql llm):

QUESTION: Which VM has the longest runtime and what type is it?

--------------------------------------------------------------------------------
[33msql llm[0m (to sql proxy):

We would need to find the VM with the longest runtime and its type. We can achieve this with the following SQL query:
```python
query = '''
SELECT v.vmId, v.starttime, v.endtime, t.vmTypeId
FROM vm AS v
JOIN vmType AS t ON v.vmTypeId = t.id
ORDER BY (v.endtime - v.starttime) DESC
LIMIT 1;
'''
rst = cursor.execute(query).fetchall()
```

--------------------------------------------------------------------------------
[31m
>>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...[0m
[33msql proxy[0m (to sql llm):

exitcode: 0 (execution succeeded)
Code output: 
{'rst': [(9543, -7157.0, 89.22275358811021

In [12]:
print(sql_assistant.assistant.system_message)

You can interact with the SQLlite3 database and answer questions.
        
There are two types of answers you can provide: code and answer.

## Code
In the code format, you must write code directly in code block to answer questions. For instance,
```python
# Python code by using the `cursor` variable.
```
The cursor variable is provided directly to you. 
Save the result to the "rst" variable so that you can see the output.
Note that you are interacting with a Python environment, and it can only reply you if you write code.


## Answer
If you gather enough information (e.g., from coding or prior knowledge) to answer the user question, then you should reply with your final answer in the foramt:
ANSWER: your answer goes here.
## Table information
There are 2 tables in the database, which are:
- vm
- vmType

## Schema
Here is the table information for Table: vm
------------------------------
Column ID | Column Name | Data Type | Not Null Constraint | Default Val | Primary Key | 
----------

## Example with Limited Tokens

What if the content length is limited? We show an example, where we limit the system message within 100 tokens. So, enough space is left for chat. In this case, the schema for any tables are not included in the system prompt.  However, it can still perform task.

For instance, the Agent needs to know the column name of "tenantId" and "core" columns from different tables of the dataset.

In [15]:
sql_assistant_short = SQLAgent(sql_file=filename, name="SQL Helper", max_sys_msg_tokens=100, llm_config=gpt35_llm_config)

print("Let's first show the system message of the assistant")
print(colored("Assistant's System Message:", "green"))
print(sql_assistant_short.assistant.system_message)
print("\n\n")

print(colored("History:", "green"))
print(sql_assistant_short.assistant._oai_messages)

Let's first show the system message of the assistant
[32mAssistant's System Message:[0m
You can interact with the SQLlite3 database and answer questions.
        
There are two types of answers you can provide: code and answer.

## Code
In the code format, you must write code directly in code block to answer questions. For instance,
```python
# Python code by using the `cursor` variable.
```
The cursor variable is provided directly to you. 
Save the result to the "rst" variable so that you can see the output.
Note that you are interacting with a Python environment, and it can only reply you if you write code.


## Answer
If you gather enough information (e.g., from coding or prior knowledge) to answer the user question, then you should reply with your final answer in the foramt:
ANSWER: your answer goes here.
## Table information
There are 2 tables in the database, which are:
- vm
- vmType

### Schemas are Skipped for some tables.
__Note__: You SHOULD use the `column_names(cursor, ta

In [16]:
sql_assistant_short = SQLAgent(sql_file=filename, name="SQL Helper", max_sys_msg_tokens=100, llm_config=gpt4_llm_config)
user = UserProxyAgent(name="user", human_input_mode="NEVER", max_consecutive_auto_reply=0)
user.initiate_chat(sql_assistant_short, message="Which tenant used the most number of CPUs?")



[33muser[0m (to SQL Helper):

Which tenant used the most number of CPUs?

--------------------------------------------------------------------------------
[33msql proxy[0m (to sql llm):

QUESTION: Which tenant used the most number of CPUs?

--------------------------------------------------------------------------------
[33msql llm[0m (to sql proxy):

To determine which tenant used the most number of CPUs, we need to first join the `vm` table with the `vmType` table on the `vmTypeId` to get the CPU information for each VM. Then, we need to sum the number of CPUs used by each tenant and find the tenant with the maximum sum. Let's run the appropriate SQL query to find this out.
```python
query = """
SELECT vm.tenantId, SUM(vmType.vcpu) as total_cpus
FROM vm
JOIN vmType ON vm.vmTypeId = vmType.vmTypeId
GROUP BY vm.tenantId
ORDER BY total_cpus DESC
LIMIT 1;
"""
rst = cursor.execute(query).fetchone()
```

--------------------------------------------------------------------------------