In [1]:
import os
import sys
import json
from glob import glob
from dotenv import load_dotenv, find_dotenv
from typing import Dict, Any, List
# from langchain_openai import OpenAIEmbeddings

from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

In [2]:
OPENAI_API_KEY=os.getenv('AZURE_OPENAI_API_KEY_US')
OPENAI_API_KEY_E=os.getenv('AZURE_OPENAI_API_KEY_US2')

# os.environ['OPENAI_API_TYPE'] = 'azure'
os.environ['OPENAI_API_VERSION'] = '2024-08-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT'] = 'https://azure-chat-try-2.openai.azure.com/'
os.environ['AZURE_OPENAI_DEPLOYMENT'] = 'chat-endpoint-us-gpt4o'

os.environ['OPENAI_API_VERSION_E'] = '2024-12-01-preview'
os.environ['AZURE_OPENAI_ENDPOINT_E'] = 'https://agents-4on.openai.azure.com/'
os.environ['AZURE_OPENAI_EMBEDDING_DEPLOYMENT_E'] = "text-embedding-3-large-eus2"

In [3]:
llm = AzureChatOpenAI(
    api_key = OPENAI_API_KEY,  
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    openai_api_version=os.getenv("OPENAI_API_VERSION"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT")
)

emb_model = AzureOpenAIEmbeddings(
    api_key=OPENAI_API_KEY_E,
    azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT_E'),  
    api_version=os.getenv('OPENAI_API_VERSION_E'),
    azure_deployment=os.getenv('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_E')
)

In [4]:
# Convert a single table JSON into readable text
def table_json_to_text(table_json: Dict[str, Any]) -> str:
    lines = []
    lines.append(f"Table: {table_json.get('table', '')}")
    if desc := table_json.get("description"):
        lines.append(f"Description: {desc}")
    lines.append("")  # blank line

    lines.append("Columns:")
    for col in table_json.get("columns", []):
        name = col.get("name")
        ctype = col.get("type")
        nullable = col.get("nullable")
        desc = col.get("description", "")
        allowed = col.get("allowed_values")
        lines.append(f"- {name} ({ctype}) - nullable={nullable}")
        if desc:
            lines.append(f"  Description: {desc}")
        if allowed:
            allowed_items = ", ".join([f"{k}: {v}" for k, v in allowed.items()])
            lines.append(f"  Allowed values: {allowed_items}")
    lines.append("")

    if constraints := table_json.get("constraints"):
        lines.append("Constraints:")
        for k, v in constraints.items():
            cols = v.get("columns")
            desc = v.get("description")
            lines.append(f"- {k}: columns={cols} - {desc}")
        lines.append("")

    if relationships := table_json.get("relationships"):
        lines.append("Relationships:")
        for rel in relationships:
            related = rel.get("related_table")
            join = rel.get("join_type")
            card = rel.get("cardinality")
            notes = rel.get("notes")
            lines.append(f"- Related table: {related} ({join}) -- {card}")
            if notes:
                lines.append(f"  Notes: {notes}")
    return "\n".join(lines)

In [5]:
# Load all .json files from a directory and return Document objects
def load_table_documents_from_dir(dir_path: str) -> List[Document]:
    docs: List[Document] = []
    pattern = os.path.join(dir_path, "*.json")
    files = sorted(glob(pattern))
    if not files:
        raise FileNotFoundError(f"No .json files found in directory: {dir_path}")
    for fp in files:
        with open(fp, "r", encoding="utf-8") as f:
            table_json = json.load(f)
        text = table_json_to_text(table_json)
        metadata = {
            "table": table_json.get("table"),
            "source_file": os.path.basename(fp),
        }
        docs.append(Document(page_content=text, metadata=metadata))
    return docs

In [6]:
docs = load_table_documents_from_dir("database")

In [7]:
vector_store = Chroma(
    collection_name="risk_db_tables",
    embedding_function=emb_model,
    persist_directory="./vector_db"
)

In [8]:
vector_store.add_documents(docs)

['b5de837b-9352-48fd-81bd-c70272b91b89',
 '48def627-2590-426b-97bd-de47a4ff1998',
 'c107e9fb-d6d0-41fa-9f9d-280e76f93ccc',
 '205e0a6b-f373-4b02-b890-d8aa56394cdc']

In [4]:
vector_store = Chroma(
    collection_name="risk_db_tables",
    embedding_function=emb_model,
    persist_directory="./vector_db"
)

In [5]:
retriever = vector_store.as_retriever()

In [6]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

# Prompt
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)
prompt

ChatPromptTemplate(input_variables=['context', 'question'], input_types={}, partial_variables={}, messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['context', 'question'], input_types={}, partial_variables={}, template='Answer the question based only on the following context:\n{context}\n\nQuestion: {question}\n'), additional_kwargs={})])

In [7]:
rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [8]:
rag_chain.invoke("How can I get the total undrawn exposure per economic sector?")

"To get the total undrawn exposure per economic sector, you would need to follow these steps using the provided context:\n\n1. **Identify Transactions with Undrawn Exposure:**\n   - In the `transactions` table, locate transactions where `PROCESSING` is set to 'OB' (off-balance exposure, also referred to as 'undrawn' exposure).\n\n2. **Link Transactions to Customers:**\n   - Use the relationship between the `transactions` table and the `customers` table to join these tables on the composite key `(REF_DATE, PARTNER_ID)`.\n\n3. **Connect Customers to Sectors:**\n   - Link the `customers` table with the `sectors` table using the `NACE` code, which is a unique identifier for economic sectors in the `sectors` table and corresponds to multiple entries in the `customers` table.\n\n4. **Aggregate the Exposure:**\n   - Sum the `EXPOSURE` field from the `transactions` records that are classified as 'undrawn' (i.e., `PROCESSING` = 'OB'). Group this aggregation by the `SECTOR` descriptions from the