https://python.langchain.com/docs/tutorials/sql_qa/

In [1]:
from langchain_community.utilities import SQLDatabase

In [64]:
db = SQLDatabase.from_uri("iris://superuser:SYS@iris:1972/LLMRAG", sample_rows_in_table_info=3, schema='Holefoods')

In [65]:
print(db.dialect)
print(db.get_usable_table_names())

iris
['Country', 'Outlet', 'Product', 'Region', 'SalesTransaction']


In [4]:
print(db.run("SELECT TOP 10 * FROM Holefoods.SalesTransaction"))

[(1, 1, Decimal('1.95'), '2', None, datetime.date(2022, 4, 4), 0, None, None, 10, 'SKU-296', None, 1, None), (2, 1, Decimal('2.30'), '2', None, datetime.date(2023, 6, 6), 0, 33.875377, -84.685645, 24, 'SKU-451', None, 2, '30073'), (3, 1, Decimal('29.70'), '2', None, datetime.date(2019, 6, 24), 0, None, None, 4, 'SKU-708', None, 6, None), (4, 1, Decimal('69.93'), '2', None, datetime.date(2023, 5, 23), Decimal('0.1'), None, None, 33, 'SKU-195', None, 6, None), (5, 1, Decimal('1.48'), '2', None, datetime.date(2020, 8, 27), Decimal('0.5'), None, None, 22, 'SKU-101', None, 1, None), (6, 1, Decimal('2.95'), '2', None, datetime.date(2024, 3, 28), 0, None, None, 9, 'SKU-192', None, 1, None), (7, 1, Decimal('8.95'), None, None, datetime.date(2019, 10, 25), 0, None, None, 17, 'SKU-900', None, 1, None), (8, 1, Decimal('20.66'), '2', None, datetime.date(2022, 4, 29), Decimal('0.1'), None, None, 6, 'SKU-601', None, 1, None), (9, 1, Decimal('11.66'), '2', None, datetime.date(2021, 3, 21), Decimal('0

In [66]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Holefoods"."Country" (
	"ID" IDENTITY DEFAULT $i(^HoleFoods.CountryD) NOT NULL, 
	"Name" VARCHAR(90), 
	"Region" BIGINT, 
	CONSTRAINT "RowIDField_As_PKey" PRIMARY KEY ("ID")
) WITH %CLASSPARAMETER ALLOWIDENTITYINSERT = 1

/*
3 rows from Country table:
ID	Name	Region
1	China	1
2	India	1
3	Japan	1
*/


CREATE TABLE "Holefoods"."Outlet" (
	"ID" IDENTITY DEFAULT $i(^HoleFoods.OutletD) NOT NULL, 
	"City" VARCHAR(100), 
	"Country" BIGINT, 
	"Latitude" DOUBLE, 
	"Longitude" DOUBLE, 
	"Population" INTEGER, 
	"Type" VARCHAR(50), 
	CONSTRAINT "RowIDField_As_PKey" PRIMARY KEY ("ID")
) WITH %CLASSPARAMETER ALLOWIDENTITYINSERT = 1

/*
3 rows from Outlet table:
ID	City	Country	Latitude	Longitude	Population	Type
1	Beijing	1	39.86	116.412	17400000	None
2	Shanghai	1	31.224	121.466	16738000	None
3	Bangalore	2	12.963	77.587	6200000	None
*/


CREATE TABLE "Holefoods"."Product" (
	"ID" VARCHAR(22) NOT NULL, 
	"Category" VARCHAR(100), 
	"Name" VARCHAR(120), 
	"Pr

In [8]:
import getpass
import os

os.environ["MISTRAL_API_KEY"] = getpass.getpass()

from langchain_mistralai import ChatMistralAI

llm = ChatMistralAI(model="mistral-large-latest")

 ········


In [14]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

 ········


In [58]:
from langchain_core.prompts import PromptTemplate

# Define the custom prompt template
template = '''
You are an InterSystems IRIS SQL expert. 
Given an input question, first create a syntactically correct InterSystems IRIS SQL query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP as defined in InterSystems IRIS syntax: ```SELECT [DISTINCT] TOP int select-item, select-item,...```
Always specify table names using schema as prefix.
Do not use LIMIT clause as it is not correct in IRIS dialect.
Do not end SQL sentences with an ;
Do not enclose fields in quotes or double quotes.
Do not enclose table names in quotes or double quotes.
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".

Return only plain SQL without any formatting.

Only use the following tables:

{table_info}.
Question: {input}'''

# Create the PromptTemplate object
custom_prompt = PromptTemplate(
    input_variables=["input", "table_info", "top_k", "dialect"],
    template=template
)

In [59]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db, prompt=custom_prompt)
chain.get_prompts()[0].pretty_print()


You are an InterSystems IRIS SQL expert. 
Given an input question, first create a syntactically correct InterSystems IRIS SQL query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the TOP as defined in InterSystems IRIS syntax: ```SELECT [DISTINCT] TOP int select-item, select-item,...```
Always specify table names using schema as prefix.
Do not use LIMIT clause as it is not correct in IRIS dialect.
Do not end SQL sentences with an ;
Do not enclose fields in quotes or double quotes.
Do not enclose table names in quotes or double quotes.
You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay at

In [20]:
from langchain.callbacks.tracers import ConsoleCallbackHandler

In [60]:
chain.invoke({"question": "how many products are there?"})

'SELECT COUNT(DISTINCT SKU) FROM Holefoods.Product'

In [49]:
chain.invoke({"question": "how many sales in 2023?"})

'SELECT COUNT(*) AS SalesCount FROM Holefoods.SalesTransaction WHERE YEAR(DateOfSale) = 2023'

In [61]:
chain.invoke({"question": "what are the product categories?"})
# chain.invoke({"question": "what are the product categories?"}, config={'callbacks': [ConsoleCallbackHandler()]})

'SELECT DISTINCT TOP 5 Category FROM Holefoods.Product'

In [62]:
chain.invoke({"question": "what are the most sold product categories during 2023?"})

"SELECT TOP 5 p.Category, SUM(st.UnitsSold) AS TotalUnitsSold \nFROM Holefoods.SalesTransaction st \nJOIN Holefoods.Product p ON st.Product = p.SKU \nWHERE st.DateOfSale >= CAST('2023-01-01' AS date) AND st.DateOfSale < CAST('2024-01-01' AS date) \nGROUP BY p.Category \nORDER BY TotalUnitsSold DESC"

In [63]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
chain = write_query | execute_query
chain.invoke({"question": "what are the most sold product categories during 2023?"})

"[('PASTA', 113), ('SNACK', 84), ('FRUIT', 37), ('SEAFOOD', 33), ('CEREAL', 24)]"