In [None]:
import json
import os
import uuid
import pandas as pd

from tqdm import tqdm
from typing import List, Dict, Any
from dotenv import load_dotenv
from openai import OpenAI
from operator import itemgetter
from IPython.display import display, Markdown

from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.messages import HumanMessage

In [None]:
pd.set_option('display.width', -1)
pd.set_option('max_colwidth', 1000)

In [None]:
load_dotenv(override=True)
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

### 1. Helper Ingestion Pipeline Functions
---

In [None]:
def table_to_string(table):
    """Convert table to string representation."""
    return "\n".join(" | ".join(str(cell) for cell in row) for row in table)

In [None]:
def setup_summarization_chain():
    """Set up the summarization chain for text and table content."""
    prompt_text = """
    You are an assistant tasked with summarizing tables and text from financial documents for semantic retrieval.
    These summaries will be embedded and used to retrieve the raw text or table elements.
    Give a detailed summary of the table or text below that is well optimized for retrieval.
    For any tables also add in a description of what the table is about besides the summary.
    Then, include the table in markdown format. Do not add additional words like Summary: etc.

    Table or text chunk:
    {element}
    """
    
    prompt = ChatPromptTemplate.from_template(prompt_text)
    chatgpt = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    
    return (
        {"element": RunnablePassthrough()}
        | prompt
        | chatgpt
        | StrOutputParser()
    )

In [None]:
def create_chunks(text: str, chunk_size: int = 4000, chunk_overlap: int = 500) -> List[str]:
    """Create chunks from text using RecursiveCharacterTextSplitter."""
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", " ", ""]
    )
    return text_splitter.split_text(text)

In [None]:
def process_documents(json_data: List[Dict]) -> tuple:
    """Process JSON data and return text and table documents with their summaries.
    Only processes entries with valid qa fields (not None or empty dict)."""
    processed_data = []
    for entry in json_data:
        # Check if qa field exists and is not empty
        qa_field = entry.get("qa")
        if qa_field is None or qa_field == {}:
            continue
        
        text_parts = []
        if entry.get("pre_text"):
            text_parts.append(" ".join(entry["pre_text"]))
        if entry.get("post_text"):
            text_parts.append(" ".join(entry["post_text"]))
        
        table = entry.get("table_ori", "")
        table_str = table_to_string(table) if table else ""
            
        full_text = "\n\n".join(text_parts)
            
        processed_data.append({
            "id": entry.get("id"),
            "text": full_text,
            "table": table_str,
            #"table_ori": table_ori,
            "qa": entry.get("qa", {})
        })
    
    text_docs = []
    table_docs = []
    
    summarize_chain = setup_summarization_chain()
    
    for item in processed_data:
        if item["text"]:
            chunks = create_chunks(item["text"])
            for chunk in chunks:
                text_docs.append(Document(
                    page_content=chunk,
                    metadata={"parent_id": item["id"]}
                ))
        
        if item["table"]:
            table_docs.append(Document(
                page_content=str(item["table"]),
                metadata={"parent_id": item["id"], "is_table": True}
            ))
    
    text_summaries = summarize_chain.batch([doc.page_content for doc in text_docs], {"max_concurrency": 5})
    table_summaries = summarize_chain.batch([doc.page_content for doc in table_docs], {"max_concurrency": 5})
    
    return text_docs, text_summaries, table_docs, table_summaries


In [None]:
def split_content_types(docs):
    """Split retrieved documents into text and table content."""
    texts = []
    tables = []
    
    for doc in docs:
        if doc.metadata.get("is_table", False):
            tables.append(doc.page_content)
        else:
            texts.append(doc.page_content)
    
    return {"texts": texts, "tables": tables}

### 3. Helper Retrieval-Generation Pipelne Functions
---

In [None]:
def create_multi_vector_retriever(vectorstore, text_summaries, texts, table_summaries, tables):
    """Create retriever that indexes summaries but returns raw content."""
    store = InMemoryStore()
    id_key = "doc_id"
    
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )
    
    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
    
    if text_summaries:
        add_documents(retriever, text_summaries, texts)
    if table_summaries:
        add_documents(retriever, table_summaries, tables)
    
    return retriever

In [None]:
def prompt_function(data_dict):
    """Create a prompt with text and table context."""
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    formatted_tables = "\n".join(data_dict["context"]["tables"])
    
    prompt_text = f"""You are an analyst tasked with understanding detailed information from text documents and data tables.
        Use the provided context information to answer the user's question.
        Do not make up answers, use only the provided context documents below.

        User question:
        {data_dict['question']}

        Text context:
        {formatted_texts}

        Table context:
        {formatted_tables}

        Answer:
    """
    
    return [HumanMessage(content=prompt_text)]

In [None]:
def setup_rag_chain(retriever, chatgpt):
    """Set up the RAG chain."""
    # Create base RAG chain
    rag_chain = (
        {
            "context": itemgetter('context'),
            "question": itemgetter('input'),
        }
        | RunnableLambda(prompt_function)
        | chatgpt
        | StrOutputParser()
    )

    # Create retrieval chain
    retrieve_docs = (
        itemgetter('input')
        | retriever
        | RunnableLambda(split_content_types)
    )

    # Combine into final chain
    return (
        RunnablePassthrough.assign(context=retrieve_docs)
        .assign(answer=rag_chain)
    )

In [None]:
def rag_qa(chain, query):
    """Execute RAG QA."""
    response = chain.invoke({'input': query})
    print('=='*50)
    print('Answer:')
    display(Markdown(response['answer']))
    print('--'*50)
    print('Sources:')
    text_sources = response['context']['texts']
    table_sources = response['context']['tables']
    
    if text_sources:
        print("\nText Sources:")
        for text in text_sources:
            display(Markdown(text))
            print()
    
    if table_sources:
        print("\nTable Sources:")
        for table in table_sources:
            display(Markdown(f"```\n{table}\n```"))
            print()
    
    print('=='*50)
    return response

### 3. Execute Main Code
---

In [None]:
def main():
    # Initialize models
    embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
    chatgpt = ChatOpenAI(model="gpt-4o-mini", temperature=0)
    
    # Initialize Chroma vectorstore
    vectorstore = Chroma(
        collection_name="text_table_rag",
        embedding_function=embedding_model,
        persist_directory="./chroma_db"
    )
    
    # Load and process data
    with open('./data/convfinqatrain.json', 'r') as f:
        json_data = json.load(f)
    test_data = json_data[:10]
    
    # Process documents and generate summaries
    # text_docs, text_summaries, table_docs, table_summaries = process_documents(test_data)
    
    # Process documents with progress bar
    print("Processing documents and generating summaries...")
    with tqdm(total=1, desc="Processing Documents") as pbar:
        text_docs, text_summaries, table_docs, table_summaries = process_documents(test_data)
        pbar.update(1)
    
    print(f"Processed {len(text_docs)} text documents and {len(table_docs)} table documents")

    # Create retriever
    print("\nEmbedding into doctstore and vectorstore...")
    retriever = create_multi_vector_retriever(
        vectorstore,
        text_summaries,
        text_docs,
        table_summaries,
        table_docs
    )
    print("Embedding complete!")

    # Setup RAG chain
    rag_chain = setup_rag_chain(retriever, chatgpt)
    
    return rag_chain, retriever

In [None]:
if __name__ == "__main__":
    rag_chain, retriever = main()
    
    # Example usage
    question = 'what was the percentage change in the net cash from operating activities from 2008 to 2009?'
    
    # Test retriever directly
    print("\nTesting retriever directly:")
    docs = retriever.invoke(question, limit=5)
    print(f"Retrieved {len(docs)} documents")
    for i, doc in enumerate(docs, 1):
        print(f"\nDocument {i}:")
        print(f"Content preview: {doc.page_content[:200]}...")
    
    # Get full response
    print("\nGetting full response:")
    response = rag_qa(rag_chain, question)