<a href="https://colab.research.google.com/drive/1WjBv6UrUIZ7KU2x0eFP9kQgNXbv9JpwI?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>

### RAG: Retrieval-Augmented Generation with ChromaDB + Google Embeddings

In [1]:
!pip install -qU google-generativeai chromadb

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.9/19.9 MB[0m [31m97.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.2/278.2 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.3/103.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m95.9 MB/s[0m eta [36m0:00:

In [2]:
import google.generativeai as genai
import chromadb
import getpass

Get free-tier Google's Gemini API Key here: https://aistudio.google.com/app/apikey

In [3]:
API_KEY = getpass.getpass("Enter your Google API key: ")

Enter your Google API key: ··········


In [5]:
genai.configure(api_key=API_KEY)

In [6]:
class RAG:
    def __init__(self, collection_name="knowledge_base"):
        self.model = genai.GenerativeModel("gemini-2.0-flash")

        # Initialize ChromaDB
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.get_or_create_collection(
            name=collection_name,
            metadata={"description": "RAG knowledge base"}
        )
        self.doc_counter = 0

    def get_embedding(self, text):
        """Get embedding using Google's embedding model"""
        result = genai.embed_content(
            model="models/text-embedding-004",
            content=text,
            task_type="retrieval_document"
        )
        return result['embedding']

    def add_document(self, content, metadata=None):
        """Add document with embedding to ChromaDB"""
        doc_id = f"doc_{self.doc_counter}"
        self.doc_counter += 1

        # Get embedding
        embedding = self.get_embedding(content)

        # Add to ChromaDB
        self.collection.add(
            ids=[doc_id],
            embeddings=[embedding],
            documents=[content],
            metadatas=[metadata or {}]
        )

        print(f"✅ Added {doc_id}: {content[:60]}...")
        return doc_id

    def retrieve(self, query, top_k=3):
        """Semantic search using embeddings"""
        # Get query embedding
        query_embedding = genai.embed_content(
            model="models/text-embedding-004",
            content=query,
            task_type="retrieval_query"
        )['embedding']

        # Search in ChromaDB
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k
        )

        # Format results
        docs = []
        if results['documents'] and results['documents'][0]:
            for i, (doc, metadata, distance) in enumerate(zip(
                results['documents'][0],
                results['metadatas'][0],
                results['distances'][0]
            )):
                docs.append({
                    "content": doc,
                    "metadata": metadata,
                    "similarity": 1 - distance  # Convert distance to similarity
                })

        return docs

    def generate_response(self, query, retrieved_docs):
        """Generate answer using retrieved context"""
        if not retrieved_docs:
            return "No relevant information found."

        # Build context from retrieved docs
        context = "\n\n".join([
            f"Document {i+1} (relevance: {doc['similarity']:.2f}):\n{doc['content']}"
            for i, doc in enumerate(retrieved_docs)
        ])

        prompt = f"""Answer the question using the provided documents. Cite which document(s) you use.

        Context:
        {context}

        Question: {query}

        Answer:"""

        response = self.model.generate_content(prompt).text
        return response.strip()

    def query(self, user_input, top_k=3):
        """Full RAG pipeline"""
        print(f"\n{'='*60}")
        print(f"❓ Query: {user_input}")
        print(f"{'='*60}\n")

        # Retrieve relevant documents
        print(f"🔍 Retrieving top {top_k} relevant documents...")
        docs = self.retrieve(user_input, top_k)

        if docs:
            print(f"📚 Found {len(docs)} documents:\n")
            for i, doc in enumerate(docs, 1):
                print(f"   {i}. [Similarity: {doc['similarity']:.3f}]")
                print(f"      {doc['content'][:80]}...\n")
        else:
            print("   No relevant documents found.\n")

        # Generate response
        print("✨ Generating response...\n")
        response = self.generate_response(user_input, docs)

        print(f"{'='*60}")
        print(f"💬 Answer:")
        print(f"{'='*60}")
        print(response)
        print()

        return response

    def get_stats(self):
        """Get collection statistics"""
        count = self.collection.count()
        print(f"📊 Knowledge Base: {count} documents")

In [7]:
# Example 1: Company Knowledge Base
print("="*60)
print("EXAMPLE 1: Company Knowledge Base")
print("="*60)

company_rag = RAG("company_kb")

# Add documents
company_rag.add_document(
    "Full-time employees receive 15 days of paid vacation per year. Part-time employees "
    "receive pro-rated vacation days. All vacation requests must be approved by your "
    "direct manager at least 2 weeks in advance. Unused vacation days do not roll over.",
    {"department": "HR", "type": "vacation_policy"}
)

company_rag.add_document(
    "Employees are entitled to 10 sick days per year. For absences of 1-2 days, no "
    "medical documentation is required. For extended illness beyond 3 days, a doctor's "
    "note must be submitted to HR.",
    {"department": "HR", "type": "sick_leave"}
)

company_rag.add_document(
    "To connect to company VPN: Download Cisco AnyConnect from the IT portal. "
    "Use your company email and standard password. If you encounter connection issues, "
    "contact IT helpdesk at extension 5555 or helpdesk@company.com.",
    {"department": "IT", "type": "vpn_guide"}
)

company_rag.add_document(
    "All business expenses must be submitted through Expensify within 30 days of purchase. "
    "Receipts are mandatory. Expenses over $500 require manager approval. Reimbursement "
    "is processed within 5-7 business days after approval.",
    {"department": "Finance", "type": "expense_policy"}
)

company_rag.add_document(
    "Remote work policy: Employees can work remotely up to 3 days per week with manager "
    "approval. Must be available during core hours 10am-3pm. Home office stipend of $500 "
    "available annually for equipment purchases.",
    {"department": "HR", "type": "remote_work"}
)

company_rag.get_stats()

# Test semantic search
company_rag.query("How many days off do I get for vacation?")
company_rag.query("I'm sick, what's the policy?")
company_rag.query("How do I connect to VPN from home?")
company_rag.query("Can I work from home?")


# Example 2: Technical Documentation
print("\n" + "="*60)
print("EXAMPLE 2: Technical Documentation")
print("="*60)

tech_rag = RAG("tech_docs")

tech_rag.add_document(
    "Authentication uses JWT Bearer tokens. Obtain a token by sending POST request to "
    "/api/v1/auth with JSON body containing username and password. Token expires after "
    "24 hours. Include token in Authorization header: 'Bearer <token>'.",
    {"category": "authentication", "version": "v1"}
)

tech_rag.add_document(
    "User management endpoints: GET /api/v1/users (list all users), "
    "POST /api/v1/users (create new user, requires admin role), "
    "PUT /api/v1/users/{id} (update user), DELETE /api/v1/users/{id} (delete user, admin only). "
    "All endpoints require authentication.",
    {"category": "endpoints", "resource": "users"}
)

tech_rag.add_document(
    "Rate limiting: API requests are limited to 100 requests per minute per API key. "
    "Exceeded limits return 429 Too Many Requests. Rate limit info in response headers: "
    "X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset.",
    {"category": "rate_limiting"}
)

tech_rag.add_document(
    "Error handling: API returns standard HTTP status codes. 400 for bad requests, "
    "401 for unauthorized, 403 for forbidden, 404 for not found, 500 for server errors. "
    "Error responses include JSON with 'error' (code) and 'message' (description) fields.",
    {"category": "errors"}
)

tech_rag.get_stats()

tech_rag.query("How do I authenticate?")
tech_rag.query("What happens if I make too many requests?")
tech_rag.query("How do I create a new user?")


# Example 3: Research Papers (Semantic Understanding)
print("\n" + "="*60)
print("EXAMPLE 3: Research Database (Semantic Search)")
print("="*60)

research_rag = RAG("research")

research_rag.add_document(
    "Remote work productivity study 2023: Survey of 5000 employees found 65% report "
    "increased productivity when working remotely. Key contributing factors include "
    "flexible scheduling, elimination of commute time, and personalized work environment. "
    "However, 28% experienced productivity decline due to home distractions.",
    {"year": 2023, "topic": "remote_work", "type": "survey"}
)

research_rag.add_document(
    "Challenges of distributed teams: Research shows 40% of remote workers struggle with "
    "work-life balance. Communication gaps reported in 35% of fully remote teams. "
    "Social isolation affects 30% of remote employees. Regular video meetings and virtual "
    "social events help mitigate these issues.",
    {"year": 2023, "topic": "remote_challenges"}
)

research_rag.add_document(
    "Hybrid work model analysis 2024: Companies implementing 3-2 model (3 days office, "
    "2 days remote) report 22% higher employee satisfaction compared to fully office or "
    "fully remote. This model balances collaboration benefits with flexibility. "
    "Wednesday is most common mandatory office day.",
    {"year": 2024, "topic": "hybrid_work"}
)

research_rag.add_document(
    "Impact of AI on software development: Study shows developers using AI assistants "
    "complete tasks 35% faster. Code quality metrics remain similar. 78% of developers "
    "report AI tools helpful for boilerplate code. Learning curve for effective AI "
    "prompting takes 2-3 weeks.",
    {"year": 2024, "topic": "ai_development"}
)

research_rag.get_stats()

# Semantic search should find relevant docs even with different wording
research_rag.query("What are the benefits and drawbacks of working from home?")
research_rag.query("What's the best office-remote split?")
research_rag.query("How does artificial intelligence help programmers?")

print("\n✅ RAG with ChromaDB + Google Embeddings Complete!")

EXAMPLE 1: Company Knowledge Base
✅ Added doc_0: Full-time employees receive 15 days of paid vacation per yea...
✅ Added doc_1: Employees are entitled to 10 sick days per year. For absence...
✅ Added doc_2: To connect to company VPN: Download Cisco AnyConnect from th...
✅ Added doc_3: All business expenses must be submitted through Expensify wi...
✅ Added doc_4: Remote work policy: Employees can work remotely up to 3 days...
📊 Knowledge Base: 5 documents

❓ Query: How many days off do I get for vacation?

🔍 Retrieving top 3 relevant documents...
📚 Found 3 documents:

   1. [Similarity: 0.408]
      Full-time employees receive 15 days of paid vacation per year. Part-time employe...

   2. [Similarity: 0.234]
      Employees are entitled to 10 sick days per year. For absences of 1-2 days, no me...

   3. [Similarity: 0.054]
      Remote work policy: Employees can work remotely up to 3 days per week with manag...

✨ Generating response...

💬 Answer:
Full-time employees receive 15 days of 