In [None]:
from langchain.sql_database import SQLDatabase
from langchain_groq import ChatGroq
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from pprint import pprint
from typing_extensions import TypedDict

import os
from dotenv import load_dotenv



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from langchain_openai.chat_models.azure import AzureChatOpenAI
* 'allow_population_by_field_name' has been renamed to 'populate_by_name'


True

In [4]:
if not os.environ.get("GROQ_API_KEY"):
  os.environ["GROQ_API_KEY"] = os.environ.get("GROQ_API_KEY")

In [5]:
from langchain_groq import ChatGroq

llm = ChatGroq(model="llama3-8b-8192")

In [6]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()




Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [7]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///travel.sqlite")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM airports_data LIMIT 10;")

sqlite
['aircrafts_data', 'airports_data', 'boarding_passes', 'bookings', 'flights', 'seats', 'ticket_flights', 'tickets']


'[(\'YKS\', \'{"en": "Yakutsk Airport", "ru": "Якутск"}\', \'{"en": "Yakutsk", "ru": "Якутск"}\', \'(129.77099609375,62.0932998657226562)\', \'Asia/Yakutsk\'), (\'MJZ\', \'{"en": "Mirny Airport", "ru": "Мирный"}\', \'{"en": "Mirnyj", "ru": "Мирный"}\', \'(114.03900146484375,62.534698486328125)\', \'Asia/Yakutsk\'), (\'KHV\', \'{"en": "Khabarovsk-Novy Airport", "ru": "Хабаровск-Новый"}\', \'{"en": "Khabarovsk", "ru": "Хабаровск"}\', \'(135.18800354004,48.5279998779300001)\', \'Asia/Vladivostok\'), (\'PKC\', \'{"en": "Yelizovo Airport", "ru": "Елизово"}\', \'{"en": "Petropavlovsk", "ru": "Петропавловск-Камчатский"}\', \'(158.453994750976562,53.1679000854492188)\', \'Asia/Kamchatka\'), (\'UUS\', \'{"en": "Yuzhno-Sakhalinsk Airport", "ru": "Хомутово"}\', \'{"en": "Yuzhno-Sakhalinsk", "ru": "Южно-Сахалинск"}\', \'(142.718002319335938,46.8886985778808594)\', \'Asia/Sakhalin\'), (\'VVO\', \'{"en": "Vladivostok International Airport", "ru": "Владивосток"}\', \'{"en": "Vladivostok", "ru": "Влад

In [18]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7f7a22dd8af0>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7f7a22dd8af0>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7f7a22dd8af0>),
 QuerySQLCheckerTool(description='Use this tool to double check

In [19]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
prompt_template.messages[0].pretty_print()




You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

To start you should ALWAYS look at

In [20]:
system_message = prompt_template.format(dialect="SQLite", top_k=5)

In [21]:
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, state_modifier=system_message)

In [25]:
question = "what is the name of tables ?"

for step in agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


what is the name of tables ?
Tool Calls:
  sql_db_list_tables (call_ngrm)
 Call ID: call_ngrm
  Args:
    tool_input:
Name: sql_db_list_tables

aircrafts_data, airports_data, boarding_passes, bookings, flights, seats, ticket_flights, tickets
Tool Calls:
  sql_db_schema (call_sde4)
 Call ID: call_sde4
  Args:
    table_names: aircrafts_data, airports_data
Name: sql_db_schema

Error: TypeError('must be real number, not str')
 Please fix your mistakes.
Tool Calls:
  sql_db_schema (call_bb41)
 Call ID: call_bb41
  Args:
    table_names: 'aircrafts_data, airports_data'
Name: sql_db_schema

Error: table_names {"'aircrafts_data", "airports_data'"} not found in database
Tool Calls:
  sql_db_list_tables (call_03yq)
 Call ID: call_03yq
  Args:
    tool_input:
Name: sql_db_list_tables

aircrafts_data, airports_data, boarding_passes, bookings, flights, seats, ticket_flights, tickets
Tool Calls:
  sql_db_schema (call_fjqb)
 Call ID: call_fjqb
  Args:
    table_names: flights, bookings
Name: sql_db

In [45]:
question = "Describe the tables of database "

for step in agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Describe the tables of database 
Tool Calls:
  sql_db_list_tables (call_mzf5)
 Call ID: call_mzf5
  Args:
Name: sql_db_list_tables

aircrafts_data, airports_data, boarding_passes, bookings, flights, seats, ticket_flights, tickets

Let me examine the tables you provided.

It seems like we have a database related to the aviation industry. The tables provided are:

* aircrafts_data: This table likely contains information about the aircrafts used for flights, such as their models, capacities, and other relevant details.
* airports_data: This table might contain information about the airports, such as their locations, codes, and other relevant details.
* boarding_passes: This table probably contains information about the boarding passes issued to passengers, including details such as flight numbers, passenger names, and seat numbers.
* bookings: This table likely contains information about the bookings made by passengers, including details such as booking dates, flight numbers, and passeng

In [47]:
print(step["messages"][-1].content)

Let me examine the tables you provided.

It seems like we have a database related to the aviation industry. The tables provided are:

* aircrafts_data: This table likely contains information about the aircrafts used for flights, such as their models, capacities, and other relevant details.
* airports_data: This table might contain information about the airports, such as their locations, codes, and other relevant details.
* boarding_passes: This table probably contains information about the boarding passes issued to passengers, including details such as flight numbers, passenger names, and seat numbers.
* bookings: This table likely contains information about the bookings made by passengers, including details such as booking dates, flight numbers, and passenger names.
* flights: This table probably contains information about the flights, including details such as flight numbers, departure and arrival dates, and departure and arrival airports.
* seats: This table might contain informatio

In [28]:
question = "What is the total total_amount "

for step in agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


What is the total total_amount 
Tool Calls:
  sql_db_list_tables (call_8rqz)
 Call ID: call_8rqz
  Args:
  sql_db_schema (call_zsnw)
 Call ID: call_zsnw
  Args:
    table_names: 
  sql_db_query (call_p7g8)
 Call ID: call_p7g8
  Args:
    query: SELECT SUM(total_amount) FROM table_name
Name: sql_db_query

Error: (sqlite3.OperationalError) no such table: table_name
[SQL: SELECT SUM(total_amount) FROM table_name]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Tool Calls:
  sql_db_list_tables (call_r93e)
 Call ID: call_r93e
  Args:
  sql_db_schema (call_fsk3)
 Call ID: call_fsk3
  Args:
    table_names: bookings,ticket_flights
Name: sql_db_schema

Error: TypeError('must be real number, not str')
 Please fix your mistakes.
Tool Calls:
  sql_db_query (call_qxdk)
 Call ID: call_qxdk
  Args:
    query: SELECT SUM(total_amount) FROM bookings
Name: sql_db_query

[(20766980900,)]

The total total_amount is 20766980900.


In [41]:
step["messages"][-1].content

'The total total_amount is 20766980900.'