In [2]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
import os
from dotenv import load_dotenv
load_dotenv()

True

**Set the environment variables and load the LLM**

In [3]:
import getpass
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

# 如果没有设置 GROQ_API_KEY，则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
    
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"\n Original table info: {table_info}")

   
#  初始化 Llama 模型，使用 Groq 后端
llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)


 Original table info: 
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


**Load and test the sqlite db**

In [None]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Invoice LIMIT 10;")

# from sqlalchemy import create_engine, inspect
# from sqlalchemy.orm import sessionmaker
# engine = create_engine(db_path)

# # Create a session
# Session = sessionmaker(bind=engine)
# session = Session()

# # Use SQLAlchemy's Inspector to get database information
# inspector = inspect(engine)

# # Get table names
# tables = inspector.get_table_names()
# print("Tables in the database:", tables)
# print(len(tables))

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 2, '2021-01-01 00:00:00', 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', 1.98), (2, 4, '2021-01-02 00:00:00', 'Ullevålsveien 14', 'Oslo', None, 'Norway', '0171', 3.96), (3, 8, '2021-01-03 00:00:00', 'Grétrystraat 63', 'Brussels', None, 'Belgium', '1000', 5.94), (4, 14, '2021-01-06 00:00:00', '8210 111 ST NW', 'Edmonton', 'AB', 'Canada', 'T6G 2C7', 8.91), (5, 23, '2021-01-11 00:00:00', '69 Salem Street', 'Boston', 'MA', 'USA', '2113', 13.86), (6, 37, '2021-01-19 00:00:00', 'Berger Straße 10', 'Frankfurt', None, 'Germany', '60316', 0.99), (7, 38, '2021-02-01 00:00:00', 'Barbarossastraße 19', 'Berlin', None, 'Germany', '10779', 1.98), (8, 40, '2021-02-01 00:00:00', '8, Rue Hanovre', 'Paris', None, 'France', '75002', 1.98), (9, 42, '2021-02-02 00:00:00', '9, Place Louis Barthou', 'Bordeaux', None, 'France', '33000', 3.96), (10, 46, '2021-02-03 00:00:00', '3 Chatham Street', 'Dublin', 'Dublin', 'Ireland', None, 5.94)]"

**Create the SQL agent chain and run a test query**

In [None]:
# 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)


write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)

execute_query = QuerySQLDataBaseTool(db=db)

system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """

answer_prompt = PromptTemplate.from_template(
    system_role)

answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)



In [None]:
msg = "How many Playlist are there? and what are their names?"

# 构造输入数据字典，其中包含方言、表结构、问题和行数限制
input_data = {
    "dialect": db.dialect,                    # 数据库方言，如 "sqlite"
    "table_info": db.get_table_info(),          # 表结构信息
    "question": msg,
    "top_k": 55
}

response = chain.invoke(input_data)
response

"There are 13 playlists, and their names are:\n\n1. '90’s Music\n2. Audiobooks\n3. Brazilian Music\n4. Classical\n5. Classical 101 - Deep Cuts\n6. Classical 101 - Next Steps\n7. Classical 101 - The Basics\n8. Grunge\n9. Heavy Metal Classic\n10. Movies\n11. Music\n12. Music Videos\n13. On-The-Go 1\n14. TV Shows"

**Travel SQL-agent Tool Design**

In [51]:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import os
from dotenv import load_dotenv
load_dotenv()


class TravelSQLAgentTool:
    """
    A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.

    This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
    The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
    generate a final answer for the user.

    Attributes:
        sql_agent_llm (LLAMA): An instance of a LLAMA language model used to generate and process SQL queries.
        system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
        db (SQLDatabase): An instance of the SQL database used to execute queries.
        chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.

    Methods:
        __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
        """
        Initializes the TravelSQLAgentTool with the necessary configurations.

        Args:
            llm (str): The name of the language model to be used for generating and interpreting SQL queries.
            sqldb_directory (str): The directory path where the SQLite database is stored.
            llm_temerature (float): The temperature setting for the language model, controlling response randomness.
        """
        #  初始化 Llama 模型，使用 Groq 后端
        #  "llama-3.3-70b-specdec"
        self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temerature)

        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
#         print(self.db.get_usable_table_names())

        # 定义自定义提示模板，用于生成 SQL 查询
        custom_prompt = PromptTemplate(
            input_variables=["dialect", "input", "table_info", "top_k"],
            template="""You are a SQL expert using {dialect}.
        Given the following table schema:
        {table_info}
        Generate a syntactically correct SQL query to answer the question: "{input}".
        Do not Limit {top_k} the results.
        Return only the SQL query without any additional commentary or Markdown formatting.
        """
        )

        # write_query
        write_query = create_sql_query_chain(self.sql_agent_llm, self.db,prompt=custom_prompt)
        execute_query = QuerySQLDataBaseTool(db=self.db)

        # answer
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        answer_prompt = PromptTemplate.from_template(
            self.system_role)
        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        
        # 8. 定义一个调试链 debug_chain，用于打印 write_query 生成的 SQL 查询。
        #    这里使用 RunnablePassthrough 执行一个 lambda 函数：
        #    lambda data: (print("write_query execution result:", data["query"]), data)[1]
        #    解释：先打印 data 字典中 "query" 对应的 SQL 语句，然后将原始 data 返回，以便后续链继续处理。
        debug_chain = RunnablePassthrough(lambda data: (print("write_query execution result:", data["query"]), data)[1])
        
        # 9. 构造完整的处理链 chain_ex：
        #    - 首先调用 write_query 生成 SQL 查询，并将结果存储到字典的 "query" 字段中；
        #    - 接着通过 debug_chain 打印出生成的 SQL 查询；
        #    - 然后使用 execute_query 执行 SQL 查询，结果存入 "result" 字段（此处利用 itemgetter 提取 "query" 字段后传递给执行工具）；
        #    - 最后将问题、SQL 查询以及查询结果传递给 answer 链，生成最终回答。
        # chain
        chain_ex = (
            RunnablePassthrough.assign(query=write_query)
            | debug_chain
            | RunnablePassthrough.assign(result=itemgetter("query") | execute_query)
            | answer
        )
        
        # 利用 bind 绑定固定参数到链中
        bound_chain = chain_ex.bind(
            dialect=self.db.dialect,
            table_info=self.db.get_table_info(),
            top_k=55
        )
        self.chain = bound_chain
        
sqldb_directory = here("data/Chinook.db")
query = "and calculate the number of all Playlist"

@tool
def query_travel_sqldb(query: str) -> str:

    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm="llama3-70b-8192", #TOOLS_CFG.travel_sqlagent_llm
        sqldb_directory= sqldb_directory,#TOOLS_CFG.travel_sqldb_directory
        llm_temerature=0 #TOOLS_CFG.travel_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response

print(query_travel_sqldb(query))

write_query execution result: SELECT COUNT(*) 
FROM Playlist;
There are 18 playlists.


In [None]:
query = "How many Playlist are there? and what are their names?"

# 构造输入数据字典，其中包含方言、表结构、问题和行数限制
input_data = {
    "dialect": db.dialect,                    # 数据库方言，如 "sqlite"
    "table_info": db.get_table_info(),          # 表结构信息
    "question": query,
    "top_k": 55
}

In [27]:
# from agent_graph.load_tools_config import LoadToolsConfig

# TOOLS_CFG = LoadToolsConfig()

sqldb_directory = here("data/Chinook.db")
query = "How many Playlist are there? and what are their names?"

@tool
def query_travel_sqldb(query: str) -> str:

    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm="llama3-70b-8192", #TOOLS_CFG.travel_sqlagent_llm
        sqldb_directory= sqldb_directory,#TOOLS_CFG.travel_sqldb_directory
        llm_temerature=0 #TOOLS_CFG.travel_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response

result = query_travel_sqldb(query)
print(result)

There are 12 playlists, and their names are:

1. 90’s Music
2. Audiobooks
3. Brazilian Music
4. Classical
5. Classical 101 - Deep Cuts
6. Classical 101 - Next Steps
7. Classical 101 - The Basics
8. Grunge
9. Heavy Metal Classic
10. Movies
11. Music
12. Music Videos
13. On-The-Go 1
14. TV Shows


In [29]:
db.run('SELECT COUNT(*) FROM Playlist;')

'[(18,)]'

In [30]:
db.run('SELECT Name FROM Playlist;')

"[('Music',), ('Movies',), ('TV Shows',), ('Audiobooks',), ('90’s Music',), ('Audiobooks',), ('Movies',), ('Music',), ('Music Videos',), ('TV Shows',), ('Brazilian Music',), ('Classical',), ('Classical 101 - Deep Cuts',), ('Classical 101 - Next Steps',), ('Classical 101 - The Basics',), ('Grunge',), ('Heavy Metal Classic',), ('On-The-Go 1',)]"