This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries.

In [1]:
import os
from dotenv import load_dotenv
from typing import Annotated
from openai import AzureOpenAI
import sqlite3
from typing import Any, List
import pandas as pd

load_dotenv()

AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_KEY")
AZURE_OPENAI_GPT4oMINI_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_GPT4oMINI_DEPLOYMENT")
#AZURE_OPENAI_EMBEDDINGS_ADA_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_ADA_EMBEDDING_DEPLOYMENT")
#AZURE_OPENAI_EMBEDDINGS_EB3_DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_EB3_EMBEDDING_DEPLOYMENT")
AZURE_OPENAI_API_VERSION=os.getenv("AZURE_OPENAI_API_VERSION")

print(f"Models: {AZURE_OPENAI_GPT4oMINI_DEPLOYMENT_NAME}; API Version:{AZURE_OPENAI_API_VERSION}")
print("Azure OpenAI Model is ready to use!")


Models: gpt-4o-mini; API Version:2024-10-21
Azure OpenAI Model is ready to use!


In [2]:
llm = AzureOpenAI(
        azure_endpoint=AZURE_OPENAI_ENDPOINT,
        api_key=AZURE_OPENAI_API_KEY,
        api_version=AZURE_OPENAI_API_VERSION
)

def call_openAI(user_prompt, use_json_object=True):
    system_message = """You are an assistant designed to answer questions."""
    
    if use_json_object:
        response = llm.chat.completions.create(
            model=AZURE_OPENAI_GPT4oMINI_DEPLOYMENT_NAME,
            messages = [
                {"role":"system","content":system_message},
                {"role":"user","content":user_prompt}
                ],
            response_format={ "type": "json_object" }
        )
        return response.choices[0].message.content
    else:
        response = llm.chat.completions.create(
            model=AZURE_OPENAI_GPT4oMINI_DEPLOYMENT_NAME,
            messages = [
                {"role":"system","content":system_message},
                {"role":"user","content":user_prompt}
                ]
        )
    return response.choices[0].message.content


In [3]:
def get_conn(db_file: str) -> sqlite3.Connection:
    return sqlite3.connect(db_file)

def execute_query(conn: sqlite3.Connection, query: str) -> List[Any]:
    cursor = conn.cursor()
    return cursor.execute(query).fetchall()

def execute_query_pd(conn: sqlite3.Connection, query: str) -> pd.DataFrame:
    return pd.read_sql_query(query, conn)

# Get a description of a table into a pandas dataframe
def get_table_schema(conn: sqlite3.Connection, table_name: str) -> pd.DataFrame:
    query = f"PRAGMA table_info({table_name});"
    return pd.read_sql_query(query, conn)

#Construct a description of the DB schema for the LLM by retrieving the
# CREATE commands used to create the tables
def get_db_creation_sql(conn: sqlite3.Connection) -> str:
    cursor = conn.cursor()
    query = "SELECT sql FROM sqlite_master WHERE type='table'"
    results = cursor.execute(query).fetchall()
    results = [col[0] for col in results]
    schema_description = '\n'.join(results)
    return schema_description


def get_tables_from_db(conn: sqlite3.Connection) -> List[str]:
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    return [x[0] for x in cursor.fetchall()]


In [4]:
database = 'bookstore.db'
conn = sqlite3.connect(database)
TABLE_NAME = 'Books'
schema = get_db_creation_sql(conn)
print(schema)

CREATE TABLE Authors (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            biography TEXT
        )
CREATE TABLE sqlite_sequence(name,seq)
CREATE TABLE Publishers (
            d INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            address TEXT
        )
CREATE TABLE Books (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            title TEXT NOT NULL,
            price REAL,
            author_id INTEGER NOT NULL, FOREIGN KEY (author_id) REFERENCES Authors(id)
        )


In [5]:
from string import Template

def load_template(filename: str) -> Template:
    file = open(filename, "r")
    content = file.read()
    file.close()
    template = Template(content)
    return template

In [6]:
from typing import List, Tuple
import pandas as pd
import json

def get_correction_prompt(schema: str,question: str,query: str,sql_error_message: str) -> str:
    template_file_name = "./prompts/correction_prompt.txt"
    template = load_template(template_file_name)
    prompt = template.substitute(schema_prompt=get_schema_prompt(),question = question, query = query, sql_error_message = sql_error_message)
    return prompt

def extract_json(response: str):
    json_response = None
    if '{' in response and '}' in response:
        response = response.replace('\n', '')
        json_response = json.loads(response)
    return json_response["explanation"], json_response["query"]
    
def get_schema_prompt() -> str:
    tbl_creation_sql = get_db_creation_sql(conn)
    template_file_name = "./prompts/schema_prompt.txt"
    template = load_template(template_file_name)
    prompt = template.substitute(tbl_creation_sql=tbl_creation_sql)
    return prompt

def get_prompt(question: str) -> str:
    schema_prompt = get_schema_prompt()
    template_file_name = "./prompts/user_prompt.txt"
    template = load_template(template_file_name)
    prompt = template.substitute(schema_prompt=schema_prompt, question=question)
    return prompt

def get_final_answer_prompt(question: str,query: str,explanation: str,result: pd.DataFrame) -> str:
    schema_prompt = get_schema_prompt()
    template_file_name = "./prompts/final_answer_prompt.txt"
    template = load_template(template_file_name)
    prompt = template.substitute(schema_prompt=schema_prompt, question=question, query=query, explanation=explanation, result=result.to_string())
    return prompt

def respond(question: str, chat_history: List[Tuple[str, str]]) -> Tuple:
    user_prompt = get_prompt(question)
    ua_response = call_openAI(user_prompt, True)
    explanation, query = extract_json(ua_response)

    if query is None:
        return '', chat_history, None, '', explanation

    success = False
    for _ in range(5):
        try:
            query_result = execute_query_pd(conn, query)
            success = True
            break
        except Exception as sql_error_message:
            sql_error_prompt = get_correction_prompt(question, query, str(sql_error_message))
            response = call_openAI(sql_error_prompt, True)
            explanation, query = extract_json(response)
            if query is None:
                return '', chat_history, None, '', explanation

    if success:
        final_answer_prompt = get_final_answer_prompt(question, query, explanation, query_result)
        chat_response = call_openAI(final_answer_prompt, False)
        chat_history.append((question, chat_response))
    else:
        query_result = ''

    return '', chat_history, query_result, query, explanation

In [7]:
respond("How many books are in the bookstore?", [])

('',
 [('How many books are in the bookstore?',
   'The question asks how many books are in the bookstore. The provided SQL query counts the total number of entries in the `Books` table by using the `COUNT(*)` function. \n\nFrom the execution of the query, the result indicates that there are a total of 24 books in the bookstore. \n\nSo, the answer is:\n\n**There are 24 books in the bookstore.**')],
    total_books
 0           24,
 'SELECT COUNT(*) AS total_books FROM Books;',
 'This query counts the total number of books in the Books table using the COUNT function.')

In [8]:
respond("How many books by Harper Lee are in the bookstore?", [])

('',
 [('How many books by Harper Lee are in the bookstore?',
   'The provided information indicates that there are 2 books by Harper Lee in the bookstore. The SQL query counts the number of books associated with Harper Lee by joining the Books and Authors tables and filtering for that specific author in the WHERE clause. The result shows that the count of those books is 2.')],
    book_count
 0           2,
 "SELECT COUNT(*) AS book_count FROM Books JOIN Authors ON Books.author_id = Authors.id WHERE Authors.name = 'Harper Lee';",
 "This query counts the number of books written by Harper Lee by joining the Books table with the Authors table on the author_id. It uses a WHERE clause to filter for the author with the name 'Harper Lee' and returns the count of books associated with that author.")

In [9]:
conn.close()