In [22]:
# IMPORT

import os
from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union
from typing_extensions import Annotated

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# --- LangChain / LangGraph core ---
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.documents import Document as LCDocument
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_community.vectorstores import Chroma
# from langchain_community.vectorstores import FAISS

from langchain_core.tools import tool
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate

from langgraph.graph import END, StateGraph, MessagesState
from langgraph.graph.message import add_messages

# --- SQL and validation ---
import sqlglot
from sqlglot import parse_one
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine

# --- Utilities ---
import json
from pathlib import Path
import time


In [None]:
# OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_API_KEY=os.getenv('AZURE_OPENAI_API_KEY_US')
os.environ['OPENAI_API_TYPE'] = 'azure'
os.environ['OPENAI_API_VERSION'] = '2024-08-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT'] = 'https://azure-chat-try-2.openai.azure.com/'



# LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
# os.environ['LANGCHAIN_TRACING_V2'] = 'true'
# os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
# os.environ['LANGCHAIN_PROJECT'] = "rag-sql"

In [19]:
# # Database
# from langchain_community.utilities import SQLDatabase
# db = SQLDatabase.from_uri("sqlite:///./database/credit-risk.db", sample_rows_in_table_info=2)
# print(db.dialect)
# print(db.get_usable_table_names())

In [None]:
# CONFIG

DB_CONN_STR = os.getenv("DB_CONN_STR", "sqlite:///./database/credit-risk.db")
VECTOR_DIR = os.getenv("VECTOR_DIR", "./vectorstore")
DB_DOCS_DIR = os.getenv("DB_DOCS_DIR", "./db_docs")

# Safety limits
MAX_ROWS_DEFAULT = int(os.getenv("MAX_ROWS_DEFAULT", "2000"))
QUERY_TIMEOUT_SECS = int(os.getenv("QUERY_TIMEOUT_SECS", "60"))
REWRITE_MAX_ATTEMPTS = int(os.getenv("REWRITE_MAX_ATTEMPTS", "2"))







In [2]:
VECTOR_DIR = os.getenv("VECTOR_DIR", "./vectorstore")

In [3]:
VECTOR_DIR

'./vectorstore'

In [5]:

from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings


In [2]:
parent_dir

'd:\\UserData\\BB033377\\my_projects\\text-to-sql'

In [None]:
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4o", temperature=0) 

from langchain_openai import AzureChatOpenAI 

llm = AzureChatOpenAI(
    api_key = OPENAI_API_KEY,  
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    openai_api_version=os.getenv("OPENAI_API_VERSION"),
    azure_deployment="chat-endpoint-us-gpt4o"
)

sqlite
['collaterals', 'customers', 'industries', 'transactions']


In [None]:
class DeepAgentState(AgentState):
    """Extended agent state that includes task tracking and virtual file system.

    Inherits from LangGraph's AgentState and adds:
    - schema: The retrieved database schema
    - sql_query: The query that refelcts the user question and the database schema
    - query_valid: a validation of the query correctness
    - query_result: the result of running the qyery against the database  
    """
    schema: str
    sql_query: str
    query_valid: bool
    
    
    # class GraphState(MessagesState):
#     sql_query: str
#     query_valid: bool
#     query_result: str
    query_result: str


In [None]:
# @tool
def get_schema_tool(input: None) -> str:
    """Retrieve the database schema"""
    schema = db.get_table_info()
    return schema

In [None]:
write_query_instructions = """
You are a SQL expert with a strong attention to detail.

Given an input question, create a syntactically correct SQLite query to run to help find the answer. 

When generating the query:

Unless the user specifies in his question a specific number of examples they wish to obtain, limit your query to at most 3 results. 
For example, if the user asks for the top 5 results, you should NOT limit the query to 3 results.

You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

Only use the following tables:
{schema}

Question: {question}
"""

@tool(parse_docstring=True)
def generate_query_tool(
    ,
    state: Annotated[dict, GraphState]
   ) -> str:
    """Write a syntactically valid SQL query.
        Args:
            query: Search query to execute
            state: Injected agent state for file storage
            tool_call_id: Injected tool call identifier
            max_results: Maximum number of results to return (default: 1)
            topic: Topic filter - 'general', 'news', or 'finance' (default: 'general')

        Returns:
            Command that saves full results to files and provides minimal summary
    """



    question = input["question"]
    schema = input["schema"]
    prompt = f"Given the question: {question} and the schema: {schema}, write a correct SQL query."
    response = await llm.ainvoke([HumanMessage(content=prompt)], RunnableConfig())
    return response.content


In [20]:
print(get_schema_tool(None))


CREATE TABLE collaterals (
	"DATE" DATE NOT NULL, 
	"DATE_COLLATERAL" TEXT, 
	"COLLATERAL_ID" TEXT NOT NULL, 
	"MARKET_VALUE" REAL NOT NULL, 
	PRIMARY KEY ("DATE_COLLATERAL")
)

/*
2 rows from collaterals table:
DATE	DATE_COLLATERAL	COLLATERAL_ID	MARKET_VALUE
2023-09-29	45198_oehym174le7795wipiwu	oehym174le7795wipiwu	7123.64
2023-09-29	45198_cojvx420gk3315hmyhyc	cojvx420gk3315hmyhyc	2000.0
*/


CREATE TABLE customers (
	"DATE" DATE NOT NULL, 
	"DATE_PARTNER" TEXT, 
	"PARTNER_ID" TEXT NOT NULL, 
	"STATUS" TEXT NOT NULL, 
	"PD" REAL NOT NULL, 
	"COUNTRY" TEXT, 
	"RATING_MODEL" TEXT NOT NULL, 
	"NACE" TEXT NOT NULL, 
	"COMPANY_SIZE" REAL, 
	PRIMARY KEY ("DATE_PARTNER"), 
	FOREIGN KEY("NACE") REFERENCES industries ("NACE")
)

/*
2 rows from customers table:
DATE	DATE_PARTNER	PARTNER_ID	STATUS	PD	COUNTRY	RATING_MODEL	NACE	COMPANY_SIZE
2023-09-29	45198_qvajo192sx6083	qvajo192sx6083	NON_PERFORMING	1.0	BG 	FKBG21	4649	2.85765402
2023-09-29	45198_jvxeh774ov6289	jvxeh774ov6289	NON_PERFORMING	1.

In [None]:
from typing_extensions import Annotated

write_query_instructions = """
You are a SQL expert with a strong attention to detail.

Given an input question, create a syntactically correct SQLite query to run to help find the answer. 

When generating the query:

Unless the user specifies in his question a specific number of examples they wish to obtain, limit your query to at most 3 results. 
For example, if the user asks for the top 5 results, you should NOT limit the query to 3 results.

You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

Only use the following tables:
{schema}

Question: {question}
"""

class QueryOutput(TypedDict):
    """Generated SQL query."""
    query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: State):
    """Generate SQL query to fetch information."""

    question = state["question"]
    schema = db.get_table_info()
    
    prompt = write_query_instructions.format(
        question = question,
        schema = schema        
    ) 

    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    # return {"query": result["query"]}    
    return {"query": "SELECT \n    SUM(EXPOSURE) AS total_off_balance_exposure \nFROM \n    transactions\nWHERE \n    DATE LIKE '2023-sep%' AND\n    EXPOSURE_TYPE = 'OFF_BALANCE';"}


In [None]:

@tool
def generate_sql_tool(question: str, docs: str) -> str:
    """Generate SQL query from question + docs."""
    prompt = f"""
    You are an assistant that writes SQL queries.
    User question: {question}
    Database info: {docs}
    
    Write a SQL query that answers the question.
    """
    return llm.invoke(prompt).content
