# Querybot

## Goal

Type a query, get data as a response.

More specifically, by using an LLM, we wish to generate appropriate queries to a database, allowing users to retrieve records without any prior knowledge of the database's structure or the appropriate query language.

The workflow will be:
 1. Take the user input as a text string;
 2. Determine which table(s) contain the relevant data;
 3. Generate a SQL query to get the data from the table(s);
 4. Execute the query;
 5. Report back to the user.

In [1]:
def query_database(user_query:str) -> str:
    """
    Query the IMDB database using LLM-generated SQL.
    """
    tables = select_tables(user_query)
    sql = create_sql_query(user_query, tables)
    answer = execute_sql(sql)
    output = format_output(user_query, answer)

    return output

## Database

The [database](https://www.kaggle.com/datasets/priy998/imdbsqlitedataset) we will use is scraped from [IMDB](imdb.com).
It can be obtained by running the appropriate `download_imdb_data` script or from [Kaggle](https://www.kaggle.com/datasets/priy998/imdbsqlitedataset).

In [2]:
import os
import sqlite3

__DATABASE_PATH__ = "data/movie.sqlite"

In [3]:
def connect(db_path:str=__DATABASE_PATH__, raise_error:bool=True):
    """
    Connect to a sqlite3 database at `db_path` only if the file is found.
    """
    if os.path.exists(db_path):
        return sqlite3.connect(db_path)
    elif raise_error:
        raise FileNotFoundError(db_path)
    else:
        return None

In [4]:
def execute_sql(query:str, db_path:str=__DATABASE_PATH__) -> list:
    """
    Connect to the database at the indicated path and execute a SQL query.
    """
    connection = connect(db_path) # connect to DB
    cursor = connection.cursor() # generate a data selector
    cursor.execute(query) # run the SQL to select data
    results = cursor.fetchall() # load the data into memory
    return results

In [5]:
# make sure that accessing the database works
assert(
    set(execute_sql("SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name;")) == 
    set({('directors',), ('movies',), ('sqlite_sequence',)})
)

There are two tables in the database, "movies" and "directors".
We can use SQL to access data from these tables, but even knowing the table names is something we want to avoid.
The AI does need the structure of the DB to be explained to it, though, so we should automatically extract the schema.

In [6]:
__SCHEMA_TXT__ = "data/schema.txt"

def extract_schema(db:str=__DATABASE_PATH__, target:str=__SCHEMA_TXT__) -> str:
    """
    Extract the table/column schema of a sqlite3 database to a text file.
    """
    connection = connect(db)
    cursor = connection.cursor()
    # read the table names from the DB metadata
    cursor.execute("SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name;")
    tables = [x[0] for x in cursor.fetchall() if not x[0].startswith("sqlite_")]
    # build the schema up table by table
    schema = ""
    for table in tables:
        # get column names from table metadata
        cursor = connection.cursor()
        cursor.execute(f"SELECT name FROM pragma_table_info('{table}');")
        columns = [x[0] for x in cursor.fetchall()]
        # format list of columns
        col_str = ""
        for column in columns:
            col_str += column + ", "
        col_str = col_str[:-2]
        # format table scheme and append it to the schema string
        schema += f"\n{table};Table: {table}\\nColumns: \\n{col_str}"
    # trim leading newline
    schema = schema[1:]
    # write to file
    with open(target, "w") as f:
        f.write(schema)
    return schema

print(extract_schema())

directors;Table: directors\nColumns: \nname, id, gender, uid, department
movies;Table: movies\nColumns: \nid, original_title, budget, popularity, release_date, revenue, title, vote_average, vote_count, overview, tagline, uid, director_id


That's a little messy to my eye, but the LLM will have no problem understanding.

## LLM

We'll use ChatGPT-4o as our LLM.
This requires an [OpenAI API key](https://platform.openai.com/settings/organization/api-keys), which must be either saved to a local file or read from environment variables.

**NEVER** save this key in a public place, or write it in a commit in *ANY* Git repo, *EVER*.
The API key gives the user the power to run queries on your account, and you will be billed!

In [7]:
from openai import OpenAI

__OPENAI_API_KEY_PATH__=".openai"

if os.path.exists(__OPENAI_API_KEY_PATH__):
    # read key from local file and connect to the OpenAI API
    client = OpenAI(api_key=open(__OPENAI_API_KEY_PATH__).read())
else:
    client = OpenAI() # backup option: read env vars

In [8]:
__OPENAI_MODEL__ = "gpt-4o"

def call_model(query:str, system_message:str="", model:str=__OPENAI_MODEL__) -> str:
    """
    Run the query with optional `system_message` on the selected model.
    """
    return client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": query}
        ]
    ).choices[0].message

In [9]:
__SYSTEM_PROMPT_FILES__ = {
    "select_tables" : "system/select_tables.txt",
    "create_sql_query" : "system/create_sql_query.txt"
}

In [10]:
def select_tables(user_query:str, system_prompt:str=__SYSTEM_PROMPT_FILES__["select_tables"]) -> list:
    """
    Use the LLM to choose tables from the database.
    """
    system_message = open(system_prompt, "r").read()
    response = call_model(system_message, user_query)
    tables = response.content.split(",")
    return tables

In [11]:
def create_sql_query(user_query:str, tables:list, schema_file:str=__SCHEMA_TXT__):
    """
    Produce a SQL query based on the `user_query` and `schema_file`.
    The list of `tables` is used to confine the SQL's scope.
    """
    # parse tables metadata from schema
    tables_data = ""
    for scheme in open(schema_file,"r").readlines():
        data = scheme.split(";")
        if data[0].strip() in tables:
            tables_data += data[1] + '\n'
    # generate the system prompt
    system_message = open(__SYSTEM_PROMPT_FILES__["create_sql_query"], "r").read()
    system_message = system_message.replace("{tables}", tables_data)
    # generate the SQL
    content = call_model(system_message, user_query).content
    # strip markdown
    content = content.replace("```sql", "").replace("```", "")
    return content

In [12]:
test_query = "List the most popular movies from the years 2000-2006."
response = query_database(test_query)
print(response)

OperationalError: no such column: release_year