In [None]:
import sqlite3
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
model_name ="gpt-5-nano"
llm = ChatOpenAI(model=model_name)

# Load environment variables (ensure OPENAI_API_KEY is set)
load_dotenv("../.env")

In [None]:
db_name="banking_customers_indian.db"
# Setup LangChain SQLDatabase
db = SQLDatabase.from_uri(f"sqlite:///{db_name}")
llm = ChatOpenAI(model=model_name)
print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM consumers LIMIT 5;")}')

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for tool in tools:
    print(f"{tool.name}: {tool.description}\n")

In [None]:


# Define Custom System Prompt
prompt_template = """
You are an expert banking assistant designed to interact with a SQL database containing customer financial data.
Your goal is to help users understand customer demographics, spending habits, and financial status.

The database has a table named `consumers` with the following columns:
- `customer_name`: Name of the customer.
- `account_number`: Unique identifier for the account.
- `bank_name`: Name of the bank where the account is held (e.g., SBI, HDFC).
- `monthly_average_balance`: Customer's average monthly balance in INR.
- `yearly_income`: Customer's annual income in INR.
- `place`: City and State of the customer (e.g., Mumbai, MH).
- `top_spending_category`: The category where the customer spends the most.

Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

Rules:
1. You can order the results by relevant columns (`yearly_income`, `monthly_average_balance`) to return interesting examples.
2. Never query for all columns from a specific table, only ask for the relevant columns given the question.
3. You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
4. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
5. If asked about "rich" customers, assume it refers to high `yearly_income` or `monthly_average_balance`.

To start you should ALWAYS look at the tables in the database to see what you can query. Do NOT skip this step.
Then you should query the schema of the most relevant tables.
"""

system_prompt = prompt_template.format(dialect=db.dialect, top_k=5)

In [None]:
from langchain.agents import create_agent
from langchain.agents.middleware import HumanInTheLoopMiddleware 
from langgraph.checkpoint.memory import InMemorySaver 


agent = create_agent(
    llm,
    tools,
    system_prompt=system_prompt,
    middleware=[ 
        HumanInTheLoopMiddleware( 
            interrupt_on={"sql_db_query": True}, 
            description_prefix="Tool execution pending approval", 
        ), 
    ], 
    checkpointer=InMemorySaver(), 
)

In [None]:
query2 = "What is the average monthly balance of HDFC Bank customers?"
config = {"configurable": {"thread_id": "1"}} 
import time
for step in agent.stream(
    {"messages": [{"role": "user", "content": query2}]},
    config, 
    stream_mode="values",
):
    print(step)
    time.sleep(20)   # ‚è∏ pauses execution for 20 seconds



In [None]:
if "__interrupt__" in step: 
        print("INTERRUPTED:") 
        interrupt = step["__interrupt__"][0] 

In [None]:
 step["__interrupt__"]

In [None]:
 step["__interrupt__"][0]

In [None]:
interrupt.value

In [None]:
for request in interrupt.value:
    print(request)

In [None]:
for request in interrupt.value["action_requests"]:
    print(request["description"])

In [None]:
query2 = "What is the average monthly balance of HDFC Bank customers?"
config = {"configurable": {"thread_id": "2"}} 
import time


for step in agent.stream(
    {"messages": [{"role": "user", "content": query2}]},
    config, 
    stream_mode="values",
):
    if "__interrupt__" in step: 
        print("INTERRUPTED:") 
        interrupt = step["__interrupt__"][0] 
        for request in interrupt.value["action_requests"]: 
            print(request["description"]) 
    elif "messages" in step:
        step["messages"][-1].pretty_print()
    else:
        pass

In [None]:
from langgraph.types import Command 

for step in agent.stream(
    Command(resume={"decisions": [{"type": "approve"}]}), 
    config,
    stream_mode="values",
):
    if "messages" in step:
        step["messages"][-1].pretty_print()
    elif "__interrupt__" in step:
        print("INTERRUPTED:")
        interrupt = step["__interrupt__"][0]
        for request in interrupt.value["action_requests"]:
            print(request["description"])
    else:
        pass

In [None]:
step['messages'][-1]