In [5]:
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool

import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint

from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

# from langchain_community.tools.sql_database.tool import (
#     InfoSQLDatabaseTool,
#     ListSQLDatabaseTool,
#     QuerySQLCheckerTool,
#     QuerySQLDatabaseTool,
# )

# 强制覆盖已存在的环境变量
load_dotenv(override=True)



# def get_engine_for_chinook_db():
#     """Pull sql file, populate in-memory database, and create engine."""
#     url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
#     response = requests.get(url)
#     sql_script = response.text

#     connection = sqlite3.connect(":memory:", check_same_thread=False)
#     connection.executescript(sql_script)
#     return create_engine(
#         "sqlite://",
#         creator=lambda: connection,
#         poolclass=StaticPool,
#         connect_args={"check_same_thread": False},
#     )


# engine = get_engine_for_chinook_db()

# db = SQLDatabase(engine)

True

In [6]:
sqldb_directory = here("data/travel2.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


In [7]:
import getpass
import os

if not os.environ.get("GROQ_API_KEY"):
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("llama3-70b-8192", model_provider="groq")

In [8]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

toolkit.get_tools()

In [9]:
toolkit.get_tools()

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001DE47A8D050>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001DE47A8D050>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001DE47A8D050>),
 QuerySQLCheckerTool(description='Use this tool to 

In [37]:
from langchain_core.tools import tool
from datetime import date, datetime
from typing import Optional
from langchain_core.runnables import ensure_config
import json

@tool
def search_flights(
    departure_airport: Optional[str] = None,
    arrival_airport: Optional[str] = None,
    start_time: Optional[date | datetime] = None,
    end_time: Optional[date | datetime] = None,
    limit: int = 20,
) -> list[dict]:
    """Search for flights based on departure airport, arrival airport, and departure time range.

    Args:
        departure_airport (Optional[str]): The airport code from which the flight is departing.
        arrival_airport (Optional[str]): The airport code where the flight is arriving.
        start_time (Optional[date | datetime]): The earliest departure time to filter flights.
        end_time (Optional[date | datetime]): The latest departure time to filter flights.
        limit (int): The maximum number of flights to return. Defaults to 20.

    Returns:
        list[dict]: A list of dictionaries where each dictionary contains the details of a flight
            including all columns from the flights table.
    """
    conn = sqlite3.connect(sqldb_directory)
    cursor = conn.cursor()

    query = "SELECT * FROM flights WHERE 1 = 1"
    params = []

    if departure_airport:
        query += " AND departure_airport = ?"
        params.append(departure_airport)

    if arrival_airport:
        query += " AND arrival_airport = ?"
        params.append(arrival_airport)

    if start_time:
        query += " AND scheduled_departure >= ?"
        params.append(start_time)

    if end_time:
        query += " AND scheduled_departure <= ?"
        params.append(end_time)

    query += " LIMIT ?"
    params.append(limit)
    cursor.execute(query, params)
    rows = cursor.fetchall()
    column_names = [column[0] for column in cursor.description]
    results = [dict(zip(column_names, row)) for row in rows]

    cursor.close()
    conn.close()

    # return results
    return json.dumps(results, ensure_ascii=False, indent=2)


@tool
def fetch_user_flight_information() -> list[dict]:
    """Fetch all tickets for the user along with corresponding flight information and seat assignments.

    Returns:
        A list of dictionaries where each dictionary contains the ticket details,
        associated flight details, and the seat assignments for each ticket belonging to the user.
    """
    config = ensure_config()  # Fetch from the context
    configuration = config.get("configurable", {})
    passenger_id = configuration.get("passenger_id", None)
    if not passenger_id:
        raise ValueError("No passenger ID configured.")

    conn = sqlite3.connect(sqldb_directory)
    cursor = conn.cursor()

    query = """
    SELECT 
        t.ticket_no, t.book_ref,
        f.flight_id, f.flight_no, f.departure_airport, f.arrival_airport, f.scheduled_departure, f.scheduled_arrival,
        bp.seat_no, tf.fare_conditions
    FROM 
        tickets t
        JOIN ticket_flights tf ON t.ticket_no = tf.ticket_no
        JOIN flights f ON tf.flight_id = f.flight_id
        JOIN boarding_passes bp ON bp.ticket_no = t.ticket_no AND bp.flight_id = f.flight_id
    WHERE 
        t.passenger_id = ?
    """
    cursor.execute(query, (passenger_id,))
    rows = cursor.fetchall()
    column_names = [column[0] for column in cursor.description]
    results = [dict(zip(column_names, row)) for row in rows]

    cursor.close()
    conn.close()

    return results

In [38]:
# 从 toolkit 获取已有工具列表
existing_tools = toolkit.get_tools()

# 将 search_flights 和 fetch_user_flight_information 两个工具加入列表
all_tools = existing_tools + [search_flights, fetch_user_flight_information]

# 使用合并后的工具列表创建 react agent
# agent_executor = create_react_agent(llm, all_tools)


In [2]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)



['dialect', 'top_k']


In [3]:
system_message = prompt_template.format(dialect="SQLite", top_k=5)

In [4]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, all_tools)

NameError: name 'llm' is not defined

In [1]:
example_query = "Search for all flights to Beijing on 2024-05-02"#"Hi there, what time is my flight? my passenger id is 8149 604011"

config = {
    "configurable": {
        # The passenger_id is used in our flight tools to
        # fetch the user's flight information
        "passenger_id": "3442 587242",
        # Checkpoints are accessed by thread_id
        # "thread_id": thread_id,
        "recursion_limit": 50
    }
}

events = agent_executor.stream(
    {"messages": [("user", example_query)]},config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

NameError: name 'agent_executor' is not defined

In [4]:
from langchain.chat_models import ChatOpenAI
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import AgentType, tool, create_sql_agent

@tool
def my_first_awesome_tool(human_message: str) -> list:
    """
    Searches for my awesome tool that works.
    Input should be a human message as string in the form of a question or request.
    It returns a tuple.
    :param human_message: This is the input string.
    :return: A tuple.
    """
    return "it","works"

@tool
def my_second_awesome_tool(mfat_output: str) -> str:
    """
    Combines the output from my awesome tool that works into a string.
    Input should be a string.
    It returns a string with relevant information
    :param mfat_output: A string.
    :return: A string.
    """
    return " ".join(mfat_output.split(','))

prefix_template = """You are an agent designed to interact with a SQL database.

DO NOT check their schemas to understand their structure.

You do not care about the database schema.

You will ALWAYS search for my awesome tool that works.

You will ALWAYS combine the output from my awesome tool that works into a string. Create a short rhyme with the output and this will be your final answer.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I do not know" as the answer.
"""

format_instructions_template = """Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question"""

# llm = ChatOpenAI(model_name = LLM_MODEL, temperature = 0)
toolkit = SQLDatabaseToolkit(db =db, llm = llm)
agent = create_sql_agent(llm = llm,
                         toolkit = toolkit,
                         verbose = True,
                         prefix = prefix_template,
                         format_instructions = format_instructions_template,
                         agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                         # agent_type = AgentType.OPENAI_FUNCTIONS,
                         extra_tools = [my_first_awesome_tool, my_second_awesome_tool]
                        )
agent.run("Give me something that works.")

  agent.run("Give me something that works.")




[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3maircrafts_data, airports_data, boarding_passes, bookings, car_rentals, flights, hotels, seats, ticket_flights, tickets, trip_recommendations[0m[32;1m[1;3mThought: I should query the schema of the tables to see what fields are available.
Action: sql_db_schema
Action Input: aircrafts_data, airports_data, boarding_passes, bookings, car_rentals, flights, hotels, seats, ticket_flights, tickets, trip_recommendations[0m[33;1m[1;3m
CREATE TABLE aircrafts_data (
	aircraft_code TEXT, 
	model TEXT, 
	range INTEGER
)

/*
3 rows from aircrafts_data table:
aircraft_code	model	range
773	Boeing 777-300	11100
763	Boeing 767-300	7900
SU9	Sukhoi Superjet-100	3000
*/


CREATE TABLE airports_data (
	airport_code TEXT, 
	airport_name TEXT,

'"It works, no worries, no strife, my awesome tool is always rife!"'