In [1]:
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool

import os
import getpass
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from pprint import pprint

from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

# from langchain_community.tools.sql_database.tool import (
#     InfoSQLDatabaseTool,
#     ListSQLDatabaseTool,
#     QuerySQLCheckerTool,
#     QuerySQLDatabaseTool,
# )

# 强制覆盖已存在的环境变量
load_dotenv(override=True)



# def get_engine_for_chinook_db():
#     """Pull sql file, populate in-memory database, and create engine."""
#     url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
#     response = requests.get(url)
#     sql_script = response.text

#     connection = sqlite3.connect(":memory:", check_same_thread=False)
#     connection.executescript(sql_script)
#     return create_engine(
#         "sqlite://",
#         creator=lambda: connection,
#         poolclass=StaticPool,
#         connect_args={"check_same_thread": False},
#     )


# engine = get_engine_for_chinook_db()

# db = SQLDatabase(engine)

True

In [2]:
sqldb_directory = here("data/travel2.sqlite")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'car_rentals', 'flights', 'hotels', 'seats', 'ticket_flights', 'tickets', 'trip_recommendations']


In [3]:
import getpass
import os

if not os.environ.get("GROQ_API_KEY"):
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("llama3-70b-8192", model_provider="groq")

In [22]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [23]:
toolkit.get_tools()

[QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000014955DE2AD0>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000014955DE2AD0>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x0000014955DE2AD0>),
 QuerySQLCheckerTool(description='Use this tool to 

In [12]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)

['dialect', 'top_k']


In [15]:
system_message = prompt_template.format(dialect="SQLite", top_k=5)

In [24]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, toolkit.get_tools())

In [36]:
example_query = "Hi there, what time is my flight? my passenger id is 8149 604011"

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()


Hi there, what time is my flight? my passenger id is 8149 604011
Tool Calls:
  sql_db_query (call_pqhp)
 Call ID: call_pqhp
  Args:
    query: SELECT * FROM flight_schedules WHERE passenger_id = '8149 604011'
Name: sql_db_query

Error: (sqlite3.OperationalError) no such table: flight_schedules
[SQL: SELECT * FROM flight_schedules WHERE passenger_id = '8149 604011']
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Tool Calls:
  sql_db_list_tables (call_6jj0)
 Call ID: call_6jj0
  Args:
    tool_input:
Name: sql_db_list_tables

aircrafts_data, airports_data, boarding_passes, bookings, car_rentals, flights, hotels, seats, ticket_flights, tickets, trip_recommendations
Tool Calls:
  sql_db_schema (call_gmas)
 Call ID: call_gmas
  Args:
    table_names: flights, bookings
Name: sql_db_schema


CREATE TABLE bookings (
	book_ref TEXT, 
	book_date TIMESTAMP, 
	total_amount INTEGER
)

/*
3 rows from bookings table:
book_ref	book_date	total_amount
00000F	2024-03-20 01:21:03.561731+00:

In [4]:
from langchain.chat_models import ChatOpenAI
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import AgentType, tool, create_sql_agent

@tool
def my_first_awesome_tool(human_message: str) -> list:
    """
    Searches for my awesome tool that works.
    Input should be a human message as string in the form of a question or request.
    It returns a tuple.
    :param human_message: This is the input string.
    :return: A tuple.
    """
    return "it","works"

@tool
def my_second_awesome_tool(mfat_output: str) -> str:
    """
    Combines the output from my awesome tool that works into a string.
    Input should be a string.
    It returns a string with relevant information
    :param mfat_output: A string.
    :return: A string.
    """
    return " ".join(mfat_output.split(','))

prefix_template = """You are an agent designed to interact with a SQL database.

DO NOT check their schemas to understand their structure.

You do not care about the database schema.

You will ALWAYS search for my awesome tool that works.

You will ALWAYS combine the output from my awesome tool that works into a string. Create a short rhyme with the output and this will be your final answer.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I do not know" as the answer.
"""

format_instructions_template = """Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question"""

# llm = ChatOpenAI(model_name = LLM_MODEL, temperature = 0)
toolkit = SQLDatabaseToolkit(db =db, llm = llm)
agent = create_sql_agent(llm = llm,
                         toolkit = toolkit,
                         verbose = True,
                         prefix = prefix_template,
                         format_instructions = format_instructions_template,
                         agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                         # agent_type = AgentType.OPENAI_FUNCTIONS,
                         extra_tools = [my_first_awesome_tool, my_second_awesome_tool]
                        )
agent.run("Give me something that works.")

  agent.run("Give me something that works.")




[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3maircrafts_data, airports_data, boarding_passes, bookings, car_rentals, flights, hotels, seats, ticket_flights, tickets, trip_recommendations[0m[32;1m[1;3mThought: I should query the schema of the tables to see what fields are available.
Action: sql_db_schema
Action Input: aircrafts_data, airports_data, boarding_passes, bookings, car_rentals, flights, hotels, seats, ticket_flights, tickets, trip_recommendations[0m[33;1m[1;3m
CREATE TABLE aircrafts_data (
	aircraft_code TEXT, 
	model TEXT, 
	range INTEGER
)

/*
3 rows from aircrafts_data table:
aircraft_code	model	range
773	Boeing 777-300	11100
763	Boeing 767-300	7900
SU9	Sukhoi Superjet-100	3000
*/


CREATE TABLE airports_data (
	airport_code TEXT, 
	airport_name TEXT,

'"It works, no worries, no strife, my awesome tool is always rife!"'