In [1]:
import os
import sqlite3
from dotenv import dotenv_values
from langchain_openai import ChatOpenAI

config = {**dotenv_values("../configs/local.env")}

In [2]:
os.environ["OPENAI_API_KEY"] = config["OPENAI_API_KEY"]

In [3]:
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")

In [14]:
from operator import itemgetter
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase

from sql_table_qa.answerers.langchain_answerer.langchain_sql_connector import execute_sql
from CONSTANTS import ROOT_DIR

execute_query = execute_sql
db = SQLDatabase.from_uri(f"sqlite:///{ROOT_DIR}/data/Chinook.db")
write_query = create_sql_query_chain(llm, db)

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
    
    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer: """
)

answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

  sample_rows_result = connection.execute(command)  # type: ignore


'There are 8 employees.'

In [5]:
chain.invoke({"question": "Which customer has spent the most money in total?"})

'The SQL query is attempting to retrieve the customer who has spent the most money in total by joining the customers and invoices tables, summing up the total amount spent by each customer, and ordering the results in descending order. However, the error "no such table: customers" indicates that the \'customers\' table does not exist in the database. The query needs to be modified to reference the correct table containing customer information.'

In [6]:
RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | execute_query).invoke({"question": "Which customer has spent the most money in total?"})

{'question': 'Which customer has spent the most money in total?',
 'query': 'SELECT "CustomerID", SUM("TotalAmount") AS total_spent\nFROM "Orders"\nGROUP BY "CustomerID"\nORDER BY total_spent DESC\nLIMIT 1;',
 'result': 'no such table: Orders'}

In [9]:
write_query

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for k, v in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. B

In [10]:
write_query.invoke({"question": "Which customer has spent the most money in total?"})

'SELECT "CustomerID", SUM("Total") AS total_spent\nFROM orders\nGROUP BY "CustomerID"\nORDER BY total_spent DESC\nLIMIT 1;'