**Reference: https://python.langchain.com/v0.1/docs/use_cases/sql/large_db/**

What happens in this notebook:

### **Table Model Definition**
   - **`Table` Class**: This is a simple Pydantic model representing a SQL table. It has one attribute, `name`, which is a string and is described as "Name of table in SQL database."
     - This model is used in the extraction process to match relevant SQL tables based on the user's query.

### **Helper Function - `get_tables`**
   - **`get_tables`**: This function takes a list of `Table` objects (i.e., categories such as "Music" or "Business") and returns a list of corresponding SQL table names based on the category.
     - For example, if the category is `"Music"`, the tables `"Album"`, `"Artist"`, `"Genre"`, etc., are added to the result.
     - Similarly, for `"Business"`, the corresponding tables like `"Customer"`, `"Employee"`, etc., are included.

### **Designing the agent for the large DB**

- **Step 1: Initialize LLM (`sql_agent_llm`)**: The LLM is instantiated with a given model (e.g., `"gpt-3.5-turbo"`) and temperature. The temperature controls how creative/random the model's responses are.
- **Step 2: Connect to the SQL Database (`db`)**: The connection to the Chinook SQLite database is established. The database URI is constructed using the `sqldb_directory` provided.
- **Step 3: Define Category Chain (`category_chain`)**: The `category_chain_system` is defined, which is a string explaining the categories available (like "Music" and "Business"). This chain determines which SQL tables are relevant to the user query based on the category.
- **Step 4: Chain Creation**:
- **`category_chain`**: This uses the `create_extraction_chain_pydantic` function, which creates an extraction chain that identifies relevant SQL tables from the user's question using the `Table` Pydantic model and the LLM.
- **`table_chain`**: A chain is formed by combining the output from `category_chain` with the `get_tables` function, so it maps categories to the actual SQL tables.
- **Step 5: Query Chain (`query_chain`)**: This creates a SQL query chain using the LLM and the database (`self.db`). It takes the SQL tables and constructs a query.
- **Step 6: Table Chain Input Handling**: The `"question"` key from the user input is mapped to the `"input"` key expected by the `table_chain`. This enables the chain to process user queries correctly.
- **Step 7: Full Chain Construction**: Finally, the full chain (`full_chain`) is created by combining:
1. **`RunnablePassthrough.assign`**: This sets up a step that assigns the `table_names_to_use` using the result of the `table_chain`.
2. **`query_chain`**: Executes the SQL query once the relevant tables are identified.

In [7]:
import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint

from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

load_dotenv()

True

**Set the environment variables and load the LLM**

In [8]:
# 如果没有设置 GROQ_API_KEY，则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
#  初始化 Llama 模型，使用 Groq 后端
table_extractor_llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)

In [9]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [12]:
db.run('''
    SELECT DISTINCT "g"."Name"
    FROM "Track" AS "t"
    JOIN "Album" AS "a" ON "t"."AlbumId" = "a"."AlbumId"
    JOIN "Genre" AS "g" ON "t"."GenreId" = "g"."GenreId"
    JOIN "Artist" AS "ar" ON "a"."ArtistId" = "ar"."ArtistId"
    WHERE "ar"."Name" = 'Alanis Morissette'
    ORDER BY "g"."Name";
''')


"[('Rock',)]"

**Prepare the `Table` class**

In [4]:
from pydantic import BaseModel, Field

class Table(BaseModel):
    """
    Represents a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.
    """
    name: str = Field(description="Name of table in SQL database.")

### **Strategy A:**

In [6]:
table_names = "\n".join(db.get_usable_table_names())
pprint(table_names)

('Album\n'
 'Artist\n'
 'Customer\n'
 'Employee\n'
 'Genre\n'
 'Invoice\n'
 'InvoiceLine\n'
 'MediaType\n'
 'Playlist\n'
 'PlaylistTrack\n'
 'Track')


In [8]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

  table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)


[Table(name='Genre'),
 Table(name='Artist'),
 Table(name='Track'),
 Table(name='Album')]

### **Strategy B:**

Music:

- "Album"
- "Artist"
- "Genre"
- "MediaType"
- "Playlist"
- "PlaylistTrack"
- "Track"

Business:

- "Customer"
- "Employee"
- "Invoice"
- "InvoiceLine"

In [10]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""You will recieve a question.

If the question is about **Music**, return **ALL** these tables:
  - "Album"
  - "Artist"
  - "Genre"
  - "MediaType"
  - "Playlist"
  - "PlaylistTrack"
  - "Track"

If the question is about **Business**, return **ALL** these tables:
  - "Customer"
  - "Employee"
  - "Invoice"
  - "InvoiceLine"

If you are unsure, return the full list of all available tables for both Music and Business categories."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Album'),
 Table(name='Artist'),
 Table(name='Genre'),
 Table(name='MediaType'),
 Table(name='Playlist'),
 Table(name='PlaylistTrack'),
 Table(name='Track')]

### **Strategy C:**

- **Step 1: Define the category**

In [19]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables 

In [18]:
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

- **Step 2: Execute the python function**

In [20]:
def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables 
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

### **Final step:**

**Attach the desired strategy to your SQL agent**

In [24]:
# 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)


# 构造输入数据字典，其中包含方言、表结构、问题和行数限制
# input_data = {
#     "dialect": db.dialect,                    # 数据库方言，如 "sqlite"
#     "table_info": db.get_table_info(),          # 表结构信息
#     "question": "What name of MediaType is?",
#     "top_k": 5
# }

In [25]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter

from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables 

# 定义自定义提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

query_chain = create_sql_query_chain(table_extractor_llm, db,prompt=custom_prompt)

# 利用 bind 绑定固定参数到链中
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)
        
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

**Test the agent**

In [42]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic

# 系统消息，要求 LLM 返回与问题相关的 SQL 表类别
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""

# 初始化 LLM 模型
table_extractor_llm = init_chat_model("llama3-70b-8192", model_provider="groq", temperature=0)

# 创建提取链：将用户问题转换为 Table 模型的实例
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

# 定义一个函数，根据 Table 对象映射到具体的 SQL 表名
def get_tables(categories: List[Table]) -> List[str]:
    """将类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

# 将类别提取链与映射函数组合，得到一个返回 SQL 表名列表的链
table_chain = category_chain | get_tables 

# 定义自定义 SQL 提示模板，用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Don't limit the results to {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

# 创建 SQL 查询链
query_chain = create_sql_query_chain(table_extractor_llm, db, prompt=custom_prompt)

# 利用 bind 将固定参数绑定到 SQL 查询链中
bound_chain = query_chain.bind(
    dialect=db.dialect,
    table_info=db.get_table_info(),
    top_k=55
)

# 将输入中的 "question" 键复制到 "input" 键，同时保留原始数据
table_chain = (lambda x: {**x, "input": x["question"]}) | table_chain

# 使用 RunnablePassthrough.assign 将提取到的表名添加到上下文中，然后与 SQL 查询链组合
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | bound_chain

# 调用整个链，生成 SQL 查询
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs?non rep!"}
)
print(query)


SELECT DISTINCT g.Name 
FROM Track t 
JOIN Genre g ON t.GenreId = g.GenreId 
JOIN Album a ON t.AlbumId = a.AlbumId 
JOIN Artist ar ON a.ArtistId = ar.ArtistId 
WHERE ar.Name = 'Alanis Morissette';


In [43]:
db.run(query)

"[('Rock',)]"

**Prepare the tool (Don't run the following cell)**

In [6]:
import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint
from pydantic import BaseModel
from langchain_core.tools import tool
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase

# 定义用于提取表类别的 Pydantic 模型
class Table(BaseModel):
    name: str

# 定义一个映射函数，将类别名称转换为具体的 SQL 表名列表
def get_tables(categories: List[Table]) -> List[str]:
    """根据类别名称映射到对应的 SQL 表名列表."""
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend([
                "Album",
                "Artist",
                "Genre",
                "MediaType",
                "Playlist",
                "PlaylistTrack",
                "Track",
            ])
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

class ChinookSQLAgent:
    """
    一个专门用于 Chinook SQL 数据库查询的 agent，
    利用 LLM 解析用户的问题，自动判断与问题相关的表类别，
    并生成相应的 SQL 查询执行。
    
    属性:
        sql_agent_llm: 用于解析问题和生成 SQL 查询的 LLM 模型。
        db: Chinook 数据库的连接对象。
        full_chain: 一个链条，将用户问题转为 SQL 查询后执行。
    
    构造方法参数:
        sqldb_directory (str): Chinook SQLite 数据库文件所在的目录路径。
        llm (str): LLM 模型名称（例如 "gpt-3.5-turbo"），但内部使用 "llama3-70b-8192"。
        llm_temperature (float): LLM 的温度参数，用于控制生成结果的随机性。
    """
    def __init__(self, sqldb_directory: str, llm: str, llm_temperature: float) -> None:
        # 初始化 LLM 模型（此处使用 "llama3-70b-8192"，由 groq 提供）
        self.sql_agent_llm = init_chat_model(llm, model_provider="groq", temperature=llm_temperature)
        
        # 建立到 Chinook SQLite 数据库的连接
        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print("可用表:", self.db.get_usable_table_names())
        
        # 定义系统提示，指导 LLM 根据用户问题返回相关的表类别
        category_chain_system = (
            "Return the names of the SQL tables that are relevant to the user question. "
            "The tables are:\n\nMusic\nBusiness"
        )
        # 创建提取链，从用户问题中提取表类别（使用 Pydantic 模型 Table）
        category_chain = create_extraction_chain_pydantic(Table, self.sql_agent_llm, system_message=category_chain_system)
        # 将提取到的类别转换为具体的 SQL 表名
        table_chain = category_chain | get_tables
        
        # 定义自定义 SQL 提示模板
        custom_prompt = PromptTemplate(
            input_variables=["dialect", "input", "table_info", "top_k"],
            template=(
                "You are a SQL expert using {dialect}.\n"
                "Given the following table schema:\n"
                "{table_info}\n"
                "Generate a syntactically correct SQL query to answer the question: \"{input}\".\n"
                "Don't limit the results to {top_k} rows.\n"
                "Ensure the query uses DISTINCT to avoid duplicate rows.\n"
                "Return only the SQL query without any additional commentary or Markdown formatting."
            )
        )
        # 利用自定义提示模板创建 SQL 查询链
        query_chain = create_sql_query_chain(self.sql_agent_llm, self.db, prompt=custom_prompt)
        
        # 将输入中的 "question" 键转换为 table_chain 所需的 "input" 键
        table_chain = {"input": itemgetter("question")} | table_chain
        
        # 利用 RunnablePassthrough.assign 将提取到的 table_names 注入上下文，然后通过管道传递给 SQL 查询链
        self.full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

    def run(self, query: str) -> str:
        """
        接收用户的查询，将问题转化为 SQL 查询语句，然后在 Chinook 数据库中执行。
        
        参数:
            query (str): 用户的自然语言查询，例如 "What are all the genres of Alanis Morisette songs? Do not repeat!"
        
        返回:
            str: 执行 SQL 查询后的结果
        """
        # 调用完整的链条生成 SQL 查询语句
        sql_query = self.full_chain.invoke({"question": query})
        # 执行生成的 SQL 查询并返回结果
        return self.db.run(sql_query)

# 使用 @tool 装饰器将查询功能暴露为一个工具
@tool
def query_chinook_sqldb(query: str) -> str:
    """
    查询 Chinook SQL 数据库的工具函数，输入为用户的查询语句。
    
    该函数会实例化一个 ChinookSQLAgent，然后调用其 run 方法来处理查询。
    """
    # 注意：sqldb_directory 应该是数据库文件所在路径，这里假设它是全局变量或配置项
    sqldb_directory = here("data/Chinook.db")
    agent = ChinookSQLAgent(
        sqldb_directory=sqldb_directory,  # 如 TOOLS_CFG.chinook_sqldb_directory
        llm="llama3-70b-8192",              # 如 TOOLS_CFG.chinook_sqlagent_llm
        llm_temperature=0
    )
    return agent.run(query)


query_chinook_sqldb('What are all the genres of Alanis Morisette songs')


可用表: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[('Rock',)]"