<a href="https://colab.research.google.com/github/lookmohan/Simple-RAG-Assistant/blob/main/RAG_Implementation_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Required Libraries

In [None]:
!pip install -q chromadb sentence-transformers langchain langchain-groq langchain-community langchain-google-genai pypdf docx2txt

# Import Libraries and Setup

In [None]:
import os
import chromadb
import json
import pickle
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_text_splitters import RecursiveCharacterTextSplitter
from google.colab import files
import glob
import sys
from datetime import datetime

In [None]:
from google.colab import userdata
GROQ_API_KEY = userdata.get('grok_api_key')

# Storage configuration
STORAGE_DIR = "./rag_storage"
CHROMA_DB_PATH = "./chroma_db"
CONVERSATION_HISTORY_FILE = f"{STORAGE_DIR}/conversation_history.json"
METADATA_FILE = f"{STORAGE_DIR}/metadata.json"

# Create storage directories
os.makedirs(STORAGE_DIR, exist_ok=True)
os.makedirs(CHROMA_DB_PATH, exist_ok=True)

print("‚úÖ Configuration loaded!")
print(f"üìÅ Storage directory: {STORAGE_DIR}")
print(f"üìÅ Vector DB directory: {CHROMA_DB_PATH}")

In [None]:
def safe_print(text):
    """Safe printing function to handle Unicode errors."""
    try:
        print(text)
    except UnicodeEncodeError:
        encoding = sys.stdout.encoding or 'ascii'
        encoded_text = text.encode(encoding, errors='replace')
        print(encoded_text.decode(encoding, errors='replace'))
    except Exception as e:
        print(f"[Error in safe_print]: Could not print message due to: {e}")


def save_json(data: Any, filepath: str):
    """Save data to JSON file."""
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
        return True
    except Exception as e:
        safe_print(f"‚ùå Error saving to {filepath}: {e}")
        return False


def load_json(filepath: str) -> Any:
    """Load data from JSON file."""
    try:
        if os.path.exists(filepath):
            with open(filepath, 'r', encoding='utf-8') as f:
                return json.load(f)
        return None
    except Exception as e:
        safe_print(f"‚ùå Error loading from {filepath}: {e}")
        return None

# Vector Database Class with Persistence

In [None]:
class VectorDB:
    """Vector database with persistent storage using ChromaDB."""

    def __init__(self, collection_name: str = "rag_documents",
                 embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
                 persist_directory: str = CHROMA_DB_PATH):
        """Initialize the vector database with persistence."""
        self.collection_name = collection_name
        self.embedding_model_name = embedding_model
        self.embedding_model = SentenceTransformer(embedding_model)
        self.persist_directory = persist_directory

        # Initialize ChromaDB client with persistence
        self.client = chromadb.PersistentClient(path=self.persist_directory)

        # Get or create collection
        self.collection = self.client.get_or_create_collection(
            name=self.collection_name,
            metadata={"description": "RAG document collection"}
        )

        # Check if collection has existing data
        existing_count = self.collection.count()
        safe_print(f"‚úÖ Vector database initialized: {self.collection_name}")
        if existing_count > 0:
            safe_print(f"üìö Found {existing_count} existing document chunks in storage")

    def chunk_text(self, text: str, chunk_size=2000, chunk_overlap=200) -> List[str]:
        """Split text into chunks."""
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
            separators=['\n\n', '\n', ' ', '']
        )
        chunks = text_splitter.split_text(text)
        return chunks

    def add_documents(self, documents: List[Dict[str, Any]]) -> None:
        """Add documents to the vector database with persistence."""
        safe_print(f"üìÑ Processing {len(documents)} documents...")

        all_chunks = []
        all_metadatas = []
        all_ids = []

        for doc_idx, doc in enumerate(documents):
            content = doc.get('content', '')
            metadata = doc.get('metadata', {})

            # Chunk the document
            chunks = self.chunk_text(content)

            # Create metadata and IDs for each chunk
            for chunk_idx, chunk in enumerate(chunks):
                all_chunks.append(chunk)

                chunk_metadata = metadata.copy()
                chunk_metadata['chunk_index'] = chunk_idx
                chunk_metadata['doc_index'] = doc_idx
                chunk_metadata['added_date'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                all_metadatas.append(chunk_metadata)

                # Create UNIQUE id with timestamp to avoid conflicts
                timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
                all_ids.append(f'doc_{doc_idx}_chunk_{chunk_idx}_{timestamp}')

        if all_chunks:
            self.collection.add(
                documents=all_chunks,
                metadatas=all_metadatas,
                ids=all_ids
            )
            safe_print(f"‚úÖ Added {len(all_chunks)} chunks to persistent vector database")
        else:
            safe_print("‚ö†Ô∏è No chunks to add")

    def search(self, query: str, n_results: int = 5) -> Dict[str, Any]:
        """Search for relevant documents."""
        query_embedding = self.embedding_model.encode([query]).tolist()
        results = self.collection.query(
            query_embeddings=query_embedding,
            n_results=n_results
        )
        return results

    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the vector database."""
        count = self.collection.count()
        return {
            "total_chunks": count,
            "collection_name": self.collection_name,
            "persist_directory": self.persist_directory
        }

    def clear_database(self):
        """Clear all documents from the database."""
        try:
            self.client.delete_collection(name=self.collection_name)
            self.collection = self.client.get_or_create_collection(
                name=self.collection_name,
                metadata={"description": "RAG document collection"}
            )
            safe_print("‚úÖ Vector database cleared!")
            return True
        except Exception as e:
            safe_print(f"‚ùå Error clearing database: {e}")
            return False


print("‚úÖ VectorDB class defined with persistence!")


# Document Loading Functions

In [None]:
def load_documents_from_folder(folder_path: str = './data') -> List[Dict[str, Any]]:
    """Load documents from a folder (supports .txt, .pdf, .docx)."""
    documents = []

    if not os.path.exists(folder_path):
        safe_print(f"üìÅ Creating folder: {folder_path}")
        os.makedirs(folder_path, exist_ok=True)
        return documents

    # Load .txt files
    for file_path in glob.glob(f"{folder_path}/*.txt"):
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                documents.append({
                    'content': content,
                    'metadata': {
                        'source': os.path.basename(file_path),
                        'type': 'txt',
                        'full_path': file_path
                    }
                })
                safe_print(f"‚úÖ Loaded: {os.path.basename(file_path)}")
        except Exception as e:
            safe_print(f"‚ùå Error loading {file_path}: {e}")

    # Load .pdf files
    try:
        from langchain_community.document_loaders import PyPDFLoader
        for file_path in glob.glob(f"{folder_path}/*.pdf"):
            try:
                loader = PyPDFLoader(file_path)
                pages = loader.load()
                content = "\n\n".join([page.page_content for page in pages])
                documents.append({
                    'content': content,
                    'metadata': {
                        'source': os.path.basename(file_path),
                        'type': 'pdf',
                        'full_path': file_path,
                        'pages': len(pages)
                    }
                })
                safe_print(f"‚úÖ Loaded: {os.path.basename(file_path)} ({len(pages)} pages)")
            except Exception as e:
                safe_print(f"‚ùå Error loading {file_path}: {e}")
    except ImportError:
        safe_print("‚ö†Ô∏è PyPDFLoader not available, skipping PDF files")

    # Load .docx files
    try:
        import docx2txt
        for file_path in glob.glob(f"{folder_path}/*.docx"):
            try:
                content = docx2txt.process(file_path)
                documents.append({
                    'content': content,
                    'metadata': {
                        'source': os.path.basename(file_path),
                        'type': 'docx',
                        'full_path': file_path
                    }
                })
                safe_print(f"‚úÖ Loaded: {os.path.basename(file_path)}")
            except Exception as e:
                safe_print(f"‚ùå Error loading {file_path}: {e}")
    except ImportError:
        safe_print("‚ö†Ô∏è docx2txt not available, skipping DOCX files")

    return documents


print("‚úÖ Document loading functions defined!")


# Conversation Memory Class with Persistence

In [None]:
class ConversationMemory:
    """Manages conversation history with persistent storage."""

    def __init__(self, max_history: int = 50, storage_file: str = CONVERSATION_HISTORY_FILE):
        """Initialize conversation memory with persistence."""
        self.max_history = max_history
        self.storage_file = storage_file
        self.history = []
        self.session_start = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # Load existing history if available
        self.load_history()

    def add_exchange(self, question: str, answer: str):
        """Add a Q&A exchange to history and save to storage."""
        exchange = {
            'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'question': question,
            'answer': answer
        }
        self.history.append(exchange)

        # Keep only the last max_history exchanges
        if len(self.history) > self.max_history:
            self.history = self.history[-self.max_history:]

        # Save to persistent storage
        self.save_history()

    def get_formatted_history(self, num_exchanges: int = 5) -> str:
        """Get formatted conversation history."""
        if not self.history:
            return "No previous conversation."

        recent_history = self.history[-num_exchanges:]

        formatted = "CONVERSATION HISTORY:\n" + "=" * 50 + "\n"
        for i, exchange in enumerate(recent_history, 1):
            formatted += f"\nExchange {i} ({exchange['timestamp']}):\n"
            formatted += f"User: {exchange['question']}\n"
            formatted += f"Assistant: {exchange['answer'][:200]}{'...' if len(exchange['answer']) > 200 else ''}\n"
            formatted += "-" * 50 + "\n"

        return formatted

    def get_all_questions(self) -> List[str]:
        """Get all questions from history."""
        return [exchange['question'] for exchange in self.history]

    def save_history(self):
        """Save conversation history to persistent storage."""
        data = {
            'session_start': self.session_start,
            'last_updated': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            'total_exchanges': len(self.history),
            'history': self.history
        }
        save_json(data, self.storage_file)

    def load_history(self):
        """Load conversation history from persistent storage."""
        data = load_json(self.storage_file)
        if data and 'history' in data:
            self.history = data['history']
            safe_print(f"üìú Loaded {len(self.history)} previous conversation exchanges")
            if 'session_start' in data:
                safe_print(f"üìÖ Previous session started: {data['session_start']}")

    def clear(self):
        """Clear conversation history."""
        self.history = []
        self.session_start = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.save_history()

    def export_history(self, filepath: str = None):
        """Export conversation history to a text file."""
        if not filepath:
            filepath = f"{STORAGE_DIR}/conversation_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write("RAG ASSISTANT - CONVERSATION HISTORY\n")
                f.write("=" * 60 + "\n\n")
                f.write(f"Session Start: {self.session_start}\n")
                f.write(f"Total Exchanges: {len(self.history)}\n")
                f.write("=" * 60 + "\n\n")

                for i, exchange in enumerate(self.history, 1):
                    f.write(f"Exchange {i} - {exchange['timestamp']}\n")
                    f.write(f"User: {exchange['question']}\n\n")
                    f.write(f"Assistant: {exchange['answer']}\n\n")
                    f.write("-" * 60 + "\n\n")

            safe_print(f"‚úÖ Conversation history exported to: {filepath}")
            return filepath
        except Exception as e:
            safe_print(f"‚ùå Error exporting history: {e}")
            return None


print("‚úÖ ConversationMemory class defined with persistence!")


# RAG Assistant Class with Full Persistence

In [None]:
class RAGAssistant:
    """Advanced RAG-based AI assistant with full persistence."""

    def __init__(self, api_key: str):
        """Initialize the RAG assistant with Groq and persistence."""
        self.llm = ChatGroq(
            groq_api_key=api_key,
            model="llama-3.3-70b-versatile",
            temperature=0.0
        )

        self.vector_db = VectorDB()
        self.memory = ConversationMemory(max_history=50)
        self.metadata = self._load_metadata()

        template = """You are an advanced AI assistant that provides accurate, helpful answers based on provided documents and conversation history.

CORE PRINCIPLES:
- Answer based on the provided documents AND conversation history
- Remember previous questions and answers in this conversation
- Be clear, concise, and accurate
- Cite sources when relevant
- If information isn't in the documents, say so honestly
- If asked about previous conversation, refer to the conversation history

RESPONSE GUIDELINES:
- For definitions/concepts: Use clear explanations with examples
- For comparisons: Use structured format
- For procedures: Provide step-by-step instructions
- For questions about previous conversation: Use the conversation history
- For general questions: Keep answers concise and well-organized

{conversation_history}

CONTEXT FROM DOCUMENTS:
{context}

CURRENT USER QUESTION:
{question}

ANSWER:"""

        self.prompt_template = ChatPromptTemplate.from_template(template)
        self.chain = self.prompt_template | self.llm | StrOutputParser()

        safe_print("‚úÖ RAG Assistant initialized with Groq (Llama 3.3 70B)")
        safe_print("‚úÖ Conversation memory enabled with persistence")
        safe_print("‚úÖ Vector database persistent storage active")

    def _load_metadata(self) -> Dict[str, Any]:
        """Load or create system metadata."""
        metadata = load_json(METADATA_FILE)
        if not metadata:
            metadata = {
                'created_date': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                'total_documents_processed': 0,
                'total_queries': 0,
                'document_list': []
            }
            save_json(metadata, METADATA_FILE)
        return metadata

    def _update_metadata(self, **kwargs):
        """Update system metadata."""
        self.metadata.update(kwargs)
        self.metadata['last_updated'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        save_json(self.metadata, METADATA_FILE)

    def add_documents(self, documents: List[Dict[str, Any]]) -> None:
        """Add documents to the knowledge base with metadata tracking."""
        self.vector_db.add_documents(documents)

        doc_names = [doc['metadata']['source'] for doc in documents]
        self.metadata['total_documents_processed'] += len(documents)
        self.metadata['document_list'].extend(doc_names)
        self._update_metadata()

    def query(self, question: str, n_results: int = 3) -> str:
        """Query the RAG assistant with conversation context."""
        self.metadata['total_queries'] += 1
        self._update_metadata()

        history_keywords = ['first question', 'previous', 'earlier', 'before', 'what did i ask',
                           'conversation', 'history', 'last question', 'my question']
        is_history_question = any(keyword in question.lower() for keyword in history_keywords)

        search_results = self.vector_db.search(question, n_results=n_results)
        context_chunks = search_results.get('documents', [[]])[0]
        metadatas = search_results.get('metadatas', [[]])[0]

        formatted_context = ""
        sources = []

        for i, (chunk, meta) in enumerate(zip(context_chunks, metadatas)):
            source = meta.get('source', 'Unknown')
            sources.append(source)
            formatted_context += f"\n[Source: {source}]\n{chunk}\n{'-' * 40}\n"

        conversation_history = self.memory.get_formatted_history(num_exchanges=5)

        answer = self.chain.invoke({
            "conversation_history": conversation_history,
            "context": formatted_context,
            "question": question
        })

        unique_sources = list(set(sources))
        if unique_sources and "source" not in answer.lower() and not is_history_question:
            answer += f"\n\nüìö Sources: {', '.join(unique_sources)}"

        self.memory.add_exchange(question, answer)
        return answer

    def get_conversation_summary(self) -> str:
        """Get a summary of the conversation."""
        if not self.memory.history:
            return "No conversation history yet."

        all_questions = self.memory.get_all_questions()
        summary = f"üìä CONVERSATION SUMMARY\n{'=' * 50}\n"
        summary += f"Total questions: {len(all_questions)}\n"
        summary += f"Session started: {self.memory.session_start}\n{'=' * 50}\n\n"
        summary += "Questions:\n" + "\n".join([f"{i+1}. {q}" for i, q in enumerate(all_questions)])
        return summary

    def get_system_stats(self) -> str:
        """Get system statistics."""
        db_stats = self.vector_db.get_stats()
        stats = f"üìä SYSTEM STATISTICS\n{'=' * 50}\n"
        stats += f"Documents processed: {self.metadata['total_documents_processed']}\n"
        stats += f"Queries answered: {self.metadata['total_queries']}\n"
        stats += f"Vector DB chunks: {db_stats['total_chunks']}\n"
        stats += f"Conversation exchanges: {len(self.memory.history)}\n"
        stats += f"System created: {self.metadata['created_date']}\n"
        stats += f"Last updated: {self.metadata.get('last_updated', 'N/A')}\n{'=' * 50}\n"
        return stats

    def clear_history(self):
        """Clear conversation history."""
        self.memory.clear()
        safe_print("üóëÔ∏è Conversation history cleared!")

    def export_conversation(self):
        """Export conversation history."""
        return self.memory.export_history()

    def reset_system(self):
        """Reset the entire system."""
        confirm = input("‚ö†Ô∏è This will delete ALL data. Type 'YES' to confirm: ")
        if confirm == "YES":
            self.vector_db.clear_database()
            self.memory.clear()
            self.metadata = {
                'created_date': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                'total_documents_processed': 0,
                'total_queries': 0,
                'document_list': []
            }
            self._update_metadata()
            safe_print("‚úÖ System reset complete!")
        else:
            safe_print("‚ùå Reset cancelled")


print("‚úÖ RAGAssistant class fully defined!")


# Upload Documents

In [None]:
safe_print("\n" + "=" * 60)
safe_print("üì§ DOCUMENT UPLOAD")
safe_print("=" * 60)
safe_print("Upload your documents (txt, pdf, docx)\n")

os.makedirs('data', exist_ok=True)
uploaded = files.upload()

if uploaded:
    for filename in uploaded.keys():
        with open(f'data/{filename}', 'wb') as f:
            f.write(uploaded[filename])
    safe_print(f"\n‚úÖ Uploaded {len(uploaded)} file(s)")
else:
    safe_print("\n‚ö†Ô∏è No files uploaded")

## Initialize RAG System

In [None]:
safe_print("\n" + "=" * 60)
safe_print("ü§ñ INITIALIZING RAG SYSTEM")
safe_print("=" * 60)

if GROQ_API_KEY == "your_groq_api_key_here":
    safe_print("\n‚ùå ERROR: Configure your Groq API key in CELL 3!")
    safe_print("Get free API key: https://console.groq.com/")
else:
    assistant = RAGAssistant(api_key=GROQ_API_KEY)

    documents = load_documents_from_folder("./data")

    if documents:
        db_stats = assistant.vector_db.get_stats()
        if db_stats['total_chunks'] > 0:
            safe_print(f"üì¶ Database has {db_stats['total_chunks']} chunks")
            choice = input("Add new documents? (yes/no): ")
            if choice.lower() in ['yes', 'y']:
                assistant.add_documents(documents)
        else:
            assistant.add_documents(documents)

        safe_print("\n‚úÖ System Ready!")
        safe_print(assistant.get_system_stats())
    else:
        safe_print("\n‚ö†Ô∏è No documents found")

## Interactive Q&A Session

In [None]:
safe_print("\n" + "=" * 60)
safe_print("üí¨ INTERACTIVE Q&A SESSION")
safe_print("=" * 60)
safe_print("\nüìã Commands:")
safe_print("  ‚Ä¢ Ask any question")
safe_print("  ‚Ä¢ 'history' - View conversation")
safe_print("  ‚Ä¢ 'stats' - System statistics")
safe_print("  ‚Ä¢ 'export' - Export conversation")
safe_print("  ‚Ä¢ 'clear' - Clear history")
safe_print("  ‚Ä¢ 'reset' - Reset system")
safe_print("  ‚Ä¢ 'quit' - Exit\n" + "=" * 60 + "\n")

while True:
    try:
        question = input("You: ").strip()

        if not question:
            continue

        if question.lower() in ["quit", "exit", "q"]:
            safe_print(f"\nüëã Goodbye! Data saved to: {STORAGE_DIR}")
            break

        if question.lower() == "history":
            safe_print("\n" + assistant.get_conversation_summary() + "\n" + "-" * 60 + "\n")
            continue

        if question.lower() == "stats":
            safe_print("\n" + assistant.get_system_stats() + "\n" + "-" * 60 + "\n")
            continue

        if question.lower() == "export":
            filepath = assistant.export_conversation()
            if filepath:
                safe_print(f"‚úÖ Exported: {filepath}\n")
            safe_print("-" * 60 + "\n")
            continue

        if question.lower() == "clear":
            assistant.clear_history()
            safe_print("-" * 60 + "\n")
            continue

        if question.lower() == "reset":
            assistant.reset_system()
            safe_print("-" * 60 + "\n")
            continue

        safe_print("\nü§î Processing...\n")
        answer = assistant.query(question)
        safe_print(f"ü§ñ Assistant:\n{answer}\n" + "-" * 60 + "\n")

    except KeyboardInterrupt:
        safe_print(f"\n\nüëã Interrupted. Data saved: {STORAGE_DIR}")
        break
    except Exception as e:
        safe_print(f"\n‚ùå Error: {str(e)}\n")