In [1]:
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

OPENAI_API_KEY=os.getenv('AZURE_OPENAI_API_KEY_US')
OPENAI_API_KEY_E=os.getenv('AZURE_OPENAI_API_KEY_US2') 

# os.environ['OPENAI_API_TYPE'] = 'azure'
os.environ['OPENAI_API_VERSION'] = '2024-08-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT'] = 'https://azure-chat-try-2.openai.azure.com/'
os.environ['AZURE_OPENAI_DEPLOYMENT'] = 'chat-endpoint-us-gpt4o'

os.environ['OPENAI_API_VERSION_E'] = '2024-12-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT_E'] = 'https://agents-4on.openai.azure.com/'
os.environ['AZURE_OPENAI_EMBEDDING_DEPLOYMENT_E'] = "text-embedding-3-large-eus2"

# LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
# os.environ['LANGCHAIN_TRACING_V2'] = 'true'
# os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
# os.environ['LANGCHAIN_PROJECT'] = "rag-sql"


In [3]:

from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

llm = AzureChatOpenAI(
    api_key = OPENAI_API_KEY,  
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    openai_api_version=os.getenv("OPENAI_API_VERSION"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT")
)

emb_model = AzureOpenAIEmbeddings(
    api_key=OPENAI_API_KEY_E,
    azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT_E'),  
    api_version=os.getenv('OPENAI_API_VERSION_E'),
    azure_deployment=os.getenv('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_E')
)

In [5]:
# Database

from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///./database/credit-risk.db", sample_rows_in_table_info=2)
print(db.dialect)
print(db.get_usable_table_names())

# engine = create_engine("sqlite:///./database/credit-risk.db", future=True)

sqlite
['collaterals', 'customers', 'sectors', 'transactions']


In [5]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

vector_store = Chroma(
    collection_name="risk_db_tables",
    embedding_function=emb_model,
    persist_directory="./vector_db"
)

retriever = vector_store.as_retriever()

# # Prompt
# template = """Answer the question based only on the following context:
# {context}

# Question: {question}
# """
# prompt = ChatPromptTemplate.from_template(template)

# rag_chain = (
#     {"context": retriever, "question": RunnablePassthrough()}
#     | prompt
#     | llm
#     | StrOutputParser()
# )

# res = rag_chain.invoke("How can I get the total undrawn exposure per economic sector?")
# print(res)

In [6]:
class ChatState(MessagesState):
    retrieved_docs: List[Dict[str, Any]]
    sql_query: str
    query_results: List[Dict[str, Any]]
    final_answer: str


# class ChatState(TypedDict):
#     # MessagesState requirements
#     messages: Annotated[Sequence[BaseMessage], add_messages]
#     retrieved_docs: List[Dict[str, Any]]
#     sql_query: str
#     query_results: List[Dict[str, Any]]
#     final_answer: str

In [7]:
@tool
def db_retriever(question: str, limit = 4, runtime=None) -> Command:
    """
    Semantic retrieval from Chroma. Returns Command that:
      - appends a ToolMessage describing the retrieval to `messages` (MessagesState reducer will append),
      - updates `retrieved_docs`,
      - routes to "sql_writer".
    """
    results = vector_store.similarity_search(question, k=limit)
    docs_serial = [
        {"table": d.metadata.get("table"), "source_file": d.metadata.get("source_file"), "snippet": d.page_content[:800]}
        for d in results
    ]

    # Stream small progress to caller UI if available
    if runtime and getattr(runtime, "stream_writer", None):
        runtime.stream_writer(f"Retrieved {len(docs_serial)} docs from the schema store.")

    # Append a ToolMessage (messages reducer will append, not overwrite)
    tool_msg = ToolMessage(content=f"Retriever: found {len(docs_serial)} docs relevant to the question.")

    # Command: update messages and retrieved_docs, then goto sql_writer
    return Command(
        update={
            "messages": [tool_msg],         # MessagesState reducer handles appending/deserializing
            "retrieved_docs": docs_serial,  # Regular state key update (replaced/merged)
        },
        goto="sql_writer",
    )

In [8]:
@tool
def sql_writer(question: str, retrieved_docs: List[Dict[str, Any]], runtime: ToolRuntime | None = None) -> Command:
    """
    Produce a SQL query based on the question & retrieved schema snippets.
    Returns a Command that updates `sql_query`, appends a ToolMessage with the SQL,
    and routes to 'execute_sql'.
    """
    docs_text = "\n\n".join([f"Table: {d.get('table')}\nSnippet: {d.get('snippet')}" for d in retrieved_docs])

    prompt = f"""
        Produce a single SQL statement that answers the user's question. Return only the SQL.
        User question:
        {question}

        Relevant schema snippets:
        {docs_text}
    """
    
    # sql = llm(prompt).strip()
    response = llm.invoke(prompt)
    sql = response.content.strip()

    if runtime and getattr(runtime, "stream_writer", None):
        runtime.stream_writer("Generated SQL query.")

    tool_msg = ToolMessage(content=f"SQL generated: {sql}")

    return Command(
        update={"sql_query": sql, "messages": [tool_msg]},
        goto="execute_sql",
    )

In [10]:
# def start_node(state: ChatState): # or node_retriever()
#     """
#     Entry node. Expects the latest user message to be present in state.messages.
#     Calls db_retriever tool by returning its Command (tools may be called directly).
#     """
#     # The latest HumanMessage content
#     msgs = state.get("messages", [])
#     last_msg = msgs[-1] if msgs else HumanMessage(content="")
#     question = last_msg.content if hasattr(last_msg, "content") else str(last_msg)

#     # return a Command or call a tool, e.g. call retriever tool (tool returns Command)
#     return db_retriever(question)
#     # return db_retriever({"question": question})


def start_node(state: ChatState) -> ChatState:
    """
    Entry node. Expects the latest user message to be present in state.messages.
    Calls db_retriever tool by returning its Command (tools may be called directly).
    """
    # The latest HumanMessage content
    msgs = state.get("messages", [])
    last_msg = msgs[-1] if msgs else HumanMessage(content="")
    question = last_msg.content if hasattr(last_msg, "content") else str(last_msg)

    db_retriever.invoke({"question": question})

    return state

# OR

#     return db_retriever.invoke({"question": question})



In [11]:
from sqlalchemy.exc import SQLAlchemyError

def execute_sql_node(state: ChatState):
    """
    Run the SQL stored in state['sql_query'] against a demo sqlite DB,
    update `query_results` and append a ToolMessage with a short summary,
    then route to 'finalize'.
    """
    sql = state.get("sql_query", "")
    if not sql:
        # nothing to run â†’ still go to finalize
        return Command(update={"query_results": []}, goto="finalize")

    try:
        with engine.connect() as conn:
            # Use text() to wrap raw SQL safely in SQLAlchemy
            result = conn.execute(text(sql))
            # If SELECT, fetch results
            rows = result.fetchall()
            # Convert Row objects to dictionaries (Row._mapping is stable)
            results = [dict(row._mapping) for row in rows]
            summary = f"Executed SQL, returned {len(results)} rows."
    except SQLAlchemyError as e:
        results = [{"error": str(e)}]
        summary = f"SQL execution error: {str(e)}"

    # Append a small ToolMessage so the messages channel contains the execution note
    tool_msg = ToolMessage(content=summary)

    # Update structured results and messages; route to finalize
    return Command(update={"query_results": results, "messages": [tool_msg]}, goto="finalize")


    # results = db.run_no_throw(sql)

    # # Demo: in-memory SQLite with tiny sample data (replace in prod)
    # conn = sqlite3.connect(":memory:")
    # cur = conn.cursor()
    # cur.executescript(
    #     """
    #     CREATE TABLE customers(ref_date TEXT, partner_id TEXT, pd REAL, country TEXT);
    #     CREATE TABLE collaterals(ref_date TEXT, coll_id TEXT, mkt_value REAL);
    #     INSERT INTO customers VALUES ('2024-08-31','C001',0.02,'AT');
    #     INSERT INTO customers VALUES ('2024-08-31','C002',0.15,'NL');
    #     INSERT INTO collaterals VALUES ('2024-08-31','L001',100000.0);
    #     """
    # )
    # conn.commit()

    # try:
    #     cur.execute(sql)
    #     cols = [d[0] for d in cur.description] if cur.description else []
    #     rows = cur.fetchall()
    #     results = [dict(zip(cols, row)) for row in rows)]
    #     summary = f"SQL executed successfully: {len(results)} rows."
    # except Exception as e:
    #     results = [{"error": str(e)}]
    #     summary = f"SQL execution failed: {str(e)}"
    # finally:
    #     conn.close()

    # # Append a ToolMessage summarizing execution and update results
    # tool_msg = ToolMessage(content=summary)
    # return Command(update={"query_results": results, "messages": [tool_msg]}, goto="finalize")

In [12]:
def finalize_node(state):
    # read query_results as structured state
    results = state.get("query_results", [])
    # append an AI reply into messages via Command update (MessagesState reducer will append)
    reply = AIMessage(content="Here are the results...")
    return Command(update={"final_answer": "Answer text", "messages": [reply]}, goto="__end__")

# def finalize_node(state: ChatState):
#     """
#     Summarize the query_results for the user using the LLM.
#     Append an AIMessage and set `final_answer`.
#     """
#     last_user = None
#     # Find latest HumanMessage (safely)
#     for m in reversed(state.messages):
#         if isinstance(m, HumanMessage):
#             last_user = m.content
#             break
#     question = last_user or "the user's question"

#     prompt = f"""
#         User asked: {question}

#         SQL results (JSON): {json.dumps(state.get('query_results', []), indent=2)}

#         Write a short, clear answer to the user that explains the results.
#     """
#     answer_text = llm.invoke(prompt)

#     ai_msg = AIMessage(content=answer_text)
#     return Command(update={"final_answer": answer_text, "messages": [ai_msg]}, goto=END)


In [13]:
graph = StateGraph(ChatState)
graph.add_node("start", start_node)
# We add a node wrapper for sql_writer so the graph knows it exists as a node;
# the sql_writer itself is implemented as a tool returning a Command.
graph.add_node("sql_writer", lambda state: sql_writer(state.messages[-1].content, state.get("retrieved_docs", [])), ends=["execute_sql"])
graph.add_node("execute_sql", execute_sql_node, ends=["finalize"])
graph.add_node("finalize", finalize_node)
graph.add_edge(START, "start")
graph = graph.compile()

In [14]:
if __name__ == "__main__":
    # Initialize with a HumanMessage in messages (MessagesState will keep this channel)
    initial_state = {"messages": [HumanMessage(content="What is the total off-balance exposure for September 2023?")]}

    result = graph.invoke(initial_state)
    print("Graph result state keys:")
    print("retrieved_docs:", result.get("retrieved_docs"))
    print("sql_query:", result.get("sql_query"))
    print("query_results:", result.get("query_results"))
    print("final_answer:", result.get("final_answer"))

    # # The `messages` channel contains the conversation history (user, tool annotations, and final AI response)
    # print("\nConversation messages (last 5):")
    # for m in result["messages"][-5:]:
    #     print(type(m).__name__, ":", m.content)

KeyError: 'tool_call_id'