In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

if not os.getenv("OPENAI_API_KEY") or not os.getenv("COHERE_API_KEY"):
    print("Please create a .env file and add your OPENAI_API_KEY and COHERE_API_KEY.")
else:
    print("API keys loaded successfully.")

API keys loaded successfully.


In [3]:
import uuid
from pathlib import Path
from typing import List, Tuple

# LangChain Imports
from langchain_community.document_loaders import UnstructuredPDFLoader
from langchain_chroma import Chroma
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_cohere import CohereRerank
from langchain_core.documents import Document
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool, BaseTool, create_retriever_tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings


class VectorStoreManager:
    def __init__(self, chroma_path: Path):
        self.chroma_path = chroma_path
        self.db_connection = None
        self.collection_name = "rag_document_collection"

    def get_connection(self) -> Chroma:
        if self.db_connection:
            return self.db_connection
        
        embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
        self.db_connection = Chroma(
            persist_directory=str(self.chroma_path),
            embedding_function=embeddings,
            collection_name=self.collection_name,
        )
        return self.db_connection


class DocumentIndexer:
    def __init__(self, vector_store: Chroma, doc_store: InMemoryStore, llm_for_summarization: ChatOpenAI):
        self.vector_store = vector_store
        self.doc_store = doc_store
        self.llm = llm_for_summarization
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)

    def _load_and_partition(self, file_paths: List[str]) -> List[Document]:
        all_elements = []
        for path in file_paths:
            loader = UnstructuredPDFLoader(path, mode="elements", strategy="hi_res")
            elements = loader.load()
            for element in elements:
                element.metadata["filename"] = os.path.basename(path)
            all_elements.extend(elements)
        return all_elements

    def _create_chunks_and_summaries(self, elements: List[Document]):
        text_elements, table_elements = [], []
        for el in elements:
            if el.metadata.get("category") == "Table":
                table_elements.append(el)
            else:
                text_elements.append(el)

        text_chunks = self.text_splitter.split_documents(text_elements)

        class TableSummary(BaseModel):
            summary: str = Field(description="A concise summary of the table's purpose and key information.")
        
        prompt = ChatPromptTemplate.from_messages([
             ("system", "You are an expert at summarizing technical tables."),
             ("human", "Summarize the following table to make it easily searchable...\n\nTable:\n{table_html}"),
        ])
        summarize_chain = prompt | self.llm.with_structured_output(TableSummary)
        table_htmls = [el.metadata.get("text_as_html", el.page_content) for el in table_elements]
        
        summaries = summarize_chain.batch([{"table_html": html} for html in table_htmls]) if table_htmls else []

        table_summaries = []
        for i, summary_obj in enumerate(summaries):
            if summary_obj:
                table_el = table_elements[i]
                summary_doc = Document(
                    page_content=summary_obj.summary,
                    metadata={**table_el.metadata, "content_type": "table_summary", "original_content": table_el.metadata.get("text_as_html", table_el.page_content)},
                )
                table_summaries.append(summary_doc)
        
        return text_chunks, table_summaries

    def _store_chunks(self, text_chunks: List[Document], table_summaries: List[Document]):
        doc_ids = [str(uuid.uuid4()) for _ in text_chunks]
        summary_ids = [str(uuid.uuid4()) for _ in table_summaries]

        if text_chunks:
            for i, chunk in enumerate(text_chunks): chunk.metadata["doc_id"] = doc_ids[i]
            self.doc_store.mset(list(zip(doc_ids, text_chunks)))
            filtered_chunks = filter_complex_metadata(text_chunks)
            self.vector_store.add_documents(filtered_chunks)

        if table_summaries:
            original_tables = []
            for i, summary in enumerate(table_summaries):
                summary.metadata["doc_id"] = summary_ids[i]
                original_tables.append(Document(page_content=summary.metadata["original_content"], metadata=summary.metadata))
            self.doc_store.mset(list(zip(summary_ids, original_tables)))
            filtered_summaries = filter_complex_metadata(table_summaries)
            self.vector_store.add_documents(filtered_summaries)

    def process_files(self, file_paths: List[str]):
        print(f"--- Starting Indexing for {len(file_paths)} file(s) ---")
        elements = self._load_and_partition(file_paths)
        text_chunks, table_summaries = self._create_chunks_and_summaries(elements)
        self._store_chunks(text_chunks, table_summaries)
        print(f"--- Finished Indexing ---")


def create_rag_retriever_tool(vector_store: Chroma, doc_store: InMemoryStore):
    base_retriever = MultiVectorRetriever(
        vectorstore=vector_store, docstore=doc_store, id_key="doc_id", search_kwargs={"k": 20}
    )
    compressor = CohereRerank(
        top_n=4,
        model="rerank-v3.5",
        cohere_api_key=os.getenv("COHERE_API_KEY"),
    )
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=base_retriever
    )
    return base_retriever, compressor, compression_retriever

print("Classes and functions defined.")

Classes and functions defined.


In [4]:
# --- 1. Setup and Initialization ---
# Create a directory for your database
os.makedirs("chroma_db_test", exist_ok=True)
pdf_directory = Path("./data")
chroma_directory = Path("./chroma_db_test")

# Initialize components
vector_store_manager = VectorStoreManager(chroma_path=chroma_directory)
vector_store_conn = vector_store_manager.get_connection()
doc_store = InMemoryStore()
llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0)

# --- 2. Index the Documents ---
pdf_files = [str(p) for p in pdf_directory.glob("*.pdf")]

if not pdf_files:
    print("🛑 No PDF files found in the 'data' directory. Please add your PDFs to test.")
else:
    # Create and run the indexer
    indexer = DocumentIndexer(
        vector_store=vector_store_conn,
        doc_store=doc_store,
        llm_for_summarization=llm
    )
    indexer.process_files(pdf_files)
    
  

--- Starting Indexing for 1 file(s) ---
--- Finished Indexing ---


In [5]:
# --- 3. Create the Retriever for testing ---
base_retriever, compressor, retriever = create_rag_retriever_tool(vector_store=vector_store_conn, doc_store=doc_store)

In [6]:
from langchain_core.retrievers import BaseRetriever
from typing import List
def test_retriever_query(query: str, retriever: BaseRetriever):
    """
    Tests a query against the retriever and prints the formatted results.
    
    Args:
        query: The question to ask the retriever.
        retriever: The retriever instance to test.
    """
    print("\n" + "="*50)
    print(f"Testing retriever with query: '{query}'")
    print("="*50)
    
    # Invoke the retriever directly with the provided query
    retrieved_docs = retriever.invoke(query)
    
    # Print the results
    if not retrieved_docs:
        print("No documents were retrieved.")
    else:
        print(f"\nRetrieved {len(retrieved_docs)} documents:\n")
        for i, doc in enumerate(retrieved_docs):
            print(f"--- Document {i+1} ---")
            print(f"Content: {doc.page_content[:500]}...")
            print(f"Source: {doc.metadata.get('filename')}")
            print(f"Page: {doc.metadata.get('page_number')}")
            print(f"Relevance Score (from re-ranker): {doc.metadata.get('relevance_score', 'N/A')}")
            print("-"*(len(str(i+1))+13) + "\n")


In [7]:
query1 = "What should be done if a newly started motor does not run smoothly?"
test_retriever_query(query=query1, retriever=retriever)


Testing retriever with query: 'What should be done if a newly started motor does not run smoothly?'

Retrieved 4 documents:

--- Document 1 ---
Content: Connect the motor as shown in the connection diagram. The wiring, fusing and grounding must comply with the National Electrical Code and local codes. When the motor is connected to the load for proper direction of rotation and started, it should start quickly and run smoothly. If not, stop the motor immediately and determine the cause. Possible causes are: low voltage at the motor, motor connections are not correct or the load is too heavy. Check the motor current after a few minutes of operatio...
Source: LB5001.pdf
Page: 1
Relevance Score (from re-ranker): 0.769306
--------------

--- Document 2 ---
Content: New motors that have been stored for a year or more should be relubricated. Lubrication is also recommended at these intervals:...
Source: LB5001.pdf
Page: 2
Relevance Score (from re-ranker): 0.5035894
--------------

--- Docume

In [8]:
query1 = "What should be done if a newly started motor does not run smoothly?"
test_retriever_query(query=query1, retriever=base_retriever)


Testing retriever with query: 'What should be done if a newly started motor does not run smoothly?'

Retrieved 20 documents:

--- Document 1 ---
Content: Connect the motor as shown in the connection diagram. The wiring, fusing and grounding must comply with the National Electrical Code and local codes. When the motor is connected to the load for proper direction of rotation and started, it should start quickly and run smoothly. If not, stop the motor immediately and determine the cause. Possible causes are: low voltage at the motor, motor connections are not correct or the load is too heavy. Check the motor current after a few minutes of operatio...
Source: LB5001.pdf
Page: 1
Relevance Score (from re-ranker): N/A
--------------

--- Document 2 ---
Content: Clean the grease fitting (or area around grease hole, if equipped with slotted grease screws). If motor has a purge plug, remove it. Motors can be regreased while stopped (at less than 80°C) or running....
Source: LB5001.pdf
Page: 2

In [11]:
vector_retriever = vector_store_conn.as_retriever()

In [12]:
query1 = "What should be done if a newly started motor does not run smoothly?"
test_retriever_query(query=query1, retriever=vector_retriever)


Testing retriever with query: 'What should be done if a newly started motor does not run smoothly?'

Retrieved 4 documents:

--- Document 1 ---
Content: Connect the motor as shown in the connection diagram. The wiring, fusing and grounding must comply with the National Electrical Code and local codes. When the motor is connected to the load for proper direction of rotation and started, it should start quickly and run smoothly. If not, stop the motor immediately and determine the cause. Possible causes are: low voltage at the motor, motor connections are not correct or the load is too heavy. Check the motor current after a few minutes of operatio...
Source: LB5001.pdf
Page: 1
Relevance Score (from re-ranker): N/A
--------------

--- Document 2 ---
Content: Clean the grease fitting (or area around grease hole, if equipped with slotted grease screws). If motor has a purge plug, remove it. Motors can be regreased while stopped (at less than 80°C) or running....
Source: LB5001.pdf
Page: 2


In [14]:
from pprint import pprint
pprint(retriever)

ContextualCompressionRetriever(base_compressor=CohereRerank(client=<cohere.client_v2.ClientV2 object at 0x7fe2e4d76bd0>, top_n=4, model='rerank-v3.5', cohere_api_key=SecretStr('**********'), base_url=None, user_agent='langchain:partner'), base_retriever=MultiVectorRetriever(vectorstore=<langchain_chroma.vectorstores.Chroma object at 0x7fe2e426b740>, docstore=<langchain_core.stores.InMemoryStore object at 0x7fe2e406bd10>, search_kwargs={'k': 20}))


In [16]:
retriever.retriever()

AttributeError: 'ContextualCompressionRetriever' object has no attribute 'retriever'