In [1]:
# IMPORT

import os
from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union
from typing_extensions import Annotated

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# --- LangChain / LangGraph core ---
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.documents import Document as LCDocument
# from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from openai import AzureOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_community.vectorstores import Chroma
# from langchain_community.vectorstores import FAISS

from langchain_core.tools import tool
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate

from langgraph.graph import END, StateGraph, MessagesState
from langgraph.graph.message import add_messages

# --- SQL and validation ---
import sqlglot
from sqlglot import parse_one
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine

# --- Utilities ---
import json
from pathlib import Path
import time


In [2]:
# OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')

OPENAI_API_KEY=os.getenv('AZURE_OPENAI_API_KEY_US')
OPENAI_API_KEY_E=os.getenv('AZURE_OPENAI_API_KEY_US2')

# os.environ['OPENAI_API_TYPE'] = 'azure'
os.environ['OPENAI_API_VERSION'] = '2024-08-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT'] = 'https://azure-chat-try-2.openai.azure.com/'

os.environ['OPENAI_API_VERSION_E'] = '2024-12-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT_E'] = 'https://agents-4on.openai.azure.com/'

# LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
# os.environ['LANGCHAIN_TRACING_V2'] = 'true'
# os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
# os.environ['LANGCHAIN_PROJECT'] = "rag-sql"

In [3]:
emb_model = AzureOpenAI(
    api_key = OPENAI_API_KEY_E,  
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT_E"),
    api_version=os.getenv("OPENAI_API_VERSION_E")
)
deployment="text-embedding-3-small-eus2"

llm = AzureChatOpenAI(
    api_key = OPENAI_API_KEY,  
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    openai_api_version=os.getenv("OPENAI_API_VERSION"),
    azure_deployment="chat-endpoint-us-gpt4o"
)

In [None]:
# def get_embedding(text: str, model=deployment, **kwargs) -> List[float]:
#     # replace newlines, which can negatively affect performance.
#     text = text.replace("\n", " ")

#     response = emb_model.embeddings.create(input=[text], model=model, **kwargs)

#     return response.data[0].embedding

In [5]:
# # Database
# from langchain_community.utilities import SQLDatabase
# db = SQLDatabase.from_uri("sqlite:///./database/credit-risk.db", sample_rows_in_table_info=2)
# print(db.dialect)
# print(db.get_usable_table_names())

In [4]:
# CONFIG

DB_CONN_STR = os.getenv("DB_CONN_STR", "sqlite:///./database/credit-risk.db")
VECTOR_DIR = os.getenv("VECTOR_DIR", "./database")
DB_DOCS_DIR = os.getenv("DB_DOCS_DIR", "./database")

# Safety limits
MAX_ROWS_DEFAULT = int(os.getenv("MAX_ROWS_DEFAULT", "2000"))
QUERY_TIMEOUT_SECS = int(os.getenv("QUERY_TIMEOUT_SECS", "60"))
REWRITE_MAX_ATTEMPTS = int(os.getenv("REWRITE_MAX_ATTEMPTS", "2"))

In [6]:
# 1) LOAD DB DOCS (JSON) AND PREP RAG INDEX

def load_db_docs(dir_path: str) -> Dict[str, Dict[str, Any]]:
    """
    Load JSON documentation files from a folder.
    Returns a dict keyed by table name -> metadata dict.
    """
    docs = {}
    for p in Path(dir_path).glob("*.json"):
        with open(p, "r", encoding="utf-8") as f:
            d = json.load(f)
            table = d["table"]
            docs[table] = d
    return docs

DB_DOCS = load_db_docs(DB_DOCS_DIR)

def docs_to_text_chunks(db_docs: Dict[str, Dict[str, Any]]) -> List[LCDocument]:
    """
    Flatten JSON table docs into LangChain Document chunks with helpful metadata.
    """
    records = []
    for table, meta in db_docs.items():
        # full text for table meta
        full = json.dumps(meta, ensure_ascii=False, indent=2)
        records.append(LCDocument(
            page_content=f"TABLE: {table}\n{full}",
            metadata={"table": table, "section": "full"}
        ))
        # optionally split by columns for finer recall
        for col in meta.get("columns", []):
            col_text = json.dumps(col, ensure_ascii=False, indent=2)
            records.append(LCDocument(
                page_content=f"TABLE: {table}\nCOLUMN: {col['name']}\n{col_text}",
                metadata={"table": table, "section": "column", "column": col["name"]}
            ))
        # relationships
        for rel in meta.get("relationships", []):
            rel_text = json.dumps(rel, ensure_ascii=False, indent=2)
            records.append(LCDocument(
                page_content=f"TABLE: {table}\nRELATIONSHIP:\n{rel_text}",
                metadata={"table": table, "section": "relationship"}
            ))
    return records


def build_vector_store(records: List[LCDocument], persist_dir: str = VECTOR_DIR) -> Chroma:
    splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=120)
    chunks = splitter.split_documents(records)
    vs = Chroma.from_documents(chunks, embedding=emb_model, persist_directory=persist_dir)
    vs.persist()
    return vs

def ensure_vector_store(vdir: str = VECTOR_DIR) -> Chroma:
    # Create or load
    if not Path(vdir).exists():
        Path(vdir).mkdir(parents=True, exist_ok=True)
    # Try to open
    try:
        vs = Chroma(embedding_function=emb_model, persist_directory=vdir)
        # quick sanity retrieve
        _ = vs.similarity_search("test", k=1)
        return vs
    except Exception:
        # rebuild if load fails
        records = docs_to_text_chunks(DB_DOCS)
        return build_vector_store(records, persist_dir=vdir)

VECTORSTORE = ensure_vector_store()
RETRIEVER = VECTORSTORE.as_retriever(search_kwargs={"k": 6})

  vs = Chroma(embedding_function=emb_model, persist_directory=vdir)


AttributeError: 'AzureOpenAI' object has no attribute 'embed_documents'

In [None]:
VECTOR_DIR = os.getenv("VECTOR_DIR", "./vectorstore")

In [None]:
VECTOR_DIR

In [None]:

from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings


In [None]:
parent_dir

In [None]:
class DeepAgentState(AgentState):
    """Extended agent state that includes task tracking and virtual file system.

    Inherits from LangGraph's AgentState and adds:
    - schema: The retrieved database schema
    - sql_query: The query that refelcts the user question and the database schema
    - query_valid: a validation of the query correctness
    - query_result: the result of running the qyery against the database  
    """
    schema: str
    sql_query: str
    query_valid: bool
    
    
    # class GraphState(MessagesState):
#     sql_query: str
#     query_valid: bool
#     query_result: str
    query_result: str


In [None]:
# @tool
def get_schema_tool(input: None) -> str:
    """Retrieve the database schema"""
    schema = db.get_table_info()
    return schema

In [None]:
write_query_instructions = """
You are a SQL expert with a strong attention to detail.

Given an input question, create a syntactically correct SQLite query to run to help find the answer. 

When generating the query:

Unless the user specifies in his question a specific number of examples they wish to obtain, limit your query to at most 3 results. 
For example, if the user asks for the top 5 results, you should NOT limit the query to 3 results.

You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

Only use the following tables:
{schema}

Question: {question}
"""

@tool(parse_docstring=True)
def generate_query_tool(
    ,
    state: Annotated[dict, GraphState]
   ) -> str:
    """Write a syntactically valid SQL query.
        Args:
            query: Search query to execute
            state: Injected agent state for file storage
            tool_call_id: Injected tool call identifier
            max_results: Maximum number of results to return (default: 1)
            topic: Topic filter - 'general', 'news', or 'finance' (default: 'general')

        Returns:
            Command that saves full results to files and provides minimal summary
    """



    question = input["question"]
    schema = input["schema"]
    prompt = f"Given the question: {question} and the schema: {schema}, write a correct SQL query."
    response = await llm.ainvoke([HumanMessage(content=prompt)], RunnableConfig())
    return response.content


In [None]:
print(get_schema_tool(None))

In [None]:
from typing_extensions import Annotated

write_query_instructions = """
You are a SQL expert with a strong attention to detail.

Given an input question, create a syntactically correct SQLite query to run to help find the answer. 

When generating the query:

Unless the user specifies in his question a specific number of examples they wish to obtain, limit your query to at most 3 results. 
For example, if the user asks for the top 5 results, you should NOT limit the query to 3 results.

You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

Only use the following tables:
{schema}

Question: {question}
"""

class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: State):
    """Generate SQL query to fetch information."""

    question = state["question"]
    schema = db.get_table_info()
    
    prompt = write_query_instructions.format(
        question = question,
        schema = schema        
    ) 

    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    # return {"query": result["query"]}    
    return {"query": "SELECT \n    SUM(EXPOSURE) AS total_off_balance_exposure \nFROM \n    transactions\nWHERE \n    DATE LIKE '2023-sep%' AND\n    EXPOSURE_TYPE = 'OFF_BALANCE';"}


In [None]:

@tool
def generate_sql_tool(question: str, docs: str) -> str:
    """Generate SQL query from question + docs."""
    prompt = f"""
    You are an assistant that writes SQL queries.
    User question: {question}
    Database info: {docs}
    
    Write a SQL query that answers the question.
    """
    return llm.invoke(prompt).content
