In [8]:
%%writefile app_rag_pdf.py

import streamlit as st
import os
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from operator import itemgetter
import pandas as pd
import chromadb
import tempfile

# Load environment variables
load_dotenv()

# Configure page
st.set_page_config(page_title="PDF QA Chatbot", page_icon="ðŸ“š")

# Custom CSS for centered chat input
st.markdown("""
    <style>
    .stChatInput {
        max-width: 800px;
        margin: 0 auto;
    }
    .block-container {
        max-width: 900px;
        padding-left: 2rem;
        padding-right: 2rem;
    }
    </style>
""", unsafe_allow_html=True)

st.title("ðŸ“š PDF QA RAG Chatbot")
st.markdown("Upload PDFs and ask questions about them using Gemini")

# Stream handler for live token updates
class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text
    
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)

# Configure retriever from uploaded PDFs
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
    docs = []
    temp_dir = tempfile.TemporaryDirectory()
    for file in uploaded_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())
        loader = PyMuPDFLoader(temp_filepath)
        docs.extend(loader.load())
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1500,
        chunk_overlap=200
    )
    doc_chunks = text_splitter.split_documents(docs)
    
    embeddings_model = GoogleGenerativeAIEmbeddings(
        model="models/embedding-001",
        google_api_key=os.getenv("GEMINI_API_KEY")
    )
    
    client = chromadb.PersistentClient(path="./chroma_db")
    vectordb = Chroma.from_documents(
        documents=doc_chunks,
        embedding=embeddings_model,
        client=client,
        collection_name="pdf_collection"
    )
    
    retriever = vectordb.as_retriever(search_kwargs={"k": 3})
    return retriever

# File uploader
uploaded_files = st.sidebar.file_uploader(
    label="Upload PDF files",
    type=["pdf"],
    accept_multiple_files=True
)

if not uploaded_files:
    st.info("ðŸ“¤ Please upload PDF documents to continue.")
    st.stop()

# Configure retriever
retriever = configure_retriever(uploaded_files)

# Initialize Gemini model
gemini = ChatGoogleGenerativeAI(
    model="models/gemini-2.5-flash",
    google_api_key=os.getenv("GEMINI_API_KEY"),
    temperature=0.1,
    streaming=True
)

# QA prompt template
qa_template = """
You are a helpful assistant that answers questions based on the provided context from PDF documents.

Context from documents:
{context}

Instructions:
- Answer the question using ONLY the information from the context above
- If the answer is not in the context, clearly state "I cannot find this information in the provided documents"
- Be accurate, concise, and direct in your response
- If relevant, cite specific details from the context to support your answer

Question: {question}

Answer:
"""
qa_prompt = ChatPromptTemplate.from_template(qa_template)

# Format documents function
def format_docs(docs):
    return "\n\n".join([d.page_content for d in docs])

# Create RAG chain
qa_rag_chain = (
    {
        "context": itemgetter("question") | retriever | format_docs,
        "question": itemgetter("question")
    }
    | qa_prompt
    | gemini
)

# Initialize chat history
streamlit_msg_history = StreamlitChatMessageHistory(key="langchain_messages")

# Initial message
if len(streamlit_msg_history.messages) == 0:
    streamlit_msg_history.add_ai_message("Ask me anything about your PDFs!")

# Display chat history
for msg in streamlit_msg_history.messages:
    st.chat_message(msg.type).write(msg.content)

# Chat input
if user_prompt := st.chat_input("Ask a question about your PDFs..."):
    st.chat_message("human").write(user_prompt)
    streamlit_msg_history.add_user_message(user_prompt)
    
    with st.chat_message("ai"):
        # Create placeholder for streaming response
        response_placeholder = st.empty()
        stream_handler = StreamHandler(response_placeholder)
        
        # Get response with streaming
        response = qa_rag_chain.invoke(
            {"question": user_prompt}, 
            config={"callbacks": [stream_handler]}
        )
        
        # Update the placeholder with final response
        response_placeholder.markdown(response.content)
        
        # Add to history
        streamlit_msg_history.add_ai_message(response.content)
        
        # Show sources
        st.markdown("---")
        st.markdown("**ðŸ“„ Sources:**")
        
        retrieved_docs = retriever.invoke(user_prompt)
        sources = []
        source_ids = []
        for d in retrieved_docs:
            metadata = {
                "source": d.metadata["source"],
                "page": d.metadata["page"],
                "content": d.page_content[:200]
            }
            idx = (metadata["source"], metadata["page"])
            if idx not in source_ids:
                source_ids.append(idx)
                sources.append(metadata)
        
        if sources:
            st.dataframe(data=pd.DataFrame(sources[:3]), width=1000)

Overwriting app_rag_pdf.py


In [10]:
!rm -rf chroma_db