This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries.

In [1]:
import os
from dotenv import load_dotenv
from typing import Annotated
from openai import AzureOpenAI
import sqlite3
from typing import Any, List
import pandas as pd

load_dotenv()
AZURE_OPENAI_ENDPOINT=os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_KEY= os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_EMBEDDINGS_ADA_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_EMBEDDINGS_ADA_DEPLOYMENT_NAME")
AZURE_OPENAI_GPT4_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_GPT4_DEPLOYMENT_NAME")


In [2]:
llm = AzureOpenAI(
        azure_endpoint=AZURE_OPENAI_ENDPOINT,
        api_key=AZURE_OPENAI_API_KEY,
        api_version=AZURE_OPENAI_API_VERSION
)

def call_openAI(user_prompt):
    system_message = """You are an assistant designed to answer questions."""

    response = llm.chat.completions.create(
        model=AZURE_OPENAI_GPT4_DEPLOYMENT_NAME,
        messages = [
            {"role":"system","content":system_message},
            {"role":"user","content":user_prompt}
            ]
    )
    return response.choices[0].message.content


In [3]:
def get_conn(db_file: str) -> sqlite3.Connection:
    return sqlite3.connect(db_file)

def execute_query(conn: sqlite3.Connection, query: str) -> List[Any]:
    cursor = conn.cursor()
    return cursor.execute(query).fetchall()


def execute_query_pd(conn: sqlite3.Connection, query: str) -> pd.DataFrame:
    return pd.read_sql_query(query, conn)

# Get a description of a table into a pandas dataframe
def get_table_schema(conn: sqlite3.Connection, table_name: str) -> pd.DataFrame:
    query = f"PRAGMA table_info({table_name});"
    return pd.read_sql_query(query, conn)

#Construct a description of the DB schema for the LLM by retrieving the
# CREATE commands used to create the tables
def get_db_creation_sql(conn: sqlite3.Connection) -> str:
    cursor = conn.cursor()
    query = "SELECT sql FROM sqlite_master WHERE type='table'"
    results = cursor.execute(query).fetchall()
    results = [col[0] for col in results]
    schema_description = '\n'.join(results)
    return schema_description


def get_tables_from_db(conn: sqlite3.Connection) -> List[str]:
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    return [x[0] for x in cursor.fetchall()]


In [10]:
database = 'bookstore.db'
conn = sqlite3.connect(database)
TABLE_NAME = 'Books'
schema = get_db_creation_sql(conn)
print(schema)

CREATE TABLE Authors (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            biography TEXT
        )
CREATE TABLE sqlite_sequence(name,seq)
CREATE TABLE Publishers (
            d INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            address TEXT
        )
CREATE TABLE Books (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            title TEXT NOT NULL,
            price REAL,
            author_id INTEGER NOT NULL, FOREIGN KEY (author_id) REFERENCES Authors(id)
        )


In [5]:
from typing import List, Optional, Tuple
import pandas as pd
import json

EXP_QUERY_TEMPLATE = json.dumps(
    {
        'explanation': '[explain what the query does]',
        'query': '```[sql query]```'
    }, indent=4
)

#Given an SQL error after a bad query, this writes a prompt to fix it
def get_error_prompt(schema: str,question: str,query: str,sql_error_message: str) -> str:
    sql_error_dict = {
        'Question': question,
        'Query': query,
        'SQL Error': str(sql_error_message)
    }

    prompt = get_table_prompt(schema)
    prompt += 'Given the following question, query, and sql error, fix the query.'
    prompt += json.dumps(sql_error_dict)
    prompt += '\nReturn your answer by filling out the following template:\n'
    prompt += EXP_QUERY_TEMPLATE
    return prompt

# Extracts the explanation and SQL query out of the response from the model
def extract_query_from_response(response: str) -> Tuple[str, Optional[str]]:
    # print('Response_1:', response, '\n')

    if '{' in response and '}' in response:
        response = '{' + response.split('{')[1].split('}')[0] + '}'
        response = response.replace('```json', '```')
        response = response.strip().rstrip().replace('\n', '')
        response = json.loads(response)
        query = response['query'].replace('```sql', '')
        query = query.replace('```', '')
        return response['explanation'], query
    else:
        # No JSON returned, so probably just an explanation for why it can't complete it
        return response, None

# Writes a prompt to show information about the table
def get_table_prompt(schema: str) -> str:
    prompt = 'Below is the information for an SQLite table.'
    prompt += f'Schema:\n{schema}\n\n'
    prompt += '\n------------------------------------------------------\n'
    return prompt

# Writes the prompt to ask the question given by the user
def get_user_agent_prompt(schema: str, question: str) -> str:
    prompt = get_table_prompt(schema)
    prompt += 'Below is a question input from a user. '
    prompt += 'Generate an SQL query that pulls the necessary data to answer the question.\n\n'
    prompt += f'Question: {question}\n\n'
    prompt += 'Return your answer by filling out the following template:\n'
    prompt += EXP_QUERY_TEMPLATE
    return prompt

# Formats the question, query, explanation, and result into a JSON type string
def get_question_result(question: str, query: str, explanation: str, result: str) -> str:
    return json.dumps(
        {
            'question': question,
            'query': query,
            'explanation': explanation,
            'result': result
        }, indent=4
    )

# Writes the prompt to answer the question the user had given the result from the query to the database
def get_db_agent_prompt(schema: str,question: str,query: str,explanation: str,result: pd.DataFrame) -> str:
    prompt = get_table_prompt(schema)
    prompt += 'Below is a question, SQL query, explanation, and the result from executing the query. '
    prompt += 'Use these pieces of information to answer the question.\n\n'
    prompt += get_question_result(
        question, query, explanation, result.to_string())
    return prompt


def respond(question: str, chat_history: List[Tuple[str, str]]) -> Tuple:
    user_prompt = get_user_agent_prompt(schema, question)
    ua_response = call_openAI(user_prompt)
    explanation, query = extract_query_from_response(ua_response)

    if query is None:
        conn.close()
        return '', chat_history, None, '', explanation

    success = False
    for _ in range(5):
        try:
            query_result = pd.read_sql_query(query, conn)
            success = True
            break
        except Exception as sql_error_message:
            sql_error_prompt = get_error_prompt(schema, question, query, str(sql_error_message))
            response = call_openAI(sql_error_prompt)
            explanation, query = extract_query_from_response(response)
            if query is None:
                conn.close()
                return '', chat_history, None, '', explanation

    if success:
        db_res_prompt = get_db_agent_prompt(schema, question, query, explanation, query_result)
        chat_response = call_openAI(db_res_prompt)
        chat_history.append((question, chat_response))
    else:
        query_result = ''

    return '', chat_history, query_result, query, explanation

In [6]:
respond("How many books are in the bookstore?", [])

('',
 [('How many books are in the bookstore?',
   'Based on the provided SQL query and its result, there are 2 books in the bookstore.')],
    total_books
 0            2,
 'SELECT COUNT(*) AS total_books FROM Books;',
 "The SQL query counts the total number of rows in the 'Books' table, which corresponds to the total number of books available in the bookstore.")

In [7]:
respond("How many books by Harper Lee are in the bookstore?", [])

('',
 [('How many books by Harper Lee are in the bookstore?',
   'The result of the query indicates that there is 1 book by Harper Lee in the bookstore.')],
    COUNT(*)
 0         1,
 "SELECT COUNT(*) FROM Books JOIN Authors ON Books.author_id = Authors.id WHERE Authors.name = 'Harper Lee';",
 "The query attempts to count the number of books written by Harper Lee by joining the Books table with the Authors table. The join is performed based on the author_id field in the Books table and the id field in the Authors table. The WHERE clause filters the results to only include books where the author's name is 'Harper Lee'. The error occurred due to an extra 'SQL' prefix in the query, which is not valid SQL syntax.")

In [8]:
conn.close()