https://python.langchain.com/docs/tutorials/sql_qa/

https://python.langchain.com/docs/how_to/sql_prompting/


In [1]:
from langchain_community.utilities import SQLDatabase

In [4]:
db = SQLDatabase.from_uri("iris://superuser:SYS@iris:1972/LLMRAG", sample_rows_in_table_info=3, schema='Holefoods')

In [5]:
print(db.dialect)
print(db.get_usable_table_names())

iris
['Country', 'Outlet', 'Product', 'Region', 'SalesTransaction']


In [6]:
print(db.run("SELECT TOP 10 * FROM Holefoods.SalesTransaction"))

[(1, 1, Decimal('2.66'), '2', None, datetime.date(2024, 7, 17), Decimal('0.1'), None, None, 33, 'SKU-101', None, 1, None), (2, 1, Decimal('5.36'), '2', None, datetime.date(2024, 7, 6), Decimal('0.1'), 39.673679, -89.711212, 26, 'SKU-222', None, 1, '62629'), (3, 1, Decimal('1.56'), '2', None, datetime.date(2021, 9, 10), Decimal('0.2'), None, None, 2, 'SKU-296', None, 1, None), (4, 1, Decimal('1.56'), None, None, datetime.date(2020, 3, 25), Decimal('0.2'), 42.291074, -86.274294, 26, 'SKU-296', None, 1, '49043'), (5, 1, Decimal('0.98'), '2', 'Shopper ranted about high prices', datetime.date(2023, 6, 11), Decimal('0.5'), None, None, 12, 'SKU-287', None, 1, None), (6, 1, Decimal('12.95'), '2', None, datetime.date(2020, 9, 24), 0, None, None, 1, 'SKU-195', None, 1, None), (7, 1, Decimal('44.75'), '2', None, datetime.date(2021, 10, 2), 0, 39.193793, -87.713073, 26, 'SKU-900', None, 5, '62478'), (8, 1, Decimal('0.92'), '2', None, datetime.date(2021, 6, 6), Decimal('0.2'), 30.529819, -96.714292

In [7]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Holefoods"."Country" (
	"ID" IDENTITY DEFAULT $i(^HoleFoods.CountryD) NOT NULL, 
	"Name" VARCHAR(90), 
	"Region" BIGINT, 
	CONSTRAINT "RowIDField_As_PKey" PRIMARY KEY ("ID")
) WITH %CLASSPARAMETER ALLOWIDENTITYINSERT = 1

/*
3 rows from Country table:
ID	Name	Region
1	China	1
2	India	1
3	Japan	1
*/


CREATE TABLE "Holefoods"."Outlet" (
	"ID" IDENTITY DEFAULT $i(^HoleFoods.OutletD) NOT NULL, 
	"City" VARCHAR(100), 
	"Country" BIGINT, 
	"Latitude" DOUBLE, 
	"Longitude" DOUBLE, 
	"Population" INTEGER, 
	"Type" VARCHAR(50), 
	CONSTRAINT "RowIDField_As_PKey" PRIMARY KEY ("ID")
) WITH %CLASSPARAMETER ALLOWIDENTITYINSERT = 1

/*
3 rows from Outlet table:
ID	City	Country	Latitude	Longitude	Population	Type
1	Beijing	1	39.86	116.412	17400000	None
2	Shanghai	1	31.224	121.466	16738000	None
3	Bangalore	2	12.963	77.587	6200000	None
*/


CREATE TABLE "Holefoods"."Product" (
	"ID" VARCHAR(22) NOT NULL, 
	"Category" VARCHAR(100), 
	"Name" VARCHAR(120), 
	"Pr

In [8]:
import getpass
import os

os.environ["MISTRAL_API_KEY"] = getpass.getpass()

from langchain_mistralai import ChatMistralAI

llm = ChatMistralAI(model="mistral-large-latest")

 ········


In [22]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

 ········


In [23]:
from langchain_core.prompts import PromptTemplate

# Define the custom prompt template
template = '''
You are an InterSystems IRIS SQL expert. 
Given an input question, first create a syntactically correct InterSystems IRIS SQL query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP as defined in InterSystems IRIS syntax: ```SELECT [DISTINCT] TOP int select-item, select-item,...```
Always specify table names using schema as prefix.
Do not use LIMIT clause as it is not correct in IRIS dialect.
Do not end SQL sentences with an ;
Do not enclose fields in quotes or double quotes.
Do not enclose table names in quotes or double quotes.
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".

Return only plain SQL without any formatting.

Only use the following tables:

{table_info}.
Question: {input}'''

# Create the PromptTemplate object
custom_prompt = PromptTemplate(
    input_variables=["input", "table_info", "top_k", "dialect"],
    template=template
)

In [24]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db, prompt=custom_prompt)
chain.get_prompts()[0].pretty_print()


You are an InterSystems IRIS SQL expert. 
Given an input question, first create a syntactically correct InterSystems IRIS SQL query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the TOP as defined in InterSystems IRIS syntax: ```SELECT [DISTINCT] TOP int select-item, select-item,...```
Always specify table names using schema as prefix.
Do not use LIMIT clause as it is not correct in IRIS dialect.
Do not end SQL sentences with an ;
Do not enclose fields in quotes or double quotes.
Do not enclose table names in quotes or double quotes.
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay at

In [25]:
from langchain.callbacks.tracers import ConsoleCallbackHandler

In [26]:
chain.invoke({"question": "how many products are there?"})

'SELECT COUNT(DISTINCT SKU) FROM Holefoods.Product'

In [27]:
chain.invoke({"question": "how many sales in 2023?"})

'SELECT COUNT(*) FROM Holefoods.SalesTransaction WHERE YEAR(DateOfSale) = 2023'

In [28]:
chain.invoke({"question": "what are the product categories?"})
# chain.invoke({"question": "what are the product categories?"}, config={'callbacks': [ConsoleCallbackHandler()]})

'SELECT DISTINCT TOP 5 Category FROM Holefoods.Product'

In [29]:
chain.invoke({"question": "what are the most sold product categories during 2023?"})

"SELECT TOP 5 p.Category, SUM(st.UnitsSold) AS TotalUnitsSold \nFROM Holefoods.SalesTransaction st \nJOIN Holefoods.Product p ON st.Product = p.SKU \nWHERE st.DateOfSale >= CAST('2023-01-01' AS date) AND st.DateOfSale < CAST('2024-01-01' AS date) \nGROUP BY p.Category \nORDER BY TotalUnitsSold DESC"

# Execute query

In [30]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
chain = write_query | execute_query
chain.invoke({"question": "what are the most sold product categories during 2023?"})

"[('PASTA', 104), ('VEGETABLE', 90), ('SNACK', 79), ('CANDY', 43), ('CEREAL', 37)]"

# Dynamic Fewshot examples

In [31]:
examples = [
    { 
        "input": "List all regions.", 
        "query": "SELECT ID, Name FROM HoleFoods.Region;"
    },
    {
        "input": "List all countries.",
        "query": "SELECT c.ID, c.Name, r.Name Region FROM HoleFoods.Country c JOIN HoleFoods.Region r on c.Region=r.ID"
    },
    {
        "input": "What are the different product categories ?",
        "query": "SELECT DISTINCT(Category) Categories FROM HoleFoods.Product"
    },
    {
        "input": "How many pasta products where sold online in 2023 ?",
        "query": "SELECT SUM(UnitsSold) FROM HoleFoods.SalesTransaction st JOIN HoleFoods.Product p ON st.Product=p.ID WHERE st.Channel='Online' AND YEAR(st.DateOfSale) = 2023 AND p.Category = 'Pasta'"
    },
    {
        "input": "Find all snack products",
        "query": "SELECT SKU, Name, Price FROM HoleFoods.Product p WHERE p.Category='Snack'"
    },
    {
        "input": "Find all candy products",
        "query": "SELECT SKU, Name, Price FROM HoleFoods.Product p WHERE p.Category='Candy'"
    },
    {
        "input": "How many products were sold in Europe in 2022 ?",
        "query": "SELECT SUM(UnitsSold) FROM HoleFoods.SalesTransaction st JOIN HoleFoods.Outlet o ON st.Outlet=o.ID JOIN HoleFoods.Country c ON o.Country=c.ID JOIN HoleFoods.Region r ON c.Region=r.ID WHERE r.Name='Europe' AND YEAR(st.DateOfSale) = 2022"
    }
]

In [32]:
from langchain_iris import IRISVector
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    IRISVector,
    k=3,
    input_keys=["input"],
    connection_string='iris://superuser:SYS@iris:1972/LLMRAG',
    collection_name="sql_samples",
    pre_delete_collection=True
)

In [33]:
example_selector.select_examples({"input": "how many products were sold in America?"})

[{'input': 'How many products were sold in Europe in 2022 ?',
  'query': "SELECT SUM(UnitsSold) FROM HoleFoods.SalesTransaction st JOIN HoleFoods.Outlet o ON st.Outlet=o.ID JOIN HoleFoods.Country c ON o.Country=c.ID JOIN HoleFoods.Region r ON c.Region=r.ID WHERE r.Name='Europe' AND YEAR(st.DateOfSale) = 2022"},
 {'input': 'How many pasta products where sold online in 2023 ?',
  'query': "SELECT SUM(UnitsSold) FROM HoleFoods.SalesTransaction st JOIN HoleFoods.Product p ON st.Product=p.ID WHERE st.Channel='Online' AND YEAR(st.DateOfSale) = 2023 AND p.Category = 'Pasta'"},
 {'input': 'What are the different product categories ?',
  'query': 'SELECT DISTINCT(Category) Categories FROM HoleFoods.Product'}]

In [34]:
from langchain_core.prompts import FewShotPromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")

prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=template,
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [35]:
print(prompt.format(input="how many products were sold in Europe?", top_k=3, table_info="foo"))


You are an InterSystems IRIS SQL expert. 
Given an input question, first create a syntactically correct InterSystems IRIS SQL query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 3 results using the TOP as defined in InterSystems IRIS syntax: ```SELECT [DISTINCT] TOP int select-item, select-item,...```
Always specify table names using schema as prefix.
Do not use LIMIT clause as it is not correct in IRIS dialect.
Do not end SQL sentences with an ;
Do not enclose fields in quotes or double quotes.
Do not enclose table names in quotes or double quotes.
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay at

In [36]:
chain = create_sql_query_chain(llm, db, prompt)
query = chain.invoke({"question": "how many products were sold in America?"})
query

"SELECT SUM(UnitsSold) FROM HoleFoods.SalesTransaction st JOIN HoleFoods.Outlet o ON st.Outlet=o.ID JOIN HoleFoods.Country c ON o.Country=c.ID JOIN HoleFoods.Region r ON c.Region=r.ID WHERE r.Name='N. America'"

In [37]:
print(db.run(query))

[(531,)]
