In [None]:
import os
import torch
import chromadb
from typing import List, Optional, Literal

from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker

from langchain_huggingface import HuggingFaceEmbeddings

from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStoreRetriever

from langchain_chroma import Chroma

from langchain.retrievers.ensemble import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.retrievers import BaseRetriever

from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.document_transformers import LongContextReorder

from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import DocumentCompressorPipeline

from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_ollama.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.language_models.chat_models import BaseChatModel

from langchain_core.runnables import Runnable
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate

LLMProvider = Literal["openai", "gemini", "groq", "ollama"]

In [2]:
def load_document(file_path: str) -> List[Document]:
    """
    Loads a document from the given file path based on its extension (.txt or .pdf).

    Args:
        file_path (str): The path to the document to be loaded.

    Returns:
        List[Document]: A list of loaded LangChain Document objects.
    
    Raises:
        ValueError: If the file extension is not .txt or .pdf.
        FileNotFoundError: If the file is not found at the specified path.
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found at: {file_path}")

    _, file_extension = os.path.splitext(file_path)

    if file_extension.lower() == '.txt':
        loader = TextLoader(file_path, encoding='utf-8')
    elif file_extension.lower() == '.pdf':
        loader = PyPDFLoader(file_path)
    else:
        raise ValueError(f"Unsupported file extension: '{file_extension}'. Only '.txt' and '.pdf' are supported.")
    
    print(f"Loading '{os.path.basename(file_path)}'...")
    
    documents = loader.load()
    
    print("Loading complete.")
    
    return documents


def load_documents_from_directory(directory_path: str) -> List[Document]:
    
    """
    Loads all supported documents (.txt and .pdf) from a specified directory.

    It iterates through all files in the given directory, identifies files
    with '.txt' or '.pdf' extensions, and loads them using the appropriate
    LangChain loader.

    Args:
        directory_path (str): The path to the directory to scan for documents.

    Returns:
        List[Document]: A single list containing all loaded documents from the directory.

    Raises:
        FileNotFoundError: If the specified directory does not exist.
    """
    
    if not os.path.isdir(directory_path):
        raise FileNotFoundError(f"Directory not found at: {directory_path}")

    all_documents = []
    supported_extensions = ['.txt', '.pdf']
    
    print(f"Scanning directory '{directory_path}' for supported files ({', '.join(supported_extensions)})...")

    for filename in os.listdir(directory_path):
        file_path = os.path.join(directory_path, filename)

        _, file_extension = os.path.splitext(filename)
        

        if os.path.isfile(file_path) and file_extension.lower() in supported_extensions:
            
            print(f"  -> Found and loading '{filename}'...")
            
            if file_extension.lower() == '.txt':
                loader = TextLoader(file_path, encoding='utf-8')
            elif file_extension.lower() == '.pdf':
                loader = PyPDFLoader(file_path)
            

            try:
                
                loaded_docs = loader.load()
                all_documents.extend(loaded_docs)
                
            except Exception as e:
                print(f"    [Warning] Failed to load or process {filename}: {e}")

    if not all_documents:
        print("Warning: No supported documents were found in the directory.")
        
    print(f"Directory scan complete. Total documents loaded: {len(all_documents)}")
    
    return all_documents



def split_text(documents: List[Document], method: str = "recursive", **kwargs) -> List[Document]:
    """
    Splits the loaded documents according to the specified method.

    Args:
        documents (List[Document]): A list of LangChain Document objects.
        method (str): The splitting method. Can be "recursive" or "semantic".
        **kwargs: Method-specific arguments.
            - For "recursive": `chunk_size` (int), `chunk_overlap` (int)
            - For "semantic": `embeddings` (an instance of HuggingFaceEmbeddings)

    Returns:
        List[Document]: A list of split text chunks as LangChain Document objects.
    """
    
    print(f"Starting splitting process with '{method}' method...")

    if method == "recursive":
        chunk_size = kwargs.get('chunk_size', 1000)
        chunk_overlap = kwargs.get('chunk_overlap', 100)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        
    elif method == "semantic":
        embeddings = kwargs.get('embeddings')
        
        if embeddings is None:
            raise ValueError("An 'embeddings' model is required for the semantic method.")
        
        text_splitter = SemanticChunker(embeddings)
        
    else:
        raise ValueError(f"Invalid method: '{method}'. Choose 'recursive' or 'semantic'.")

    texts = text_splitter.split_documents(documents)
    
    print(f"The text was split into {len(texts)} chunks.")
    
    return texts

def create_or_load_chroma_retriever(
		persist_directory: str, 
		collection_name: str,
		embeddings: HuggingFaceEmbeddings, 
		documents: Optional[List[Document]] = None,
		search_type: str = "mmr",
		search_kwargs: dict = {"k": 5}
	) -> VectorStoreRetriever:
    
    """
    Creates and persists a new Chroma vector store or loads an existing one.

    This function checks if a collection with the given name already exists in the
    persist_directory. If it does, it loads the store. If not, it creates a
    new one using the provided documents.

    Args:
        persist_directory (str): The directory to save to or load from.
        collection_name (str): The name for the collection within Chroma.
        embeddings (HuggingFaceEmbeddings): The embedding model to use.
        documents (Optional[List[Document]], optional): The list of split documents. 
            Required only if the collection does not already exist. Defaults to None.
        search_type (str, optional): The type of search for the retriever. Defaults to "mmr".
        search_kwargs (dict, optional): Keyword arguments for the search. Defaults to {"k": 5}.

    Returns:
        VectorStoreRetriever: A configured retriever for the Chroma vector store.
        
    Raises:
        ValueError: If the collection does not exist and no documents are provided.
    """

    client = chromadb.PersistentClient(path=persist_directory)
    
    existing_collections = [c.name for c in client.list_collections()]
    
    if collection_name in existing_collections:

        print(f"Collection '{collection_name}' found in '{persist_directory}'. Loading from disk.")
        vectorstore = Chroma(
            persist_directory=persist_directory,
            embedding_function=embeddings,
            collection_name=collection_name
        )
        
    else:
        
        print(f"Collection '{collection_name}' not found. Creating a new one...")
        
        if not documents:
            raise ValueError(
                "Documents must be provided to create a new collection, but 'documents' parameter was None."
            )
        
        vectorstore = Chroma.from_documents(
            documents=documents,
            embedding=embeddings,
            collection_name=collection_name,
            persist_directory=persist_directory
        )
        
        print(f"New collection created and persisted at '{persist_directory}'.")

    print(f"Creating retriever with search_type='{search_type}' and search_kwargs={search_kwargs}.")
    
    retriever = vectorstore.as_retriever(
        search_type=search_type, 
        search_kwargs=search_kwargs
    )
    
    return retriever


def create_bm25_retriever(
		documents: List[Document],
		k: int = 5
	) -> BM25Retriever:
    
    """
    Creates a BM25Retriever for keyword-based search from a list of documents.

    BM25 is a ranking function that scores documents based on the query terms
    appearing in each document, without using semantic understanding.

    Args:
        documents (List[Document]): The list of documents to index for keyword search.
        k (int, optional): The number of documents to retrieve. Defaults to 5.

    Returns:
        BM25Retriever: A configured retriever for keyword search.
    """
    
    print(f"Creating BM25Retriever with k={k}...")
    
    bm25_retriever = BM25Retriever.from_documents(
        documents=documents,
        k=k
    )
    
    print("BM25Retriever created successfully.")
    
    return bm25_retriever



def create_ensemble_retriever(
		retrievers: List[BaseRetriever],
	) -> EnsembleRetriever:
    
    """
    Creates an EnsembleRetriever to combine the results of multiple retrievers.
    It uses Reciprocal Rank Fusion (RRF) to re-rank the combined results,
    providing a more robust final ranking.

    Args:
        retrievers (List[BaseRetriever]): A list of retrievers to combine 
            (e.g., [chroma_retriever, bm25_retriever]).

    Returns:
        EnsembleRetriever: A retriever that combines and re-ranks results.
    """
    
        
    print(f"Creating EnsembleRetriever for {len(retrievers)} retrievers...")
    
    ensemble_retriever = EnsembleRetriever(
        retrievers=retrievers,
        fusion_type="RRF",
    	c=60
    )
    
    
    print("EnsembleRetriever created successfully.")
    
    return ensemble_retriever


def create_compression_retriever(
		base_retriever: BaseRetriever,
		embeddings: HuggingFaceEmbeddings,
		reranker_model_name: str = "BAAI/bge-reranker-large",
		top_n: int = 5,
		similarity_threshold: float = 0.95
	) -> ContextualCompressionRetriever:
    
    """
    Wraps a base retriever with a compression and reranking pipeline.

    This pipeline enhances retrieval results by:
    1. Filtering out redundant documents (semantically similar ones).
    2. Reranking the remaining documents with a powerful Cross-Encoder model for relevance.
    3. Reordering the documents to place the most relevant ones at the beginning and end,
       combating the "lost in the middle" problem for Large Language Models.

    Args:
        base_retriever (BaseRetriever): The retriever to enhance (e.g., an EnsembleRetriever).
        embeddings (HuggingFaceEmbeddings): The embedding model, needed for the redundant filter.
        reranker_model_name (str, optional): The name of the Cross-Encoder model for reranking.
        top_n (int, optional): The number of top documents to return after reranking. Defaults to 5.
        similarity_threshold (float, optional): The threshold for filtering similar documents.
                                                 Defaults to 0.95.

    Returns:
        ContextualCompressionRetriever: The enhanced retriever.
    """
    
    print("Creating advanced compression and reranking pipeline...")


    print(f"Loading reranker model: {reranker_model_name}...")
    
    reranker_model = HuggingFaceCrossEncoder(model_name=reranker_model_name)
    compressor = CrossEncoderReranker(model=reranker_model, top_n=top_n)

    redundant_filter = EmbeddingsRedundantFilter(
        embeddings=embeddings,
        similarity_threshold=similarity_threshold
    )

    reordering = LongContextReorder()

    # The order is important: filter -> rerank -> reorder
    pipeline_compressor = DocumentCompressorPipeline(
        transformers=[redundant_filter, compressor, reordering]
    )
    
    print("Compressor pipeline created successfully.")

    compression_retriever = ContextualCompressionRetriever(
        base_compressor=pipeline_compressor,
        base_retriever=base_retriever
    )
    
    print("ContextualCompressionRetriever created successfully.")
    
    return compression_retriever



def initialize_llm(
		provider: LLMProvider,
		model_name: str = None,
		**kwargs
	) -> BaseChatModel:
    
    """
    Initializes and returns a LangChain Chat Model from a specified provider.

    This function acts as a factory for different LLM providers, loading the
    necessary API keys from environment variables.

    Args:
        provider (Literal["openai", "gemini", "groq", "ollama"]):
            The LLM provider to use.
        model_name (str, optional): The specific model to use from the provider.
            If None, a sensible default will be used for each provider.
        **kwargs: Additional keyword arguments to pass to the model's constructor
                  (e.g., temperature=0.7, max_tokens=1024).

    Returns:
        BaseChatModel: An instance of the requested LangChain chat model.

    Raises:
        ValueError: If an unsupported provider is specified or if the required
                    API key environment variable is not set.
    """
    
    provider = provider.lower()
    
    print(f"Initializing LLM from provider: '{provider}'...")

    if provider == "gemini":
        
        if not os.getenv("GOOGLE_API_KEY"):
            raise ValueError("GOOGLE_API_KEY environment variable not set.")
        
        model = model_name or "gemini-2.5-flash"
        
        return ChatGoogleGenerativeAI(model=model, **kwargs)
    
    elif provider == "openai":
        
        if not os.getenv("OPENAI_API_KEY"):
            raise ValueError("OPENAI_API_KEY environment variable not set.")
        
        model = model_name or "gpt-4o"
        
        return ChatOpenAI(model=model, **kwargs)

    elif provider == "groq":
        
        if not os.getenv("GROQ_API_KEY"):
            raise ValueError("GROQ_API_KEY environment variable not set.")
        
        model = model_name or "meta-llama/llama-4-scout-17b-16e-instruct"
        
        return ChatGroq(model_name=model, **kwargs)


    elif provider == "ollama":
        
        model = model_name or "llama3"
        
        print(f"Note: Ensure the Ollama service is running and you have pulled the '{model}' model.")
        
        return ChatOllama(model=model, **kwargs)

    else:
        
        raise ValueError(
            f"Unsupported LLM provider: '{provider}'. "
            "Supported providers are 'openai', 'gemini', 'groq', 'ollama'."
        )


def create_rag_chain(
    retriever: BaseRetriever, 
    llm: BaseChatModel, 
    prompt_template: Optional[str] = None
	) -> Runnable:
    
    """
    Creates a Retrieval-Augmented Generation (RAG) chain.

    This chain orchestrates the entire process:
    1. It takes a user's question.
    2. It uses the provided retriever to fetch relevant documents.
    3. It stuffs the documents and the question into a prompt.
    4. It sends the prompt to the LLM to generate an answer.

    Args:
        retriever (BaseRetriever): The configured retriever instance 
            (e.g., the final compression retriever).
        llm (BaseChatModel): The initialized language model.
        prompt_template (Optional[str], optional): A custom system prompt template string.
            Must include a '{context}' placeholder. If None, a default prompt is used.

    Returns:
        Runnable: A LangChain runnable object that can be invoked with a query.
                  The output is a dictionary containing "input", "context", and "answer".
    """
    
    print("Creating the final RAG chain...")


    if prompt_template is None:
        
        prompt_template = (
            "You are an assistant for question-answering tasks. "
            "Use the following pieces of retrieved context to answer the question. "
            "If you don't know the answer, just say that you don't know. "
            "Keep the answer concise and based ONLY on the provided context.\n\n"
            "CONTEXT:\n{context}"
        )

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", prompt_template),
            ("human", "{input}"),
        ]
    )


    document_chain = create_stuff_documents_chain(llm, prompt)

    retrieval_chain = create_retrieval_chain(retriever, document_chain)
    
    print("RAG chain created successfully.")
    
    return retrieval_chain

In [None]:
if __name__ == "__main__":
    
    test_file_path = "./data/test.pdf"
    embeddings_model = None
    
    try:

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using '{device}' for the embedding model...")
        
        embeddings_model = HuggingFaceEmbeddings(
            model_name="BAAI/bge-large-en-v1.5",
            model_kwargs={'device': device},
            encode_kwargs={'normalize_embeddings': True}
        )
        
        print("Embedding model loaded successfully.")
        
    except Exception as e:
        print(f"An error occurred while loading the embedding model: {e}")
        print("Semantic splitting test will be skipped.")



    try:

        loaded_documents = load_document(test_file_path)
        
        print(f"A total of {len(loaded_documents)} document(s) were loaded.\n")


        print("--- Recursive Method Test ---")
        
        recursive_chunks = split_text(
            loaded_documents, 
            method="recursive", 
            chunk_size=1000, 
            chunk_overlap=100
        )
        
        print("\nFirst recursive chunk:\n", recursive_chunks[0].page_content)
        print("-" * 25)

        
        if embeddings_model:
            print("\n--- Semantic Method Test ---")
            
            semantic_chunks = split_text(
                loaded_documents, 
                method="semantic", 
                embeddings=embeddings_model
            )
            

            print("\nFirst semantic chunk:\n", semantic_chunks[0].page_content)
            print("-" * 25)



        db_directory = "./chroma_db_store"
        collection = "test_collection"
        
        vector_retriever = create_or_load_chroma_retriever(
			persist_directory=db_directory,
			collection_name=collection,
			embeddings=embeddings_model,
			documents=semantic_chunks,
   			search_type = "mmr",
			search_kwargs={"k": 5}
		)
        
        keyword_retriever = create_bm25_retriever(
			documents=semantic_chunks,
			k=5
		)
        
        ensemble_retriever = create_ensemble_retriever(
			retrievers=[vector_retriever, keyword_retriever]
		)
        
        final_retriever = create_compression_retriever(
			base_retriever=ensemble_retriever,
			embeddings=embeddings_model,
			top_n=5
		)
        
        print("\n" + "="*30 + "\n")
        
        query = "What is a Semantic Analysis?"
        
        print(f"Testing FINAL retriever with query: '{query}'")
        
        final_docs = final_retriever.invoke(query)
        
        print("\n--- Final Retrieved & Reranked Documents ---")
        
        for i, doc in enumerate(final_docs):
            print(f"Document {i+1}: {doc.page_content} (Source: {doc.metadata.get('source')})")
            
        
        simple_query = "Hello! Is there anybody there?"
        
        if os.getenv("OPENAI_API_KEY"):
            
            try:
                
                print("\n--- Testing OpenAI ---")
                llm_openai = initialize_llm(
                    "openai", 
                    model_name="gpt-4o", 
                    temperature=0
                )
                
                response = llm_openai.invoke(simple_query)
                
                print("OpenAI Response:\n", response.content)
            
            except Exception as e:
                print(f"Error testing OpenAI: {e}")
                
        else:
            print("\nSkipping OpenAI test: OPENAI_API_KEY not set.")
        
        if os.getenv("GOOGLE_API_KEY"):
            
            try:
                
                print("\n--- Testing Google Gemini ---")
                
                llm_gemini = initialize_llm(
					"gemini", 
					model_name="gemini-2.5-flash", 
					temperature=0
                )
                
                response = llm_gemini.invoke(simple_query)
                
                print("Gemini Response:\n", response.content)
            
            except Exception as e:
                print(f"Error testing Gemini: {e}")
        
        else:
            print("\nSkipping Gemini test: GOOGLE_API_KEY not set.")
            
        if os.getenv("GROQ_API_KEY"):
            
            try:
                
                print("\n--- Testing Groq ---")
                
                llm_groq = initialize_llm(
					"groq", 
					temperature=0
                )
                
                response = llm_groq.invoke(simple_query)
                print("Groq Response:\n", response.content)
            
            except Exception as e:
                print(f"Error testing Groq: {e}")
                
        else:
            print("\nSkipping Groq test: GROQ_API_KEY not set.")
            
        try:
            
            print("\n--- Testing Ollama ---")
            
            llm_ollama = initialize_llm(
				"ollama", 
				model_name="llama3", 
				temperature=0
            )
            
            response = llm_ollama.invoke(simple_query)
            
            print("Ollama Response:\n", response.content)
            
        except Exception as e:
            print(f"Could not connect to Ollama. Please ensure the service is running and the model is pulled.")
            print(f"Error: {e}")
        
        
        rag_chain = create_rag_chain(
			retriever=final_retriever,
			llm=llm_gemini
   		)
        
        print("\n" + "="*30 + "\n")
        
        query = "What is Lexical Analysis?"
        print(f"Invoking chain with query: '{query}'")
        
        response = rag_chain.invoke({"input": query})
        
        print("\n--- Full Response Dictionary ---")
        print(response)
        
        print("\n--- Retrieved Context ---")
        
        for doc in response["context"]:
            print(doc.page_content)
            
        print("\n--- Final Answer ---")
        print(response["answer"])

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"An error occurred during the workflow: {e}")
        
        