**Reference: https://python.langchain.com/v0.1/docs/use_cases/sql/large_db/**

What happens in this notebook:

### **Table Model Definition**
   - **`Table` Class**: This is a simple Pydantic model representing a SQL table. It has one attribute, `name`, which is a string and is described as "Name of table in SQL database."
     - This model is used in the extraction process to match relevant SQL tables based on the user's query.

### **Helper Function - `get_tables`**
   - **`get_tables`**: This function takes a list of `Table` objects (i.e., categories such as "Music" or "Business") and returns a list of corresponding SQL table names based on the category.
     - For example, if the category is `"Music"`, the tables `"Album"`, `"Artist"`, `"Genre"`, etc., are added to the result.
     - Similarly, for `"Business"`, the corresponding tables like `"Customer"`, `"Employee"`, etc., are included.

### **Designing the agent for the large DB**

- **Step 1: Initialize LLM (`sql_agent_llm`)**: The LLM is instantiated with a given model (e.g., `"gpt-3.5-turbo"`) and temperature. The temperature controls how creative/random the model's responses are.
- **Step 2: Connect to the SQL Database (`db`)**: The connection to the Chinook SQLite database is established. The database URI is constructed using the `sqldb_directory` provided.
- **Step 3: Define Category Chain (`category_chain`)**: The `category_chain_system` is defined, which is a string explaining the categories available (like "Music" and "Business"). This chain determines which SQL tables are relevant to the user query based on the category.
- **Step 4: Chain Creation**:
- **`category_chain`**: This uses the `create_extraction_chain_pydantic` function, which creates an extraction chain that identifies relevant SQL tables from the user's question using the `Table` Pydantic model and the LLM.
- **`table_chain`**: A chain is formed by combining the output from `category_chain` with the `get_tables` function, so it maps categories to the actual SQL tables.
- **Step 5: Query Chain (`query_chain`)**: This creates a SQL query chain using the LLM and the database (`self.db`). It takes the SQL tables and constructs a query.
- **Step 6: Table Chain Input Handling**: The `"question"` key from the user input is mapped to the `"input"` key expected by the `table_chain`. This enables the chain to process user queries correctly.
- **Step 7: Full Chain Construction**: Finally, the full chain (`full_chain`) is created by combining:
1. **`RunnablePassthrough.assign`**: This sets up a step that assigns the `table_names_to_use` using the result of the `table_chain`.
2. **`query_chain`**: Executes the SQL query once the relevant tables are identified.

In [3]:
import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint

from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

# 强制覆盖已存在的环境变量
load_dotenv(override=True)


True

**Set the environment variables and load the LLM**

In [6]:
# 如果没有设置 GROQ_API_KEY，则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
#  初始化 Llama 模型，使用 Groq 后端
table_extractor_llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)

In [39]:
sqldb_directory = here("data/travel2.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())

# "SELECT * FROM tickets LIMIT 10;"
# sql_txt = "SELECT * FROM flights WHERE destination = 'Beijing' AND departure_date > '2025-05-02';"
# db.run(sql_txt)

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


In [41]:
# 导入 LangChain 中的 SQLDatabase 类，用于连接数据库
from langchain.sql_database import SQLDatabase
# 导入 SQLAlchemy 的 inspect 函数，方便获取数据库的元数据（表、字段等信息）
from sqlalchemy import inspect

# 1. 获取数据库文件路径
#    使用 here() 函数获取 travel2.sqlite 文件的绝对路径，这里假设函数 here 已经定义好
sqldb_directory = here("data/travel2.sqlite")

# 2. 通过 SQLDatabase.from_uri 创建数据库连接对象
#    f"sqlite:///{sqldb_directory}" 构造了 SQLite 数据库的 URI
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
#    例如，如果 sqldb_directory 是 "/home/user/project/data/travel2.sqlite"，则 URI 为 "sqlite:////home/user/project/data/travel2.sqlite"

# 3. 使用 SQLAlchemy 的 inspect 创建一个检查器对象，通过它可以获取数据库的元数据信息
#    db 内部保存了 SQLAlchemy 的 engine 对象，这里使用 db._engine 来创建 Inspector
inspector = inspect(db._engine)

# 4. 获取数据库中所有表的名称
#    db.get_table_names() 返回一个列表，包含数据库中所有表名，例如 ['flights', 'users', 'bookings']
table_names = db.get_table_names()

# 5. 初始化一个字典，用于存储每个表及其字段信息
db_schema = {}

# 6. 遍历每个表，利用 inspector.get_columns 获取该表的所有列信息
for table in table_names:
    # inspector.get_columns(table) 返回一个列表，每个元素是一个字典，描述一个字段的信息
    # 例如，对于 flights 表，可能返回：
    # [{'name': 'id', 'type': 'INTEGER', ...}, {'name': 'departure_date', 'type': 'TEXT', ...}, ...]
    columns_info = inspector.get_columns(table)
    
    # 7. 从每个字典中提取字段名称，形成一个字段名称列表
    columns = [col["name"] for col in columns_info]
    
    # 8. 将当前表的字段列表存入 db_schema 字典，键为表名，值为字段名称列表
    db_schema[table] = columns

# 9. 打印出整个数据库的模式信息，帮助检查所有表和字段
print("数据库模式信息：")
for table, columns in db_schema.items():
    print(f"表 {table} 的字段: {columns}")
#    举例输出：
#      表 flights 的字段: ['id', 'departure_date', 'arrival_time']
#      表 users 的字段: ['id', 'username', 'email']

# 10. 将数据库模式信息转换为一个上下文字符串，供 SQL 生成模型使用
schema_context = ""
for table, columns in db_schema.items():
    # 每个表的信息格式为 "Table 表名: 字段1, 字段2, ..."，并以换行符分隔不同表的信息
    schema_context += f"Table {table}: " + ", ".join(columns) + "\n"

# 打印生成的上下文信息，便于检查和调试
print("\n上下文信息：")
print(schema_context)
#    举例输出：
#      Table flights: id, departure_date, arrival_time
#      Table users: id, username, email

# 11. 模拟 SQL 生成函数
def create_sql_query_chain(query, context):
    """
    模拟生成 SQL 语句的函数：
    - query 参数表示用户的查询条件，比如 "Beijing" 表示查询目的地为北京。
    - context 参数为数据库模式上下文字符串，包含所有表和字段的信息。
    
    函数根据上下文信息判断是否存在 'destination' 字段，
    如果存在则生成 SQL 语句，否则返回错误提示。
    """
    if "destination" in context:
        # 如果上下文中包含 destination 字段，生成一条查询 flights 表的 SQL 语句
        return f"SELECT * FROM flights WHERE destination = '{query}' AND departure_date > '2025-05-02';"
    else:
        # 如果上下文中没有 destination 字段，说明生成 SQL 可能会出错，因此返回提示信息
        return "生成的 SQL 语句错误：缺少 destination 字段"

# 12. 调用 SQL 生成函数，并传入示例查询参数 "Beijing" 以及生成的数据库模式上下文信息
generated_sql = create_sql_query_chain("Beijing", schema_context)
print("\n生成的 SQL 语句：")
print(generated_sql)
#    举例说明：
#      如果 flights 表中没有 destination 字段，则函数会返回错误提示，
#      否则就会生成一条正确的 SQL 查询语句。

# 13. 完成所有操作后关闭数据库连接
#    SQLDatabase 对象不一定要求手动关闭，如果需要可以调用相关方法，这里假设不再需要额外关闭


数据库模式信息：
表 aircrafts_data 的字段: ['aircraft_code', 'model', 'range']
表 airports_data 的字段: ['airport_code', 'airport_name', 'city', 'coordinates', 'timezone']
表 boarding_passes 的字段: ['ticket_no', 'flight_id', 'boarding_no', 'seat_no']
表 bookings 的字段: ['book_ref', 'book_date', 'total_amount']
表 car_rentals 的字段: ['id', 'name', 'location', 'price_tier', 'start_date', 'end_date', 'booked']
表 flights 的字段: ['flight_id', 'flight_no', 'scheduled_departure', 'scheduled_arrival', 'departure_airport', 'arrival_airport', 'status', 'aircraft_code', 'actual_departure', 'actual_arrival']
表 hotels 的字段: ['id', 'name', 'location', 'price_tier', 'checkin_date', 'checkout_date', 'booked']
表 seats 的字段: ['aircraft_code', 'seat_no', 'fare_conditions']
表 ticket_flights 的字段: ['ticket_no', 'flight_id', 'fare_conditions', 'amount']
表 tickets 的字段: ['ticket_no', 'book_ref', 'passenger_id']
表 trip_recommendations 的字段: ['id', 'name', 'location', 'keywords', 'details', 'booked']

上下文信息：
Table aircrafts_data: aircraft_co

  table_names = db.get_table_names()


**Prepare the `Table` class**

In [11]:
from pydantic import BaseModel, Field

class Table(BaseModel):
    """
    Represents a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.
    """
    name: str = Field(description="Name of table in SQL database.")

### **Strategy A:**

In [12]:
table_names = "\n".join(db.get_usable_table_names())
pprint(table_names)

('aircrafts_data\n'
 'airports_data\n'
 'boarding_passes\n'
 'bookings\n'
 'car_rentals\n'
 'flights\n'
 'hotels\n'
 'seats\n'
 'ticket_flights\n'
 'tickets\n'
 'trip_recommendations')


In [18]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "Search for all hotels"})

[Table(name='hotels')]

### **Strategy B:**

Music:

- "Album"
- "Artist"
- "Genre"
- "MediaType"
- "Playlist"
- "PlaylistTrack"
- "Track"

hotels:

- "hotels"

In [27]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""You will recieve a question.

If the question is about **flights**, return **ALL** these tables:
  - "aircrafts_data"
  - "airports_data"
  - "boarding_passes"
  - "bookings"
  - "flights"
  - "seats"
  - "ticket_flights"
  - "tickets"

If the question is about **hotels**, return **ALL** these tables:
  - "hotels"

If you are unsure, return the full list of all available tables for both Music and Business categories."""
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
category_chain.invoke({"input": "What are all the flights"})

[Table(name='aircrafts_data'),
 Table(name='airports_data'),
 Table(name='boarding_passes'),
 Table(name='bookings'),
 Table(name='flights'),
 Table(name='seats'),
 Table(name='ticket_flights'),
 Table(name='tickets')]

### **Strategy C:**

- **Step 1: Define the category**

In [29]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

flights
hotels"""
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "flights":
            tables.extend(
                [
                    "aircrafts_data",
                    "airports_data",
                    "boarding_passes",
                    "bookings",
                    "flights",
                    "seats",
                    "ticket_flights",
                    "tickets"
                ]
            )
        elif category.name == "hotels":
            tables.extend(["hotels"])
    return tables


table_chain = category_chain | get_tables 

table_chain.invoke({"input": "Search for all flights to Beijing after 2025-05-02"})

['aircrafts_data',
 'airports_data',
 'boarding_passes',
 'bookings',
 'flights',
 'seats',
 'ticket_flights',
 'tickets']

- **Step 2: Execute the python function**

In [31]:
def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "flights":
            tables.extend(
                [
                    "aircrafts_data",
                    "airports_data",
                    "boarding_passes",
                    "bookings",
                    "flights",
                    "seats",
                    "ticket_flights",
                    "tickets"
                ]
            )
        elif category.name == "hotels":
            tables.extend(["hotels"])
    return tables


table_chain = category_chain | get_tables 
table_chain.invoke({"input": "Search for all flights to Beijing after 2025-05-02"})

['aircrafts_data',
 'airports_data',
 'boarding_passes',
 'bookings',
 'flights',
 'seats',
 'ticket_flights',
 'tickets']

### **Final step:**

**Attach the desired strategy to your SQL agent**

In [32]:
# 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)


# 构造输入数据字典，其中包含方言、表结构、问题和行数限制
# input_data = {
#     "dialect": db.dialect,                    # 数据库方言，如 "sqlite"
#     "table_info": db.get_table_info(),          # 表结构信息
#     "question": "What name of MediaType is?",
#     "top_k": 5
# }

In [33]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter

from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

flights
hotels"""
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
category_chain.invoke({"input": "Search for all flights to Beijing after 2025-05-02"})

[Table(name='flights')]

In [48]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain import PromptTemplate
from typing import List

# 1. 定义系统提示和初始化用于提取表名的模型
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

# 2. 构建用于提取表名的链（使用 Pydantic Schema Table，此处假设 Table 已经定义）
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

# 3. 定义将提取的类别映射到具体 SQL 表名的函数
def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.
    
    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.
    
    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "flights":
            tables.extend(
                [
                    "aircrafts_data",
                    "airports_data",
                    "boarding_passes",
                    "bookings",
                    "flights",
                    "seats",
                    "ticket_flights",
                    "tickets"
                ]
            )
        elif category.name == "hotels":
            tables.extend(["hotels"])
    return tables

# 4. 将表名提取链和映射函数通过管道组合
table_chain = category_chain | get_tables 

# 5. 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k", "schema_context"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

# 6. 使用 create_sql_query_chain 构建 SQL 生成链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)

# 7. 利用 bind 固定部分参数，这里传入 dialect、table_info、top_k、以及 schema_context（这里 schema_context 是你提前构造的数据库模式上下文）
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55,
    schema_context=schema_context
)

# 8. 将输入键 "question" 映射为 "input"（因为自定义提示模板需要 "input" 变量）
table_chain = {"input": itemgetter("question")} | table_chain

# 9. 组合链：使用 RunnablePassthrough.assign 将 table_names_to_use 参数赋值给 SQL 生成链
# 这里有两种解决方案：
# 【方案一】如果你希望在 invoke 时手动传入 schema_context，则直接组合：
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

# 【方案二】如果希望后续调用时不必传入 schema_context，则可以在组合后再次绑定：
# full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain.bind(schema_context=schema_context)

# 10. 调用 full_chain 时确保输入包含所有 PromptTemplate 所需的变量
# 如果使用【方案一】，则 invoke 时需要传入 "schema_context"：
query = full_chain.invoke({
    "question": "Search for all flights to Beijing after 2025-05-02",
#     "schema_context": schema_context   # 手动传入 schema_context
})
# 如果使用【方案二】，则只需要传入 {"question": ...}

print(query)
# db.run(query)


SELECT * FROM flights WHERE destination = 'Beijing' AND departure_date > '2025-05-02';


**Test the agent**

In [42]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic

# 系统消息，要求 LLM 返回与问题相关的 SQL 表类别
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""

# 初始化 LLM 模型
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

# 创建提取链：将用户问题转换为 Table 模型的实例
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

# 定义一个函数，根据 Table 对象映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

# 将类别提取链与映射函数组合，得到一个返回 SQL 表名列表的链
table_chain = category_chain | get_tables 

# 定义自定义 SQL 提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

# 创建 SQL 查询链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)

# 利用 bind 将固定参数绑定到 SQL 查询链中
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)

# 将输入中的 "question" 键复制到 "input" 键，同时保留原始数据
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain

# 使用 RunnablePassthrough.assign 将提取到的表名添加到上下文中，然后与 SQL 查询链组合
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

# 调用整个链，生成 SQL 查询
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs?non rep!"}
)
print(query)


SELECT DISTINCT g.Name 
FROM Track t 
JOIN Genre g ON t.GenreId = g.GenreId 
JOIN Album a ON t.AlbumId = a.AlbumId 
JOIN Artist ar ON a.ArtistId = ar.ArtistId 
WHERE ar.Name = 'Alanis Morissette';


In [43]:
db.run(query)

"[('Rock',)]"

**Prepare the tool (Don't run the following cell)**

In [6]:
import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint
from pydantic import BaseModel
from langchain_core.tools import tool
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase

# 定义用于提取表类别的 Pydantic 模型
class Table(BaseModel):
    name: str

# 定义一个映射函数，将类别名称转换为具体的 SQL 表名列表
def get_tables(categories: List[Table]) -> List[str]:
    """根据类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend([
                "Album",
                "Artist",
                "Genre",
                "MediaType",
                "Playlist",
                "PlaylistTrack",
                "Track",
            ])
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

class ChinookSQLAgent:
    """
    一个专门用于 Chinook SQL 数据库查询的 agent，
    利用 LLM 解析用户的问题，自动判断与问题相关的表类别，
    并生成相应的 SQL 查询执行。
    
    属性:
        sql_agent_llm: 用于解析问题和生成 SQL 查询的 LLM 模型。
        db: Chinook 数据库的连接对象。
        full_chain: 一个链条，将用户问题转为 SQL 查询后执行。
    
    构造方法参数:
        sqldb_directory (str): Chinook SQLite 数据库文件所在的目录路径。
        llm (str): LLM 模型名称（例如 "gpt-3.5-turbo"），但内部使用 "llama3-70b-8192"。
        llm_temperature (float): LLM 的温度参数，用于控制生成结果的随机性。
    """
    def __init__(self, sqldb_directory: str, llm: str, llm_temperature: float) -> None:
        # 初始化 LLM 模型（此处使用 "llama3-70b-8192"，由 groq 提供）
        self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temperature)
        
        # 建立到 Chinook SQLite 数据库的连接
        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print("可用表:", self.db.get_usable_table_names())
        
        # 定义系统提示，指导 LLM 根据用户问题返回相关的表类别
        category_chain_system = (
            "Return the names of the SQL tables that are relevant to the user question. "
            "The tables are:\n\nMusic\nBusiness"
        )
        # 创建提取链，从用户问题中提取表类别（使用 Pydantic 模型 Table）
        category_chain = create_extraction_chain_pydantic(Table, self.sql_agent_llm, system_message=category_chain_system)
        # 将提取到的类别转换为具体的 SQL 表名
        table_chain = category_chain | get_tables
        
        # 定义自定义 SQL 提示模板
        custom_prompt = PromptTemplate(
            input_variables=["dialect", "input", "table_info", "top_k"],
            template=(
                "You are a SQL expert using {dialect}.\n"
                "Given the following table schema:\n"
                "{table_info}\n"
                "Generate a syntactically correct SQL query to answer the question: \"{input}\".\n"
                "Don't limit the results to {top_k} rows.\n"
                "Ensure the query uses DISTINCT to avoid duplicate rows.\n"
                "Return only the SQL query without any additional commentary or Markdown formatting."
            )
        )
        # 利用自定义提示模板创建 SQL 查询链
        query_chain = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
        
        # 将输入中的 "question" 键转换为 table_chain 所需的 "input" 键
        table_chain = {"input": itemgetter("question")} | table_chain
        
        # 利用 RunnablePassthrough.assign 将提取到的 table_names 注入上下文，然后通过管道传递给 SQL 查询链
        self.full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

    def run(self, query: str) -> str:
        """
        接收用户的查询，将问题转化为 SQL 查询语句，然后在 Chinook 数据库中执行。
        
        参数:
            query (str): 用户的自然语言查询，例如 "What are all the genres of Alanis Morisette songs? Do not repeat!"
        
        返回:
            str: 执行 SQL 查询后的结果
        """
        # 调用完整的链条生成 SQL 查询语句
        sql_query = self.full_chain.invoke({"question": query})
        # 执行生成的 SQL 查询并返回结果
        return self.db.run(sql_query)

# 使用 @tool 装饰器将查询功能暴露为一个工具
@tool
def query_chinook_sqldb(query: str) -> str:
    """
    查询 Chinook SQL 数据库的工具函数，输入为用户的查询语句。
    
    该函数会实例化一个 ChinookSQLAgent，然后调用其 run 方法来处理查询。
    """
    # 注意：sqldb_directory 应该是数据库文件所在路径，这里假设它是全局变量或配置项
    sqldb_directory = here("data/Chinook.db")
    agent = ChinookSQLAgent(
        sqldb_directory=sqldb_directory,  # 如 TOOLS_CFG.chinook_sqldb_directory
        llm="llama3-70b-8192",              # 如 TOOLS_CFG.chinook_sqlagent_llm
        llm_temperature=0
    )
    return agent.run(query)


query_chinook_sqldb('What are all the genres of Alanis Morisette songs')


可用表: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[('Rock',)]"