In [1]:
from dotenv import load_dotenv,find_dotenv
from typing import Optional,List

from langchain_core.pydantic_v1 import BaseModel, Field
_ = load_dotenv(find_dotenv())
from langchain_openai import ChatOpenAI

chatbot = ChatOpenAI(model="gpt-4o-mini")

In [13]:
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
db=SQLDatabase.from_uri("sqlite:///data/street_tree_db.sqlite")
write_chain=create_sql_query_chain(chatbot,db)
write_chain.get_prompts()[0].pretty_print()
# response

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [19]:
response = write_chain.invoke({"question": "How many species of trees are in San Francisco?"})
print(response)

```sql
SQLQuery: SELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;
```


In [24]:
cleaned_response = response.replace("```sql\n", "").replace("\n```", "").replace("SQLQuery:","")
print(cleaned_response)
db.run(cleaned_response)

 SELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;


'[("Arbutus \'Marina\' :: Hybrid Strawberry Tree",), (\'Afrocarpus gracilior :: Fern Pine\',), ("Thuja occidentalis \'Emerald\' :: Emerald Arborvitae",), ("Magnolia grandiflora \'Little Gem\' :: Little Gem Magnolia",), (\'Platanus x hispanica :: Sycamore: London Plane\',)]'

In [33]:
from operator import itemgetter
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
def replace_sql_formatting(response: str) -> str:
    print(response)
    return response.replace("```sql\n", "").replace("\n```", "").replace("SQLQuery:","")
chain_res= write_chain | RunnableLambda(replace_sql_formatting)
chain_res.invoke({"question": "How many species of trees are in San Francisco?"})

SQLQuery: SELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;


' SELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;'

In [52]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, QuerySQLCheckerTool

execute_chain = QuerySQLDataBaseTool(db=db)
check_chain = QuerySQLCheckerTool(db=db, llm=chatbot)
'''
RunnablePassthrough will add question, user input, as output and then assign will add query,output of WriteChain, to same output. Then itemgetter gets the query and pass it to RunnableLambda with replace_sql_formatting function. Then execute_chain is assigned to the output of RunnableLambda. And output of execute_chain is the final output & saved as result.
'''
chain = RunnablePassthrough.assign(query=write_chain).assign(result=itemgetter("query") | RunnableLambda(replace_sql_formatting)|execute_chain)

response = chain.invoke({"question": "How many species of trees are in San Francisco?"})
print(response)

```sql
SELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;
```
{'question': 'How many species of trees are in San Francisco?', 'query': '```sql\nSELECT DISTINCT "qSpecies" FROM street_trees LIMIT 5;\n```', 'result': '[("Arbutus \'Marina\' :: Hybrid Strawberry Tree",), (\'Afrocarpus gracilior :: Fern Pine\',), ("Thuja occidentalis \'Emerald\' :: Emerald Arborvitae",), ("Magnolia grandiflora \'Little Gem\' :: Little Gem Magnolia",), (\'Platanus x hispanica :: Sycamore: London Plane\',)]'}


In [53]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, 
    corresponding SQL query, and SQL result, 
    answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)
chat_chain = chain|answer_prompt|chatbot|StrOutputParser()
# By default the limit is set to 5, hence asking to unset to it.
response = chat_chain.invoke({"question": "How many species of trees are in San Francisco? Do unset the default limit"})
print(response)

```sql
SQLQuery: SELECT DISTINCT "qSpecies" FROM street_trees
```
There are 113 distinct species of trees in San Francisco.


Let's review what is happening in the above chain.
The user asks a question (identified by the variable name "question").
We use RunnablePassthrough to get that "question" variable, and we use .assign() twice to get the other two variables required by the prompt template: "query" and "result".
With the first .assign(), the write_query chain has que question as input and the SQL query (identified by the variable name "query") as output.
With the second .assign(), the execute_query chain has the SQL query as input and the SQL query execution (identified by the variable name "result") as output.
The prompt template has the question (identified by the variable name "question"), the SQL query (identified by the variable name "query") and the SQL query execution (identified by the variable name "result") as input, and the final prompt as the output.
The chat model has the prompt as he input and the AIMessage with the response in natural language as the output.
The StrOutputParser has the AIMessage with the response in natural language as the input and the response in natural language as a string of text as the output.