In [2]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

# 强制覆盖已存在的环境变量load_dotenv(override=True)
import sqlalchemy
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"Original table info: {table_info}")

Original table info: 
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


**Set the environment variable and load the LLM**

In [None]:
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"Original table info: {table_info}")

result = db.run("SELECT * FROM Album LIMIT 10;")
print(result)
print(db.dialect)
print(db.get_usable_table_names(),'\n')

In [19]:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

# 强制覆盖已存在的环境变量load_dotenv(override=True)


# 如果没有设置 GROQ_API_KEY，则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
    
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"\n Original table info: {table_info}")

   
#  初始化 Llama 模型，使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)


write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典，其中包含方言、表结构、问题和行数限制
input_data = {
    "dialect": db.dialect,                    # 数据库方言，如 "sqlite"
    "table_info": db.get_table_info(),          # 表结构信息
    "question": "What name of MediaType is?",
    "top_k": 5
}

# 调用链生成 SQL 查询，返回结果为一个字典，包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result：',write_query_response)

#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result：',execute_response)

#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)


 Original table info: 
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/

 write_query result： SELECT Name FROM MediaType LIMIT 5

 execute_response result： [('MPEG audio file',), ('Protected AAC audio file',), ('Protected MPEG-4 video file',), ('Purchased AAC audio file',), ('AAC audio file',)]

 result_chain== [('MPEG audio file',), ('Protected AAC audio file',), ('Protected MPEG-4 video file',), ('Purchased AAC audio file',), ('AAC audio file',)]


**Load and test the sqlite db**

In [17]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Invoice LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 2, '2021-01-01 00:00:00', 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', 1.98), (2, 4, '2021-01-02 00:00:00', 'Ullevålsveien 14', 'Oslo', None, 'Norway', '0171', 3.96), (3, 8, '2021-01-03 00:00:00', 'Grétrystraat 63', 'Brussels', None, 'Belgium', '1000', 5.94), (4, 14, '2021-01-06 00:00:00', '8210 111 ST NW', 'Edmonton', 'AB', 'Canada', 'T6G 2C7', 8.91), (5, 23, '2021-01-11 00:00:00', '69 Salem Street', 'Boston', 'MA', 'USA', '2113', 13.86), (6, 37, '2021-01-19 00:00:00', 'Berger Straße 10', 'Frankfurt', None, 'Germany', '60316', 0.99), (7, 38, '2021-02-01 00:00:00', 'Barbarossastraße 19', 'Berlin', None, 'Germany', '10779', 1.98), (8, 40, '2021-02-01 00:00:00', '8, Rue Hanovre', 'Paris', None, 'France', '75002', 1.98), (9, 42, '2021-02-02 00:00:00', '9, Place Louis Barthou', 'Bordeaux', None, 'France', '33000', 3.96), (10, 46, '2021-02-03 00:00:00', '3 Chatham Street', 'Dublin', 'Dublin', 'Ireland', None, 5.94)]"

In [18]:
table_info = db.get_table_info(["Employee"])  # 注意需要传递列表
print(f"Original table info: {table_info}")

Original table info: 
CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60), 
	PRIMARY KEY ("EmployeeId"), 
	FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId	LastName	FirstName	Title	ReportsTo	BirthDate	HireDate	Address	City	State	Country	PostalCode	Phone	Fax	Email
1	Adams	Andrew	General Manager	None	1962-02-18 00:00:00	2002-08-14 00:00:00	11120 Jasper Ave NW	Edmonton	AB	Canada	T5K 2N1	+1 (780) 428-9482	+1 (780) 428-3457	andrew@chinookcorp.com
2	Edwards	Nancy	Sales Manager	1	1958-12-08 00:00:00	2002-05-01 00:00:00	825 8 Ave SW	Calgary	AB	Canada	T2P 2T3	+1 (403) 262-34

In [16]:
from langchain import debug

debug = True  # 启用调试模式

import getpass
import os
from langchain.chat_models import init_chat_model
# from langchain_core.prompts import PromptTemplate
# from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# 强制覆盖已存在的环境变量load_dotenv(override=True)


# 如果没有设置 GROQ_API_KEY，则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
    
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
# print(f"\n Original table info: {table_info}")

   
#  初始化 Llama 模型，使用 Groq 后端
llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

write_chain = create_sql_query_chain(llm, db)
response = write_chain.invoke({"question": "What name of MediaType is?"})
# print(response,'\n')

system = """Double check the user's {dialect} query for common mistakes, including:
- Only return SQL Query not anything else like ```sql ... ```
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates\
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query.
If there are no mistakes, just reproduce the original query with no further commentary.

Output the final SQL query only."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)

validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": write_chain} | validation_chain
query = full_chain.invoke(
    {
        "question": "What name of MediaType is?"
    }
)
# print(query)

db.run(query)

# chain = create_sql_query_chain(llm, db)
# response = write_chain.invoke({"question": "How many Genre are there?"})
# response


"[('MPEG audio file',), ('Protected AAC audio file',), ('Protected MPEG-4 video file',), ('Purchased AAC audio file',), ('AAC audio file',)]"