In [41]:
# Setting up api key for langsmith

import os,getpass 

os.environ['LANGSMITH_TRACING'] = 'true' 
if not os.environ.get('LANGSMITH_API_KEY'): 
    os.environ['LANGSMITH_API_KEY'] = getpass.getpass('Enter Your LangSmith Api Key') 

In [42]:
# # setting up api key for gimini 


# if not os.environ.get('GOOGLE_API_KEY'): 
#     os.environ['GOOGLE_API_KEY'] = getpass.getpass('Enter Your Google Api Key') 


# from google import genai

# # The client gets the API key from the environment variable `GEMINI_API_KEY`.
# client = genai.Client(api_key=os.environ.get('GOOGLE_API_KEY'))

# response = client.models.generate_content(
#     model="gemini-3-flash-preview", contents="Explain how AI works in a few words"
# )
# print(response.text)

# print('list of available models:') 
# for m in client.models.list(): 
#     print(m.name)

In [43]:
# # setting up the model 

# from langchain_google_genai import ChatGoogleGenerativeAI 

# model = ChatGoogleGenerativeAI(model='gemini-2.0-flash-001') 

In [44]:
if not os.environ.get('GROQ_API_KEY'): 
    os.environ['GROQ_API_KEY'] = getpass.getpass('Enter Your GROQ API KEY') 


In [45]:
from langchain_groq import ChatGroq 
model = ChatGroq(model='llama-3.3-70b-versatile')

In [46]:
# Downloading the database 

import requests,pathlib 

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path('Chinook.db') 

if local_path.exists(): 
    print(f'{local_path} already exists, skipping download') 
else: 
    response = requests.get(url) 
    if response.status_code == 200: 
        local_path.write_bytes(response.content) 
        print(f'File downloaded and saved as {local_path}') 
    else: 
        print(f'Failed to download the file.Status code: {response.status_code}')

Chinook.db already exists, skipping download


In [47]:
from langchain_community.utilities import SQLDatabase 

db = SQLDatabase.from_uri("sqlite:///Chinook.db") 

print(f'dialect: {db.dialect}') 
print(f'available tabels: {db.get_usable_table_names()}') 
print(f'sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')

dialect: sqlite
available tabels: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]


In [48]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit 

toolkit = SQLDatabaseToolkit(db=db,llm=model) 

tools = toolkit.get_tools() 

for tool in tools: 
    print(f'{tool.name}: {tool.description}\n')

sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.

sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3

sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.

sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!



In [49]:
system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect=db.dialect,
    top_k=5,
)

In [50]:
print(system_prompt)


You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct sqlite query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most 5 results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.



In [51]:
from langchain.agents import create_agent 

agent = create_agent( model, 
                      tools,
                      system_prompt=system_prompt,
                    ) 

In [52]:
question = 'How many tabels are there in the Database?'

for step in agent.stream( 
    {'messages':[{'role':'user','content':question}]},
    stream_mode='values', 
): 
    step['messages'][-1].pretty_print()


How many tabels are there in the Database?
Tool Calls:
  sql_db_list_tables (51h7968vb)
 Call ID: 51h7968vb
  Args:
    tool_input:
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Tool Calls:
  sql_db_schema (cxvy6tc2a)
 Call ID: cxvy6tc2a
  Args:
    table_names: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Name: sql_db_schema


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aero

In [53]:
l=db.run("SELECT count(name) FROM sqlite_master WHERE type='table';")
print(l) 


[(11,)]


In [54]:
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware 
from langgraph.checkpoint.memory import InMemorySaver 


agent = create_agent(
    model,
    tools,
    system_prompt=system_prompt,
    middleware=[ 
        HumanInTheLoopMiddleware( 
            interrupt_on={"sql_db_query": True}, 
            description_prefix="Tool execution pending approval", 
        ), 
    ], 
    checkpointer=InMemorySaver(), 
)

In [55]:
question = "How many tables are there in database ? "
config = {"configurable": {"thread_id": "1"}} 

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    config, 
    stream_mode="values",
):
    if "__interrupt__" in step: 
        print("INTERRUPTED:") 
        interrupt = step["__interrupt__"][0] 
        for request in interrupt.value["action_requests"]: 
            print(request["description"]) 
    elif "messages" in step:
        step["messages"][-1].pretty_print()
    else:
        pass


How many tables are there in database ? 
Tool Calls:
  sql_db_list_tables (3ntpjyq66)
 Call ID: 3ntpjyq66
  Args:
    tool_input:
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Tool Calls:
  sql_db_schema (5jhkkj3ce)
 Call ID: 5jhkkj3ce
  Args:
    table_names: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Name: sql_db_schema


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosm

In [56]:
from langgraph.types import Command 

for step in agent.stream(
    Command(resume={"decisions": [{"type": "approve"}]}), 
    config,
    stream_mode="values",
):
    if "messages" in step:
        step["messages"][-1].pretty_print()
    elif "__interrupt__" in step:
        print("INTERRUPTED:")
        interrupt = step["__interrupt__"][0]
        for request in interrupt.value["action_requests"]:
            print(request["description"])
    else:
        pass


There are 11 tables in the database: 

1. Album
2. Artist
3. Customer
4. Employee
5. Genre
6. Invoice
7. InvoiceLine
8. MediaType
9. Playlist
10. PlaylistTrack
11. Track
