# Enhanced RAG + Multi-Index Chroma Conversational Agent

## A conversational AI agent built using LangChain, Chroma, and Google Gemini, capable of retrieval-augmented generation (RAG) across multiple domains with persistent memory. This agent can process PDFs, index them into domain-specific vector stores, and answer user queries using both indexed documents and tools like a calculator.

In [None]:
"""
Enhanced RAG + Multi-Index Chroma Conversational Agent
- Persistent Chroma vector store (multi-collection / domain-specific)
- HuggingFace embeddings
- Google Gemini (ChatGoogleGenerativeAI) as LLM
- *** FIXED: Persistent ConversationBufferMemory (Saves to JSON file) ***
- Tools (Calculator example, RAG tool)
- Single Agent with verbose=True that uses memory, tools, and RAG
- Metadata using LangChain Document objects (richer metadata)
- Incremental indexing via add_documents() with duplicate avoidance
- Query rewriter chain (context optimizer)

Instructions:
- Set GOOGLE_API_KEY in a .env file or environment
- pip install required packages (langchain, langchain-google-genai, chromadb, sentence-transformers, PyMuPDF, python-dotenv)
"""

import os
import hashlib
import json
import gc
from dotenv import load_dotenv
import fitz  # PyMuPDF
from tkinter import Tk
from tkinter.filedialog import askopenfilenames

import re
from typing import List, Dict, Optional

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import Tool, initialize_agent, AgentType
from langchain.chains import LLMChain, RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory, ChatMessageHistory
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema import Document
# *** IMPORTS FOR PERSISTENT MEMORY ***
from langchain.schema.messages import messages_from_dict, messages_to_dict

# ===============================
# Config / Env
# ===============================
load_dotenv()
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.0-flash")

# Local storage
CHROMA_DIR = "./chroma_db"
CACHE_FILE = "./pdf_cache.json"
# *** File for persistent memory ***
MEMORY_FILE = "./conversation_history.json"


# Embeddings model
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")

# RAG settings
DEFAULT_CHUNK_SIZE = 1000
DEFAULT_CHUNK_OVERLAP = 100

# Domain mapping (basic). Extend as needed.
DOMAIN_KEYWORDS = {
    "legal": ["law", "contract", "agreement", "court", "legal"],
    "finance": ["invoice", "payment", "finance", "budget", "tax"],
    "research": ["chapter", "study", "research", "methodology", "results"],
    "general": []
}

# Safety: Ensure keys exist
if GEMINI_API_KEY is None:
    print("⚠️ WARNING: GOOGLE_API_KEY not found in environment. The LLM may fail at runtime if not set.")

# ===============================
# Persistent Memory Functions
# ===============================

def save_memory(memory_object: ConversationBufferMemory):
    """Saves the conversation history to a JSON file."""
    try:
        messages = memory_object.chat_memory.messages
        history_dict = messages_to_dict(messages)
        with open(MEMORY_FILE, "w") as f:
            json.dump(history_dict, f, indent=2)
    except Exception as e:
        print(f"Error saving memory: {e}")


def load_memory() -> ConversationBufferMemory:
    """Loads conversation history from a JSON file or returns a new memory object."""
    if os.path.exists(MEMORY_FILE):
        try:
            with open(MEMORY_FILE, "r") as f:
                history_dict = json.load(f)
            
            messages = messages_from_dict(history_dict)
            # Create a history object and add messages
            message_history = ChatMessageHistory(messages=messages)
            
            # Create the buffer memory using this pre-filled history
            memory = ConversationBufferMemory(
                memory_key="chat_history",
                input_key="input",
                return_messages=True,
                chat_memory=message_history
            )
            print(f"🧠 Loaded persistent memory from {MEMORY_FILE} with {len(messages)} messages.")
            return memory
        except Exception as e:
            print(f"⚠️ Error loading memory file, starting fresh: {e}")
    
    # No file or error loading, start with new empty memory
    print("🧠 Starting with new, empty memory.")
    return ConversationBufferMemory(
        memory_key="chat_history",
        input_key="input",
        return_messages=True
    )

# ===============================
# Globals
# ===============================
_chroma_clients: Dict[str, Chroma] = {}
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
llm = ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=0.2)

# Load memory from file on startup
memory = load_memory()

# Query rewriter chain
rewrite_prompt = PromptTemplate(
    input_variables=["query"],
    template=(
        "Rewrite the user query to be concise and focused for document retrieval. "
        "Keep entities and important keywords but remove chit-chat. Return only the rewritten query.\n\n"
        "User query: {query}\n\n"
        "Rewritten query:"
    )
)
query_rewriter = LLMChain(llm=llm, prompt=rewrite_prompt, verbose=False)

# ===============================
# Helpers: Vector store (Chroma multi-collection)
# ===============================

def ensure_chroma_for_domain(domain: str) -> Chroma:
    """Return a Chroma client bound to a collection (domain). Lazily initializes and caches the client."""
    domain = domain or "general"
    if domain in _chroma_clients:
        return _chroma_clients[domain]

    os.makedirs(CHROMA_DIR, exist_ok=True)
    client = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings, collection_name=domain)
    _chroma_clients[domain] = client
    print(f"✅ Initialized Chroma collection for domain: {domain}")
    return client


def file_hash(file_path: str) -> str:
    sha1 = hashlib.sha1()
    with open(file_path, "rb") as f:
        while chunk := f.read(8192):
            sha1.update(chunk)
    return sha1.hexdigest()


def detect_domain_from_text(text: str) -> str:
    txt = text.lower()
    for domain, kws in DOMAIN_KEYWORDS.items():
        for kw in kws:
            if kw in txt:
                return domain
    return "general"


def detect_domain_from_query(query: str) -> str:
    q = query.lower()
    for domain, kws in DOMAIN_KEYWORDS.items():
        for kw in kws:
            if kw in q:
                return domain
    return "general"


# ===============================
# PDF processing with metadata and incremental indexing
# ===============================

def process_pdf(file_path: str, domain: Optional[str] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> int:
    """Process a PDF by page, create Document objects with rich metadata and add them to the domain collection.
    Returns number of chunks added (0 if unchanged).
    """
    with fitz.open(file_path) as doc:
        sample_text = "\n".join(page.get_text() for page in doc[:min(3, len(doc))])

    if domain is None:
        domain = detect_domain_from_text(sample_text)

    client = ensure_chroma_for_domain(domain)

    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, "r") as f:
            pdf_cache = json.load(f)
    else:
        pdf_cache = {}

    h = file_hash(file_path)
    cached_hash = pdf_cache.get(file_path)

    if cached_hash == h:
        print(f"ℹ️ No changes detected in '{os.path.basename(file_path)}' (domain={domain}), skipping processing.")
        return 0

    docs_to_add: List[Document] = []
    doc = fitz.open(file_path)
    for page_idx, page in enumerate(doc):
        page_text = page.get_text()
        if not page_text.strip():
            continue
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        chunks = text_splitter.split_text(page_text)
        for i, chunk in enumerate(chunks):
            metadata = {
                "source": os.path.basename(file_path),
                "page": page_idx + 1,
                "chunk": i,
                "domain": domain
            }
            docs_to_add.append(Document(page_content=chunk, metadata=metadata))

    doc.close()

    try:
        existing_meta = client.get(include=['metadatas'])
        existing_mds = existing_meta.get('metadatas', [])
    except Exception:
        existing_mds = []

    existing_signatures = set()
    for md in existing_mds:
        if isinstance(md, dict):
            key = (md.get('source'), md.get('page'), md.get('chunk'))
            existing_signatures.add(key)

    filtered_docs = [d for d in docs_to_add if (d.metadata.get('source'), d.metadata.get('page'), d.metadata.get('chunk')) not in existing_signatures]

    if not filtered_docs:
        print(f"ℹ️ All chunks from '{os.path.basename(file_path)}' already indexed in domain '{domain}'.")
        pdf_cache[file_path] = h
        with open(CACHE_FILE, "w") as f:
            json.dump(pdf_cache, f, indent=2)
        return 0

    client.add_documents(filtered_docs)
    client.persist()

    pdf_cache[file_path] = h
    with open(CACHE_FILE, "w") as f:
        json.dump(pdf_cache, f, indent=2)

    print(f"✅ Processed '{os.path.basename(file_path)}' into domain '{domain}' and added {len(filtered_docs)} chunks to vector store.")
    return len(filtered_docs)


def upload_and_process_pdfs():
    Tk().withdraw()
    pdf_files = askopenfilenames(title="Select PDF files", filetypes=[("PDF Files", "*.pdf")])
    if not pdf_files:
        print("❌ No PDFs selected.")
        return

    total = 0
    for p in pdf_files:
        total += process_pdf(p)
    print(f"🔁 Done. Total chunks added: {total}")


def list_stored_pdfs():
    found = {}
    for domain in list(DOMAIN_KEYWORDS.keys()) + ["general"]:
        try:
            client = ensure_chroma_for_domain(domain)
            info = client.get(include=['metadatas', 'documents'])
            mds = info.get('metadatas', [])
            sources = set(md.get('source') for md in mds if isinstance(md, dict) and md.get('source'))
            if sources:
                found[domain] = sources
        except Exception:
            pass

    if not found:
        print("(No PDFs indexed yet)")
        return

    print("📄 Indexed PDFs by domain:")
    for dom, sources in found.items():
        print(f" - {dom}:")
        for s in sources:
            print("    -", s)


# ===============================
# Retrieval (Hybrid) + RAG answer with domain awareness and query rewriter
# ===============================

def hybrid_retrieve(query: str, domain: Optional[str] = None, k: int = 5, keyword_search: bool = True, show_sources: bool = True) -> List[Document]:
    """Performs vector similarity search on the selected domain collection + optional keyword match and returns top-k Documents."""
    try:
        rewritten = query_rewriter.run(query=query).strip()
        if rewritten:
            use_query = rewritten
        else:
            use_query = query
    except Exception:
        use_query = query

    if domain is None:
        domain = detect_domain_from_query(query)

    client = ensure_chroma_for_domain(domain)

    results: List[Document] = []
    try:
        vector_matches = client.similarity_search(use_query, k=k)
    except Exception:
        try:
            retr = client.as_retriever()
            vector_matches = retr.get_relevant_documents(use_query)[:k]
        except Exception as e:
            print("⚠️ Vector search failed:", e)
            vector_matches = []

    added_texts = set()
    for d in vector_matches:
        results.append(d)
        added_texts.add(d.page_content)

    if keyword_search:
        try:
            all_docs = client.get(include=['documents']).get('documents', [])
            for doc_text in all_docs:
                if use_query.lower() in doc_text.lower() and doc_text not in added_texts:
                    results.append(Document(page_content=doc_text, metadata={"source": "keyword_match", "domain": domain}))
                    added_texts.add(doc_text)
        except Exception:
            pass

    if show_sources:
        print(f"\n🔍 Retrieved chunks (domain={domain}):")
        for i, d in enumerate(results[:k], 1):
            meta = getattr(d, "metadata", {}) or {}
            preview = getattr(d, "page_content", str(d))
            print(f" {i}. Source: {meta.get('source','unknown')} | page={meta.get('page','?')} | len={len(preview)} chars")

    return results[:k]


def get_rag_answer(query: str, domain: Optional[str] = None, k: int = 5) -> str:
    """Builds the RAG prompt and queries the LLM. Returns text answer."""
    docs = hybrid_retrieve(query, domain=domain, k=k, show_sources=True)
    context = "\n\n---\n\n".join(d.page_content for d in docs)

    prompt = f"""
You are a helpful assistant. Answer the user's question using ONLY the provided CONTEXT. If the answer is not contained in the context, say "I don't know based on the provided documents." Do NOT hallucinate.

CONTEXT:
{context}

QUESTION:
{query}
"""
    response = llm.invoke(prompt)
    if hasattr(response, "content"):
        if isinstance(response.content, str):
            return response.content
        try:
            return response.content[0].text
        except Exception:
            return str(response.content)
    return str(response)


# ===============================
# Tools (example: calculator and RAG tool)
# ===============================

def calculator_tool(query: str) -> str:
    try:
        if not re.match(r"^[0-9\.\+\-\*\/\(\) \n]+$", query.strip()):
            return "Error: Unsafe characters in expression."
        return str(eval(query))
    except Exception as e:
        return f"Error: {e}"


def rag_tool(query: str) -> str:
    domain = detect_domain_from_query(query)
    return get_rag_answer(query, domain=domain)


calc_tool = Tool(name="Calculator", func=calculator_tool, description="Performs math calculations")
rag_tool_wrapper = Tool(name="RAGRetriever", func=rag_tool, description="Answers questions using the indexed documents (RAG). Use this for any query about PDFs, documents, chapters, summaries, or specific content from a file.")

# ===============================
# Initialize Agent (with tools, memory, verbose=True)
# ===============================
agent_llm = llm

# *** FIX: Added {chat_history} placeholder to the prefix ***
agent_kwargs = {
    "prefix": """You are a friendly conversational assistant.
You have access to the following tools:
1. Calculator: Performs math calculations.
2. RAGRetriever: Answers questions using the indexed documents (RAG). Use this for any query about PDFs, documents, chapters, summaries, or specific content from a file.

You also have access to the conversation history.
If the user is just chatting, respond conversationally.
If the user asks a question, decide if it requires a tool or if you can answer from history.

Here is the conversation history:
{chat_history}
"""
}

agent = initialize_agent(
    tools=[calc_tool, rag_tool_wrapper],
    llm=agent_llm,
    agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    max_iterations=4,
    memory=memory,  # Memory object is now loaded from file
    agent_kwargs=agent_kwargs,
    handle_parsing_errors=True
)

# ===============================
# CLI / Main loop
# ===============================
if __name__ == "__main__":
    print("=== RAG + Agent Conversational App (multi-domain Chroma) ===")
    while True:
        print("\nOptions:\n1) Upload PDFs to index\n2) List indexed PDFs\n3) Ask a question (Agent handles RAG/Tools/Chat)\n4) Run example queries\n5) Clear Memory\n6) Exit")
        choice = input("Choose: ").strip()
        if choice == "1":
            upload_and_process_pdfs()
        elif choice == "2":
            list_stored_pdfs()
        elif choice == "3":
            q = input("Enter question: ")
            try:
                answer = agent.run(q)
                print("\n== Answer ==\n", answer)
            except Exception as e:
                print(f"Error during agent execution: {e}")
            finally:
                # *** FIX: Save memory in 'finally' to ensure it saves even on error ***
                save_memory(memory)
                print("💾 Memory saved.")
        elif choice == "4":
            examples = [
                "Hi, my name is Alex.",
                "What is 25*6?",
                "What is my name?" # Tests memory
            ]
            for ex in examples:
                print("\n>>", ex)
                try:
                    print(agent.run(ex))
                except Exception as e:
                    print(f"Error during agent execution: {e}")
                finally:
                    # *** FIX: Save memory in 'finally' to ensure it saves even on error ***
                    save_memory(memory)
            print("💾 Memory saved for examples.")
        elif choice == "5":
            if os.path.exists(MEMORY_FILE):
                os.remove(MEMORY_FILE)
                print("🗑️ Cleared persistent memory file.")
            memory.clear() # Clear the in-RAM version
            # Re-initialize empty memory
            memory = load_memory()
            # We must re-initialize the agent so it gets the new empty memory object
            agent = initialize_agent(
                tools=[calc_tool, rag_tool_wrapper],
                llm=agent_llm,
                agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
                verbose=True,
                max_iterations=4,
                memory=memory,
                agent_kwargs=agent_kwargs,
                handle_parsing_errors=True
            )
            print("🧠 In-memory history cleared and agent re-initialized.")
        elif choice == "6":
            print("Goodbye.")
            break
        else:
            print("Invalid choice. Try again.")