In [1]:
from langchain_community.utilities import SQLDatabase

In [6]:
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

In [7]:
print(db.dialect)

sqlite


In [8]:
print(db.get_usable_table_names())

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [9]:
db.run("SELECT * FROM Artist LIMIT 10;")

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [11]:
from typing_extensions import TypedDict

In [12]:
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [13]:
from dotenv import load_dotenv

In [17]:
load_dotenv()

True

In [18]:
from langchain.chat_models import init_chat_model

In [19]:
llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [20]:
from langchain import hub

In [21]:
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

In [22]:
query_prompt_template

ChatPromptTemplate(input_variables=['dialect', 'input', 'table_info', 'top_k'], input_types={}, partial_variables={}, metadata={'lc_hub_owner': 'langchain-ai', 'lc_hub_repo': 'sql-query-system-prompt', 'lc_hub_commit_hash': '360a0e9d0f0f5da0ee9810a2a0ea3c9dc3de31b3ae9d50272420c33e48e6e323'}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['dialect', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema

In [24]:
assert len(query_prompt_template.messages) == 2

In [25]:
for message in query_prompt_template.messages:
    message.pretty_print()


Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [26]:
from typing_extensions import Annotated

In [27]:
class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntatically valid SQL query."]

def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}
    

In [28]:
write_query({"question": "How many Employees are there?"})

{'query': 'SELECT COUNT(EmployeeId) AS NumberOfEmployees FROM Employee;'}

In [29]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

In [32]:
def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [33]:
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})

{'result': '[(8,)]'}