<a href="https://colab.research.google.com/github/duyvm/funny_stuff_with_llm/blob/main/learning-rag/Langchain_query_validation_SQL_qa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install --quiet --upgrade langchain-chroma langchain[openai] langchain langchain-community langgraph langchain-core langchain-text-splitters> /dev/null

In [2]:
from google.colab import userdata
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = f"langchain-learning-rag"
os.environ["LANGSMITH_API_KEY"] = userdata.get('LANGSMITH_API_KEY')
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')

# Overview

Guide: [How to do query validation as part of SQL question-answering](https://python.langchain.com/docs/how_to/sql_query_checking/)

## Objectives

Validation generated SQL query

# Load data

Load the Chinook db sql file and build the db

In [3]:
!sudo apt-get install sqlite3
!curl -s https://raw.githubusercontent.com/duyvm/funny_stuff_with_llm/refs/heads/main/learning-rag/db/Chinook_Sqlite.sql | sqlite3 Chinook.db

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Suggested packages:
  sqlite3-doc
The following NEW packages will be installed:
  sqlite3
0 upgraded, 1 newly installed, 0 to remove and 35 not upgraded.
Need to get 769 kB of archives.
After this operation, 1,873 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 sqlite3 amd64 3.37.2-2ubuntu0.4 [769 kB]
Fetched 769 kB in 1s (1,152 kB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 1.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 
Selecting previously unselected package sqlite3.
(Reading d

In [4]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db", sample_rows_in_table_info=3)
print(f"dialect: {db.dialect}")
print(f"table_names: {db.get_usable_table_names()}")
print(f"Test 'SELECT * FROM Artist LIMIT 10;': {db.run('SELECT * FROM Artist LIMIT 10;')}")

dialect: sqlite
table_names: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Test 'SELECT * FROM Artist LIMIT 10;': [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


- `create_sql_query_chain` will create chain based on input llm and db (automatically realize db dialect and use approriate template)

- make two llm calls: create query then validate it

In [35]:
from langchain.chat_models import init_chat_model
from langchain.chains import create_sql_query_chain

llm = init_chat_model("openai:gpt-4o-mini")
chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [41]:
# extend the chain by adding 2nd prompt
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
 - Using NOT IN with NULL values
 - Using UNION when UNION ALL should have been used
 - Using BETWEEN for exclusive ranges
 - Data type mismatch in predicates
 - Properly quoting identifiers
 - Using the correct number of arguments for functions
 - Casting to the correct data type
 - Using the proper columns for joins
 - Clarify which column belong to which table in a join query

If there are any of the above mistakes, rewrite the query.
If there are no mistakes, just reproduce the original query with no further comentary

Output the final SQL query only.
DO NOT output any markdown type text string."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)

validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": chain} | validation_chain

In [42]:
query = full_chain.invoke({"question": "What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010"})
print(query)

SELECT AVG("Total") AS "AverageInvoice"
FROM "Invoice" AS inv
JOIN "Customer" AS cust ON inv."CustomerId" = cust."CustomerId"
WHERE cust."Country" = 'USA'
AND cust."Fax" IS NULL
AND inv."InvoiceDate" >= '2003-01-01' 
AND inv."InvoiceDate" < '2010-01-01'
LIMIT 5;


In [43]:
db.run(query)

'[(None,)]'

- One llm call: create query and check in one prompt

In [46]:
system = """You are a {dialect} expert. Given an input question, create a syntactically correct {dialect} query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Do not output markdown-type text

Use format:

First draft: <<FIRST_DRAFT_QUERY>>
Final answer: <<FINAL_ANSWER_QUERY>>
"""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}")]
).partial(dialect=db.dialect)


def parse_final_answer(output: str) -> str:
    return output.split("Final answer: ")[1]


chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer
prompt.pretty_print()


You are a [33;1m[1;3m{dialect}[0m expert. Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most [33;1m[1;3m{top_k}[0m results using the LIMIT clause as per [33;1m[1;3m{dialect}[0m. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Write an initial draft of the query. T

In [47]:
query = chain.invoke(
    {
        "question": "What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010"
    }
)
print(query)

SELECT AVG("Total") AS "AverageInvoice" FROM "Invoice" INNER JOIN "Customer" ON "Invoice"."CustomerId" = "Customer"."CustomerId" WHERE "Customer"."Fax" IS NULL AND "Invoice"."InvoiceDate" >= '2003-01-01' AND "Invoice"."InvoiceDate" < '2010-01-01' AND "Customer"."Country" = 'United States';


In [48]:
db.run(query)

'[(None,)]'