In [10]:
# if doing clean extraction from database, run this
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from unstructured.partition.auto import partition
from langchain.schema import Document
from IPython.display import display, HTML
import os
import pandas as pd
import gradio as gr
import pickle
#from langchain.document_loaders import DirectoryLoader
#from langchain_community.document_loaders import TextLoader

# 1. Document Loading and Page Tracking
docs = []
doc_folder = r'C:\Users\Ian\Documents\RAG\B1-B data'
for filename in os.listdir(doc_folder):
    filepath = os.path.join(doc_folder, filename)
    if os.path.isfile(filepath):
        elements = partition(filename=filepath)
        for i, element in enumerate(elements):
            # Extract text content and page information
            text = str(element) 
            page_number = element.metadata.page_number if element.metadata.page_number else 'N/A'  # Extract page info
            docs.append({"source": filename, "content": text, "page": page_number})

# 2. Chunking while Preserving Page Information
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=300)
all_splits = []
current_chunk = ""
current_metadata = {} 

for doc in docs:
    splits = text_splitter.split_text(doc['content'])
    for split in splits:
        if len(current_chunk) + len(split) <= 4096: 
            current_chunk += split + " " # Add to the current chunk
            current_metadata = {"source": doc['source'], "page": doc['page']} 
        else:
            all_splits.append(Document(page_content=current_chunk, metadata=current_metadata))
            current_chunk = split + " "
            current_metadata = {"source": doc['source'], "page": doc['page']}

# Append the last chunk
if current_chunk:
    all_splits.append(Document(page_content=current_chunk, metadata=current_metadata)) 

#save docs to pickle file
with open(r'C:\Users\Ian\Documents\RAG\B-1B data processed\docs.pkl', 'wb') as f:
    pickle.dump(docs, f)

# 3. Vectorstore and Retriever Setup
model = OllamaEmbeddings(model="nomic-embed-text")
vectorstore = Chroma.from_documents(all_splits, model, collection_name="B-1B", persist_directory=r"C:\Users\Ian\Documents\RAG\B-1B data processed")
llm = ChatOllama(model="llama3.1:8b")  # Or your preferred LLM

In [5]:
# if docs.pkl already exists, run this -- run this in demo
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from unstructured.partition.auto import partition
import pandas as pd
import gradio as gr
import pickle
# Load docs from pickle file
with open(r'C:\Users\Ian\Documents\RAG\B-1B data processed\docs.pkl', 'rb') as f:
    docs = pickle.load(f)

# Optional: Display the loaded docs to verify
print(f"Loaded {len(docs)} documents.")

# Chunking while Preserving Page Information
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=300)
all_splits = []
current_chunk = ""
current_metadata = {}

for doc in docs:
    splits = text_splitter.split_text(doc['content'])
    for split in splits:
        if len(current_chunk) + len(split) <= 4096:
            current_chunk += split + " "
            current_metadata = {"source": doc['source'], "page": doc['page']}
        else:
            all_splits.append(Document(page_content=current_chunk, metadata=current_metadata))
            current_chunk = split + " "
            current_metadata = {"source": doc['source'], "page": doc['page']}

# Append the last chunk
if current_chunk:
    all_splits.append(Document(page_content=current_chunk, metadata=current_metadata))

# Vectorstore and Retriever Setup
embedding_function = OllamaEmbeddings(model="nomic-embed-text")
vectorstore = Chroma.from_documents(all_splits, embedding_function, collection_name="B-1B", persist_directory=r'C:\Users\Ian\Documents\RAG\B-1B data processed\vectorDB_B-1B')
llm = ChatOllama(model="llama3.1:8b")  # Or your preferred LLM

Loaded 28689 documents.


In [2]:
# 4. RAG Function (Incorporating Contextual Compression)
def RAG(user_prompt, llm, vectorstore, stream=False, source_summaries=False, retrieval = 'contextual', top_k_hits = 5):
    # retrieval methods: contextual, cosine_similarity, both
    def format_docs(docs):
        return "\n\n".join(
            f"Source: {doc.metadata['source']} - Page: {doc.metadata.get('page', 'N/A')}\n\n{doc.page_content}" 
            for doc in docs
        )

    RAG_TEMPLATE = """
    This is a chat between a user and an AI assistant. The assistant provides detailed, polite answers based on the context provided.

    Definitions:
    - 'Embedded' content is directly from the PDF.
    - 'Predicted' content is generated by an OCR model and may contain inaccuracies (e.g., spelling, spacing, symbols).

    Each document is tracked by its source and page number. When answering, cite the source document and page number in parentheses, like this: ([Source: document_name, Page X]). If the answer comes from multiple chunks of the same document, reference all relevant chunks.

    Indicate when information cannot be found in the context.

    <context>
    {context}
    </context>

    Answer the following question:

    {question}
    """
    question = user_prompt
    rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
    retriever = vectorstore.as_retriever()
    if retrieval == 'contextual' or retrieval == 'both':
        compressor = LLMChainExtractor.from_llm(llm) 
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=retriever
        )
        qa_chain = (
            {"context": compression_retriever | format_docs, "question": RunnablePassthrough()} 
            | rag_prompt
            | llm
            | StrOutputParser() 
        )
        
        docs = compression_retriever.invoke(question)  # Invoke on the question 
    if retrieval == 'cosine similarity' or retrieval == 'both':
        if retrieval != 'both':
            qa_chain = (
                {"context": retriever | format_docs, "question": RunnablePassthrough()}
                | rag_prompt
                | llm
                | StrOutputParser()
            )
        if retrieval == 'both':
            docs.extend(vectorstore.similarity_search(question, k = top_k_hits))
        else:
            docs = vectorstore.similarity_search(question, k = top_k_hits)

    if not docs:
        return "No relevant documents found", pd.DataFrame()

    source_data = []
    for doc in docs:
        source_data.append({
            "source": doc.metadata['source'], 
            "end page": doc.metadata.get('page', 'N/A'),
            "content": doc.page_content 
        })

    if source_summaries:
        summaries = [llm.invoke(f'Summarize this in one or two sentences. Only state main point, nothing else. <{doc.page_content}> ').content for doc in docs]
        source_df = pd.DataFrame(source_data)
        source_df["short summary"] = summaries
    else:
        source_df = pd.DataFrame(source_data)

    if stream:
        for chunk in qa_chain.stream(question):
            print(chunk, end="", flush=True)
        return '', source_df
    else:
        result = qa_chain.invoke(question)
        return result, source_df


In [8]:
def RAG_gradio(user_prompt, retrieval_method, top_k_hits=5): 
    result, sources_df = RAG(user_prompt, llm, vectorstore, retrieval=retrieval_method, top_k_hits=top_k_hits)

    # Format source information for display
    root_dir = doc_folder
    doc_names = sources_df['source'].values
    sources_df['source'] = sources_df['source'].apply(lambda x: root_dir + '\\' + x)
    # Return the result and HTML representation of the DataFrame
    return result, sources_df.to_html(escape=False)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("### B-1B Database Q&A")

    with gr.Row():
        user_input = gr.Textbox(label="Enter your question:")
        retrieval_choice = gr.Radio(
            choices=["contextual", "cosine similarity", "both"],
            label="Retrieval Method:",
            value="cosine similarity"
        )
        
        top_k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Cosine Similarity Number of Documents:")
        
    output = gr.Markdown(label="Answer:")  # Set interactive=False for the answer
    source_output = gr.HTML(label = 'Source Information')  # Changed to gr.HTML to display HTML content

    # Function to update slider visibility based on the retrieval method
    def update_slider(method):
        if method == "contextual":
            return gr.update(visible=True, value=1), gr.update(label="Not Applicable to Contexual Retrieval")  # Hide the slider
        elif method == "cosine similarity":
            return gr.update(visible=True, value=5), gr.update(label="Cosine Similarity Number of Documents:")  # Show slider with new label
        elif method == "both":
            return gr.update(visible=True, value=5), gr.update(label="Cosine Similarity Number of Documents (not contextual):")  # Show slider with new label

    # Add a change event to the retrieval_choice radio button
    retrieval_choice.change(update_slider, inputs=retrieval_choice, outputs=[top_k_slider, top_k_slider])

    def show_thinking():
        return "Processing your request..."  # Placeholder to indicate "thinking"
    
    btn = gr.Button("Submit")
    btn.click(
        fn=show_thinking,  # First show "thinking" message
        inputs=None,
        outputs=output
    )
    
    # Simulate query after showing "thinking" state
    btn.click(
        fn=RAG_gradio, 
        inputs=[user_input, retrieval_choice, top_k_slider],
        outputs=[output, source_output]
    )

demo.launch(share = False)#inbrowser = False, #opens gradio in interface in browser when ran (T/F)
            #inline = False, #Shows on external window or not (T/F)
            #share = False, #make a public website (T/F)
            #auth = None) #to add a password, use this argument: auth=("admin", "pass1234")

* Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


