In [1]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

load_dotenv()
import sqlalchemy
print(sqlalchemy.__version__)


2.0.34


**Set the environment variable and load the LLM**

In [5]:
sqldb_directory = here("data/travel.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["aircrafts_data"])  # 注意需要传递列表
print(f"Original table info: {table_info}")

db.run("SELECT * FROM aircrafts_data LIMIT 10;")
print(db.dialect)
print(db.get_usable_table_names(),'\n')

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'flights', 'seats', 'ticket_flights', 'tickets'] 



TypeError: must be real number, not str

In [6]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")


# llm = ChatOpenAI(temperature=0)
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# llm = ChatOpenAI(model="gpt-4o")

**Load and test the sqlite db**

In [7]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Invoice LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 2, '2021-01-01 00:00:00', 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', 1.98), (2, 4, '2021-01-02 00:00:00', 'Ullevålsveien 14', 'Oslo', None, 'Norway', '0171', 3.96), (3, 8, '2021-01-03 00:00:00', 'Grétrystraat 63', 'Brussels', None, 'Belgium', '1000', 5.94), (4, 14, '2021-01-06 00:00:00', '8210 111 ST NW', 'Edmonton', 'AB', 'Canada', 'T6G 2C7', 8.91), (5, 23, '2021-01-11 00:00:00', '69 Salem Street', 'Boston', 'MA', 'USA', '2113', 13.86), (6, 37, '2021-01-19 00:00:00', 'Berger Straße 10', 'Frankfurt', None, 'Germany', '60316', 0.99), (7, 38, '2021-02-01 00:00:00', 'Barbarossastraße 19', 'Berlin', None, 'Germany', '10779', 1.98), (8, 40, '2021-02-01 00:00:00', '8, Rue Hanovre', 'Paris', None, 'France', '75002', 1.98), (9, 42, '2021-02-02 00:00:00', '9, Place Louis Barthou', 'Bordeaux', None, 'France', '33000', 3.96), (10, 46, '2021-02-03 00:00:00', '3 Chatham Street', 'Dublin', 'Dublin', 'Ireland', None, 5.94)]"

In [8]:
table_info = db.get_table_info(["Employee"])  # 注意需要传递列表
print(f"Original table info: {table_info}")

Original table info: 
CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60), 
	PRIMARY KEY ("EmployeeId"), 
	FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId	LastName	FirstName	Title	ReportsTo	BirthDate	HireDate	Address	City	State	Country	PostalCode	Phone	Fax	Email
1	Adams	Andrew	General Manager	None	1962-02-18 00:00:00	2002-08-14 00:00:00	11120 Jasper Ave NW	Edmonton	AB	Canada	T5K 2N1	+1 (780) 428-9482	+1 (780) 428-3457	andrew@chinookcorp.com
2	Edwards	Nancy	Sales Manager	1	1958-12-08 00:00:00	2002-05-01 00:00:00	825 8 Ave SW	Calgary	AB	Canada	T2P 2T3	+1 (403) 262-34

In [9]:
from langchain import debug

debug = True  # 启用调试模式

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many Genre are there?"})
response


'SELECT COUNT("GenreId") AS NumberOfGenres\nFROM "Genre"'

In [10]:
# from langchain.prompts import PromptTemplate
# from langchain.chains import LLMChain

# custom_prompt = PromptTemplate(
#     input_variables=["question"],
#     template="""
#     You are an SQL expert. Given the input question, generate an SQL query that answers the question.
#     The query must be simple and efficient. For example, if the question is about counting rows,
#     the query should use SELECT COUNT(*) syntax.
#     Question: {question}
#     """
# )

# chain = LLMChain(llm=llm, prompt=custom_prompt)
# response = chain.invoke({"question": "How many rows are there in the aircrafts_data table?"})
# print(response['text'])


chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "What type of seat is available on the flight?"})
response

'SELECT DISTINCT "Total" FROM Invoice ORDER BY "Total" LIMIT 5;'

**Create the SQL agent and run a test query**

In [12]:
db.run(response)

'[(0.99,), (1.98,), (1.99,), (2.98,), (3.96,)]'

In [14]:
query = "SELECT COUNT(*) FROM aircrafts_data;"
result = db.run(query)
print(f"SQL Query Result: {result}")


SQL Query Result: [(9,)]
