# Lab 3: Enterprise RAG System

**Week 4 - RAG Fundamentals**

**Provided by:** ADC ENGINEERING & CONSULTING LTD

## Objectives

In this lab, you will:
- Build production-ready RAG systems
- Implement advanced RAG patterns
- Handle concurrent requests
- Add caching and optimization
- Implement monitoring and logging
- Build fault-tolerant systems
- Add security and access control
- Deploy RAG at scale
- Implement RAG pipelines with streaming

## Prerequisites

- Completed Week 4 Labs 1-2
- Understanding of RAG architecture
- Experience with embeddings and vector search
- OpenAI API key configured
- Python 3.9+

## Setup and Installation

In [None]:
# Install required packages
!pip install openai python-dotenv tiktoken numpy scikit-learn pandas redis --quiet

In [None]:
import os
import json
import re
import hashlib
import time
import logging
import threading
from typing import List, Dict, Optional, Tuple, Any, Callable, Iterator
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from pathlib import Path
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor, as_completed
import pickle

from openai import OpenAI
from dotenv import load_dotenv
import tiktoken
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load environment variables
load_dotenv()

# Initialize OpenAI client
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

print("âœ“ Setup complete!")

## Part 1: Production-Ready RAG Components

Building enterprise-grade RAG requires robust, scalable components.

### 1.1 Embeddings Cache

Cache embeddings to avoid redundant API calls.

In [None]:
class EmbeddingsCache:
    """
    Cache for embeddings to reduce API calls and costs.
    """
    
    def __init__(
        self,
        cache_dir: str = ".embeddings_cache",
        ttl_hours: int = 24 * 7  # 1 week
    ):
        """
        Initialize embeddings cache.
        
        Args:
            cache_dir: Directory to store cache
            ttl_hours: Time-to-live in hours
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.ttl = timedelta(hours=ttl_hours)
        
        # In-memory cache
        self.memory_cache: Dict[str, Tuple[List[float], datetime]] = {}
        
        # Statistics
        self.hits = 0
        self.misses = 0
        
        logger.info(f"Initialized embeddings cache at {cache_dir}")
    
    def _get_cache_key(self, text: str, model: str) -> str:
        """Generate cache key from text and model."""
        combined = f"{model}:{text}"
        return hashlib.sha256(combined.encode()).hexdigest()
    
    def get(self, text: str, model: str) -> Optional[List[float]]:
        """
        Get embedding from cache.
        
        Args:
            text: Text to embed
            model: Embedding model name
        
        Returns:
            Cached embedding or None
        """
        cache_key = self._get_cache_key(text, model)
        
        # Check memory cache first
        if cache_key in self.memory_cache:
            embedding, timestamp = self.memory_cache[cache_key]
            if datetime.now() - timestamp < self.ttl:
                self.hits += 1
                return embedding
            else:
                # Expired
                del self.memory_cache[cache_key]
        
        # Check disk cache
        cache_file = self.cache_dir / f"{cache_key}.pkl"
        if cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    data = pickle.load(f)
                
                timestamp = data['timestamp']
                if datetime.now() - timestamp < self.ttl:
                    embedding = data['embedding']
                    # Load into memory cache
                    self.memory_cache[cache_key] = (embedding, timestamp)
                    self.hits += 1
                    return embedding
                else:
                    # Expired
                    cache_file.unlink()
            except Exception as e:
                logger.warning(f"Error reading cache: {e}")
        
        self.misses += 1
        return None
    
    def set(self, text: str, model: str, embedding: List[float]):
        """
        Store embedding in cache.
        
        Args:
            text: Text that was embedded
            model: Embedding model name
            embedding: Embedding vector
        """
        cache_key = self._get_cache_key(text, model)
        timestamp = datetime.now()
        
        # Store in memory
        self.memory_cache[cache_key] = (embedding, timestamp)
        
        # Store on disk
        cache_file = self.cache_dir / f"{cache_key}.pkl"
        try:
            with open(cache_file, 'wb') as f:
                pickle.dump({
                    'embedding': embedding,
                    'timestamp': timestamp,
                    'text_length': len(text),
                    'model': model
                }, f)
        except Exception as e:
            logger.warning(f"Error writing cache: {e}")
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0.0
        
        return {
            "hits": self.hits,
            "misses": self.misses,
            "total_requests": total,
            "hit_rate": hit_rate,
            "memory_cache_size": len(self.memory_cache),
            "disk_cache_size": len(list(self.cache_dir.glob("*.pkl")))
        }
    
    def clear(self):
        """Clear all cache."""
        self.memory_cache.clear()
        for cache_file in self.cache_dir.glob("*.pkl"):
            cache_file.unlink()
        logger.info("Cache cleared")

# Test embeddings cache
print("="*80)
print("EMBEDDINGS CACHE")
print("="*80)

cache = EmbeddingsCache()

def get_embedding_with_cache(text: str, model: str = "text-embedding-3-small") -> List[float]:
    """Get embedding with caching."""
    # Check cache first
    embedding = cache.get(text, model)
    if embedding is not None:
        return embedding
    
    # Get from API
    text = text.replace("\n", " ")
    response = client.embeddings.create(input=[text], model=model)
    embedding = response.data[0].embedding
    
    # Store in cache
    cache.set(text, model, embedding)
    
    return embedding

# Test cache
test_text = "This is a test sentence for caching."

print("\nFirst call (cache miss):")
start = time.time()
emb1 = get_embedding_with_cache(test_text)
time1 = time.time() - start
print(f"  Time: {time1:.4f}s")

print("\nSecond call (cache hit):")
start = time.time()
emb2 = get_embedding_with_cache(test_text)
time2 = time.time() - start
print(f"  Time: {time2:.4f}s")
print(f"  Speedup: {time1/time2:.1f}x")

print("\nCache statistics:")
stats = cache.get_stats()
for key, value in stats.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

### 1.2 Rate Limiter

Manage API rate limits and concurrency.

In [None]:
class RateLimiter:
    """
    Rate limiter for API calls.
    """
    
    def __init__(
        self,
        max_requests_per_minute: int = 3500,
        max_tokens_per_minute: int = 90000
    ):
        """
        Initialize rate limiter.
        
        Args:
            max_requests_per_minute: Maximum requests per minute
            max_tokens_per_minute: Maximum tokens per minute
        """
        self.max_requests_per_minute = max_requests_per_minute
        self.max_tokens_per_minute = max_tokens_per_minute
        
        # Tracking
        self.request_times: deque = deque()
        self.token_usage: deque = deque()  # (timestamp, tokens)
        
        # Thread safety
        self.lock = threading.Lock()
        
        logger.info(f"Initialized rate limiter: {max_requests_per_minute} req/min, {max_tokens_per_minute} tokens/min")
    
    def _clean_old_entries(self):
        """Remove entries older than 1 minute."""
        now = time.time()
        cutoff = now - 60
        
        # Clean requests
        while self.request_times and self.request_times[0] < cutoff:
            self.request_times.popleft()
        
        # Clean tokens
        while self.token_usage and self.token_usage[0][0] < cutoff:
            self.token_usage.popleft()
    
    def acquire(self, estimated_tokens: int = 1000) -> bool:
        """
        Try to acquire permission for an API call.
        
        Args:
            estimated_tokens: Estimated token usage
        
        Returns:
            True if acquired, False if rate limit exceeded
        """
        with self.lock:
            self._clean_old_entries()
            
            # Check request limit
            if len(self.request_times) >= self.max_requests_per_minute:
                return False
            
            # Check token limit
            current_tokens = sum(tokens for _, tokens in self.token_usage)
            if current_tokens + estimated_tokens > self.max_tokens_per_minute:
                return False
            
            # Acquire
            now = time.time()
            self.request_times.append(now)
            self.token_usage.append((now, estimated_tokens))
            
            return True
    
    def wait_if_needed(self, estimated_tokens: int = 1000, max_wait: float = 60.0):
        """
        Wait until rate limit allows the request.
        
        Args:
            estimated_tokens: Estimated token usage
            max_wait: Maximum wait time in seconds
        """
        start = time.time()
        
        while not self.acquire(estimated_tokens):
            if time.time() - start > max_wait:
                raise TimeoutError("Rate limit wait timeout")
            time.sleep(0.1)
    
    def get_stats(self) -> Dict[str, Any]:
        """Get rate limiter statistics."""
        with self.lock:
            self._clean_old_entries()
            current_requests = len(self.request_times)
            current_tokens = sum(tokens for _, tokens in self.token_usage)
            
            return {
                "current_requests_per_minute": current_requests,
                "current_tokens_per_minute": current_tokens,
                "request_capacity_used": current_requests / self.max_requests_per_minute,
                "token_capacity_used": current_tokens / self.max_tokens_per_minute
            }

# Test rate limiter
print("\n" + "="*80)
print("RATE LIMITER")
print("="*80)

rate_limiter = RateLimiter(max_requests_per_minute=10, max_tokens_per_minute=5000)

print("\nSimulating API calls:")
for i in range(12):
    if rate_limiter.acquire(estimated_tokens=500):
        print(f"  Request {i+1}: Accepted")
    else:
        print(f"  Request {i+1}: Rate limited")
    
    if i == 5:
        print("\n  Checking stats:")
        stats = rate_limiter.get_stats()
        for key, value in stats.items():
            if isinstance(value, float):
                print(f"    {key}: {value:.2%}")
            else:
                print(f"    {key}: {value}")
        print()

### Exercise 1.1: Build Connection Pool

Implement a connection pool for API clients:

In [None]:
# TODO: Implement connection pool

class APIConnectionPool:
    """
    TODO: Connection pool for API clients.
    
    Should implement:
    1. Pool of OpenAI clients
    2. Connection reuse
    3. Health checking
    4. Automatic retry with backoff
    5. Circuit breaker pattern
    6. Load balancing across clients
    """
    
    def __init__(self, pool_size: int = 5):
        """Initialize connection pool."""
        pass
    
    def get_client(self) -> OpenAI:
        """TODO: Get an available client from pool."""
        pass
    
    def release_client(self, client: OpenAI):
        """TODO: Return client to pool."""
        pass
    
    def health_check(self) -> Dict[str, bool]:
        """TODO: Check health of all clients."""
        pass

# Test your connection pool
# pool = APIConnectionPool(pool_size=3)
# client = pool.get_client()
# # Use client...
# pool.release_client(client)

## Part 2: Enterprise RAG System

Complete enterprise-grade RAG implementation.

In [None]:
@dataclass
class QueryRequest:
    """RAG query request."""
    query: str
    top_k: int = 3
    min_similarity: float = 0.0
    filters: Optional[Dict[str, Any]] = None
    user_id: Optional[str] = None
    session_id: Optional[str] = None
    
@dataclass
class QueryResponse:
    """RAG query response."""
    query: str
    answer: str
    sources: List[Dict[str, Any]]
    confidence: float
    latency_ms: float
    num_chunks_retrieved: int
    metadata: Dict[str, Any] = field(default_factory=dict)

class EnterpriseRAG:
    """
    Production-ready RAG system with enterprise features.
    """
    
    def __init__(
        self,
        embedding_model: str = "text-embedding-3-small",
        llm_model: str = "gpt-3.5-turbo",
        cache_dir: str = ".rag_cache",
        max_concurrent_requests: int = 10
    ):
        """
        Initialize enterprise RAG.
        
        Args:
            embedding_model: Model for embeddings
            llm_model: Model for generation
            cache_dir: Directory for caching
            max_concurrent_requests: Max concurrent queries
        """
        self.embedding_model = embedding_model
        self.llm_model = llm_model
        
        # Components
        self.embeddings_cache = EmbeddingsCache(cache_dir)
        self.rate_limiter = RateLimiter()
        
        # Storage
        self.chunks: List[Dict[str, Any]] = []
        self.embeddings: Optional[np.ndarray] = None
        
        # Concurrency
        self.executor = ThreadPoolExecutor(max_workers=max_concurrent_requests)
        self.semaphore = threading.Semaphore(max_concurrent_requests)
        
        # Monitoring
        self.query_history: List[Dict[str, Any]] = []
        self.error_count = 0
        
        logger.info(f"Initialized EnterpriseRAG with model={llm_model}")
    
    def get_embedding(self, text: str) -> List[float]:
        """Get embedding with caching and rate limiting."""
        # Check cache
        embedding = self.embeddings_cache.get(text, self.embedding_model)
        if embedding is not None:
            return embedding
        
        # Rate limit
        self.rate_limiter.wait_if_needed(estimated_tokens=len(text.split()))
        
        # Get from API
        text = text.replace("\n", " ")
        response = client.embeddings.create(
            input=[text],
            model=self.embedding_model
        )
        embedding = response.data[0].embedding
        
        # Cache
        self.embeddings_cache.set(text, self.embedding_model, embedding)
        
        return embedding
    
    def ingest_document(
        self,
        text: str,
        document_id: str,
        metadata: Optional[Dict] = None,
        chunk_size: int = 500
    ):
        """
        Ingest a document.
        
        Args:
            text: Document text
            document_id: Document identifier
            metadata: Optional metadata
            chunk_size: Chunk size in characters
        """
        logger.info(f"Ingesting document: {document_id}")
        
        # Simple chunking
        chunks = []
        words = text.split()
        current_chunk = []
        current_size = 0
        
        for word in words:
            current_chunk.append(word)
            current_size += len(word) + 1
            
            if current_size >= chunk_size:
                chunk_text = " ".join(current_chunk)
                chunks.append(chunk_text)
                current_chunk = []
                current_size = 0
        
        if current_chunk:
            chunks.append(" ".join(current_chunk))
        
        logger.info(f"  Created {len(chunks)} chunks")
        
        # Generate embeddings (with parallelization)
        logger.info(f"  Generating embeddings...")
        
        def embed_chunk(chunk_text, idx):
            embedding = self.get_embedding(chunk_text)
            return {
                'chunk_id': f"{document_id}_chunk_{idx}",
                'document_id': document_id,
                'content': chunk_text,
                'embedding': embedding,
                'metadata': metadata or {}
            }
        
        with ThreadPoolExecutor(max_workers=5) as executor:
            futures = [
                executor.submit(embed_chunk, chunk, i)
                for i, chunk in enumerate(chunks)
            ]
            
            for future in as_completed(futures):
                try:
                    chunk_data = future.result()
                    self.chunks.append(chunk_data)
                except Exception as e:
                    logger.error(f"Error embedding chunk: {e}")
                    self.error_count += 1
        
        # Rebuild embeddings matrix
        self._rebuild_embeddings()
        
        logger.info(f"âœ“ Document ingested. Total chunks: {len(self.chunks)}")
    
    def _rebuild_embeddings(self):
        """Rebuild embeddings matrix."""
        if self.chunks:
            embeddings_list = [chunk['embedding'] for chunk in self.chunks]
            self.embeddings = np.array(embeddings_list)
    
    def retrieve(
        self,
        query: str,
        top_k: int = 3,
        min_similarity: float = 0.0,
        filters: Optional[Dict[str, Any]] = None
    ) -> List[Tuple[Dict, float]]:
        """
        Retrieve relevant chunks.
        
        Args:
            query: Search query
            top_k: Number of chunks
            min_similarity: Minimum similarity
            filters: Metadata filters
        
        Returns:
            List of (chunk, score) tuples
        """
        if not self.chunks:
            return []
        
        # Get query embedding
        query_embedding = np.array(self.get_embedding(query))
        
        # Calculate similarities
        similarities = cosine_similarity(
            query_embedding.reshape(1, -1),
            self.embeddings
        )[0]
        
        # Apply filters
        valid_indices = []
        for idx, chunk in enumerate(self.chunks):
            if filters:
                match = all(
                    chunk['metadata'].get(k) == v
                    for k, v in filters.items()
                )
                if not match:
                    continue
            valid_indices.append(idx)
        
        # Get top-k from valid indices
        valid_similarities = [(idx, similarities[idx]) for idx in valid_indices]
        valid_similarities.sort(key=lambda x: x[1], reverse=True)
        
        results = []
        for idx, score in valid_similarities[:top_k]:
            if score >= min_similarity:
                results.append((self.chunks[idx], float(score)))
        
        return results
    
    def generate_answer(
        self,
        query: str,
        context_chunks: List[Dict],
        temperature: float = 0.7
    ) -> Tuple[str, float]:
        """
        Generate answer with confidence score.
        
        Args:
            query: User query
            context_chunks: Retrieved chunks
            temperature: LLM temperature
        
        Returns:
            (answer, confidence_score)
        """
        # Assemble context
        context = "\n\n".join([
            f"[Source {i+1}]: {chunk['content']}"
            for i, chunk in enumerate(context_chunks)
        ])
        
        # Create prompt
        prompt = f"""Answer the question based on the context provided. If the answer cannot be found in the context, say "I don't have enough information to answer that."

Context:
{context}

Question: {query}

Answer:"""
        
        # Rate limit
        self.rate_limiter.wait_if_needed(estimated_tokens=len(prompt.split()) + 200)
        
        # Generate
        response = client.chat.completions.create(
            model=self.llm_model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            max_tokens=500
        )
        
        answer = response.choices[0].message.content
        
        # Estimate confidence (simple heuristic)
        confidence = 0.5
        if "don't have enough information" not in answer.lower():
            confidence = 0.8
        if len(context_chunks) >= 3:
            confidence = min(confidence + 0.1, 1.0)
        
        return answer, confidence
    
    def query(self, request: QueryRequest) -> QueryResponse:
        """
        Process a query request.
        
        Args:
            request: Query request
        
        Returns:
            Query response
        """
        start_time = time.time()
        
        try:
            # Acquire semaphore for concurrency control
            with self.semaphore:
                logger.info(f"Processing query: {request.query}")
                
                # Retrieve
                retrieved = self.retrieve(
                    request.query,
                    request.top_k,
                    request.min_similarity,
                    request.filters
                )
                
                if not retrieved:
                    latency = (time.time() - start_time) * 1000
                    response = QueryResponse(
                        query=request.query,
                        answer="I don't have any relevant information to answer that question.",
                        sources=[],
                        confidence=0.0,
                        latency_ms=latency,
                        num_chunks_retrieved=0
                    )
                    self._log_query(request, response)
                    return response
                
                # Generate
                chunks = [chunk for chunk, _ in retrieved]
                answer, confidence = self.generate_answer(request.query, chunks)
                
                # Build response
                latency = (time.time() - start_time) * 1000
                
                response = QueryResponse(
                    query=request.query,
                    answer=answer,
                    sources=[
                        {
                            'chunk_id': chunk['chunk_id'],
                            'document_id': chunk['document_id'],
                            'content': chunk['content'],
                            'score': score,
                            'metadata': chunk['metadata']
                        }
                        for chunk, score in retrieved
                    ],
                    confidence=confidence,
                    latency_ms=latency,
                    num_chunks_retrieved=len(retrieved),
                    metadata={
                        'user_id': request.user_id,
                        'session_id': request.session_id
                    }
                )
                
                self._log_query(request, response)
                
                return response
                
        except Exception as e:
            logger.error(f"Error processing query: {e}")
            self.error_count += 1
            
            latency = (time.time() - start_time) * 1000
            return QueryResponse(
                query=request.query,
                answer=f"Error processing query: {str(e)}",
                sources=[],
                confidence=0.0,
                latency_ms=latency,
                num_chunks_retrieved=0,
                metadata={'error': str(e)}
            )
    
    def query_batch(
        self,
        requests: List[QueryRequest]
    ) -> List[QueryResponse]:
        """
        Process multiple queries in parallel.
        
        Args:
            requests: List of query requests
        
        Returns:
            List of query responses
        """
        logger.info(f"Processing batch of {len(requests)} queries")
        
        futures = [
            self.executor.submit(self.query, request)
            for request in requests
        ]
        
        responses = []
        for future in as_completed(futures):
            try:
                response = future.result()
                responses.append(response)
            except Exception as e:
                logger.error(f"Error in batch query: {e}")
        
        return responses
    
    def _log_query(self, request: QueryRequest, response: QueryResponse):
        """Log query for monitoring."""
        self.query_history.append({
            'timestamp': datetime.now().isoformat(),
            'query': request.query,
            'answer': response.answer,
            'confidence': response.confidence,
            'latency_ms': response.latency_ms,
            'num_sources': response.num_chunks_retrieved,
            'user_id': request.user_id,
            'session_id': request.session_id
        })
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get system metrics."""
        if not self.query_history:
            return {"message": "No queries processed yet"}
        
        latencies = [q['latency_ms'] for q in self.query_history]
        confidences = [q['confidence'] for q in self.query_history]
        
        cache_stats = self.embeddings_cache.get_stats()
        rate_stats = self.rate_limiter.get_stats()
        
        return {
            "total_queries": len(self.query_history),
            "avg_latency_ms": np.mean(latencies),
            "p50_latency_ms": np.percentile(latencies, 50),
            "p95_latency_ms": np.percentile(latencies, 95),
            "p99_latency_ms": np.percentile(latencies, 99),
            "avg_confidence": np.mean(confidences),
            "error_count": self.error_count,
            "cache_hit_rate": cache_stats['hit_rate'],
            "total_chunks": len(self.chunks),
            **rate_stats
        }
    
    def shutdown(self):
        """Shutdown the system."""
        self.executor.shutdown(wait=True)
        logger.info("EnterpriseRAG shutdown complete")

# Test Enterprise RAG
print("\n" + "="*80)
print("ENTERPRISE RAG SYSTEM")
print("="*80)

# Create system
rag = EnterpriseRAG()

# Ingest documents
documents = [
    {
        "id": "python_guide",
        "text": """Python is a versatile programming language. It supports multiple programming paradigms including object-oriented and functional programming. Python has a rich ecosystem of libraries for data science, web development, and automation. Popular frameworks include Django and Flask for web development, and NumPy and Pandas for data analysis.""",
        "metadata": {"category": "programming", "language": "python"}
    },
    {
        "id": "ml_intro",
        "text": """Machine learning is a subset of artificial intelligence. It enables systems to learn from data without explicit programming. Common algorithms include linear regression, decision trees, and neural networks. Deep learning uses multi-layer neural networks to learn complex patterns. Applications include image recognition, natural language processing, and recommendation systems.""",
        "metadata": {"category": "ai", "topic": "machine_learning"}
    }
]

for doc in documents:
    rag.ingest_document(doc['text'], doc['id'], doc['metadata'])

print("\n" + "="*80)
print("TESTING QUERIES")
print("="*80)

# Single query
request = QueryRequest(
    query="What is Python used for?",
    top_k=2,
    user_id="user123",
    session_id="session456"
)

response = rag.query(request)
print(f"\nQuery: {response.query}")
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.2f}")
print(f"Latency: {response.latency_ms:.2f}ms")
print(f"Sources: {response.num_chunks_retrieved}")

# Batch queries
print("\n" + "="*80)
print("BATCH PROCESSING")
print("="*80)

batch_requests = [
    QueryRequest(query="Tell me about machine learning", top_k=2),
    QueryRequest(query="What frameworks does Python have?", top_k=2),
    QueryRequest(query="What is deep learning?", top_k=2)
]

batch_responses = rag.query_batch(batch_requests)
print(f"\nProcessed {len(batch_responses)} queries in parallel")

# Metrics
print("\n" + "="*80)
print("SYSTEM METRICS")
print("="*80)

metrics = rag.get_metrics()
for key, value in metrics.items():
    if isinstance(value, float):
        print(f"{key}: {value:.2f}")
    else:
        print(f"{key}: {value}")

### Exercise 2.1: Add Monitoring and Alerting

Implement comprehensive monitoring:

In [None]:
# TODO: Implement monitoring system

class RAGMonitor:
    """
    TODO: Monitoring and alerting for RAG system.
    
    Should track:
    1. Query volume and patterns
    2. Latency percentiles (p50, p95, p99)
    3. Error rates and types
    4. Cache hit rates
    5. Token usage and costs
    6. User satisfaction scores
    
    Should alert on:
    - High latency
    - High error rate
    - Low cache hit rate
    - Unusual query patterns
    """
    
    def __init__(self, rag_system: EnterpriseRAG):
        """Initialize monitor."""
        self.rag = rag_system
        self.alerts: List[Dict] = []
    
    def check_health(self) -> Dict[str, str]:
        """
        TODO: Check system health.
        
        Return health status for:
        - Latency
        - Error rate
        - Cache performance
        - Overall status
        """
        pass
    
    def generate_alert(self, alert_type: str, message: str):
        """TODO: Generate alert."""
        pass
    
    def get_dashboard_data(self) -> Dict:
        """TODO: Get data for monitoring dashboard."""
        pass

# Test monitoring
# monitor = RAGMonitor(rag)
# health = monitor.check_health()
# dashboard = monitor.get_dashboard_data()

## Part 3: Advanced RAG Patterns

Sophisticated RAG techniques for better quality.

### 3.1 Query Routing

Route queries to appropriate strategies.

In [None]:
class QueryRouter:
    """
    Route queries to appropriate RAG strategies.
    """
    
    def __init__(self):
        """Initialize query router."""
        self.query_patterns = {
            'factual': re.compile(r'\b(what|who|when|where|which)\b', re.IGNORECASE),
            'comparative': re.compile(r'\b(compare|difference|versus|vs|better)\b', re.IGNORECASE),
            'procedural': re.compile(r'\b(how|steps|process|procedure)\b', re.IGNORECASE),
            'analytical': re.compile(r'\b(why|explain|analyze|reason)\b', re.IGNORECASE)
        }
        
        logger.info("Initialized query router")
    
    def classify_query(self, query: str) -> str:
        """
        Classify query type.
        
        Args:
            query: User query
        
        Returns:
            Query type
        """
        for query_type, pattern in self.query_patterns.items():
            if pattern.search(query):
                return query_type
        
        return 'general'
    
    def get_strategy(self, query: str) -> Dict[str, Any]:
        """
        Get retrieval strategy for query.
        
        Args:
            query: User query
        
        Returns:
            Strategy parameters
        """
        query_type = self.classify_query(query)
        
        strategies = {
            'factual': {
                'top_k': 2,
                'min_similarity': 0.7,
                'temperature': 0.3,
                'description': 'Precise, focused retrieval'
            },
            'comparative': {
                'top_k': 5,
                'min_similarity': 0.5,
                'temperature': 0.5,
                'description': 'Broader retrieval for comparison'
            },
            'procedural': {
                'top_k': 4,
                'min_similarity': 0.6,
                'temperature': 0.4,
                'description': 'Sequential, step-by-step context'
            },
            'analytical': {
                'top_k': 5,
                'min_similarity': 0.5,
                'temperature': 0.7,
                'description': 'Diverse context for analysis'
            },
            'general': {
                'top_k': 3,
                'min_similarity': 0.6,
                'temperature': 0.5,
                'description': 'Balanced retrieval'
            }
        }
        
        return {
            'query_type': query_type,
            **strategies.get(query_type, strategies['general'])
        }

# Test query router
print("\n" + "="*80)
print("QUERY ROUTING")
print("="*80)

router = QueryRouter()

test_queries = [
    "What is machine learning?",
    "Compare Python and Java",
    "How do I build a web app?",
    "Why is caching important?",
    "Tell me about databases"
]

for query in test_queries:
    strategy = router.get_strategy(query)
    print(f"\nQuery: {query}")
    print(f"  Type: {strategy['query_type']}")
    print(f"  Strategy: {strategy['description']}")
    print(f"  Parameters: top_k={strategy['top_k']}, temp={strategy['temperature']}")

### 3.2 Streaming Responses

Stream responses for better UX.

In [None]:
class StreamingRAG(EnterpriseRAG):
    """
    RAG system with streaming responses.
    """
    
    def query_stream(
        self,
        request: QueryRequest
    ) -> Iterator[Dict[str, Any]]:
        """
        Process query with streaming response.
        
        Args:
            request: Query request
        
        Yields:
            Response chunks
        """
        start_time = time.time()
        
        try:
            # Yield status: retrieving
            yield {
                'type': 'status',
                'message': 'Retrieving relevant information...'
            }
            
            # Retrieve
            retrieved = self.retrieve(
                request.query,
                request.top_k,
                request.min_similarity,
                request.filters
            )
            
            # Yield sources
            yield {
                'type': 'sources',
                'count': len(retrieved),
                'sources': [
                    {
                        'document_id': chunk['document_id'],
                        'score': float(score)
                    }
                    for chunk, score in retrieved
                ]
            }
            
            if not retrieved:
                yield {
                    'type': 'answer',
                    'content': "I don't have any relevant information to answer that question.",
                    'done': True
                }
                return
            
            # Yield status: generating
            yield {
                'type': 'status',
                'message': 'Generating answer...'
            }
            
            # Prepare context
            chunks = [chunk for chunk, _ in retrieved]
            context = "\n\n".join([
                f"[Source {i+1}]: {chunk['content']}"
                for i, chunk in enumerate(chunks)
            ])
            
            prompt = f"""Answer the question based on the context provided.

Context:
{context}

Question: {request.query}

Answer:"""
            
            # Stream generation
            self.rate_limiter.wait_if_needed(estimated_tokens=len(prompt.split()) + 200)
            
            response = client.chat.completions.create(
                model=self.llm_model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.7,
                max_tokens=500,
                stream=True
            )
            
            full_answer = ""
            for chunk in response:
                if chunk.choices[0].delta.content:
                    content = chunk.choices[0].delta.content
                    full_answer += content
                    yield {
                        'type': 'answer_chunk',
                        'content': content
                    }
            
            # Yield completion
            latency = (time.time() - start_time) * 1000
            yield {
                'type': 'complete',
                'latency_ms': latency,
                'full_answer': full_answer
            }
            
        except Exception as e:
            logger.error(f"Error in streaming query: {e}")
            yield {
                'type': 'error',
                'message': str(e)
            }

# Test streaming
print("\n" + "="*80)
print("STREAMING RESPONSES")
print("="*80)

streaming_rag = StreamingRAG()

# Ingest documents
for doc in documents:
    streaming_rag.ingest_document(doc['text'], doc['id'], doc['metadata'])

# Stream query
request = QueryRequest(query="What is Python?", top_k=2)

print(f"\nQuery: {request.query}\n")
print("Streaming response:")
print("-" * 80)

for event in streaming_rag.query_stream(request):
    if event['type'] == 'status':
        print(f"\n[{event['message']}]")
    elif event['type'] == 'sources':
        print(f"\n[Found {event['count']} relevant sources]")
    elif event['type'] == 'answer_chunk':
        print(event['content'], end='', flush=True)
    elif event['type'] == 'complete':
        print(f"\n\n[Completed in {event['latency_ms']:.2f}ms]")
    elif event['type'] == 'error':
        print(f"\n[Error: {event['message']}]")

print("\n" + "-" * 80)

### Exercise 3.1: Implement Advanced Retrieval

Add sophisticated retrieval techniques:

In [None]:
# TODO: Implement advanced retrieval

class AdvancedRetriever:
    """
    TODO: Advanced retrieval techniques.
    
    Implement:
    1. Maximal Marginal Relevance (MMR)
       - Balance relevance and diversity
    
    2. Reciprocal Rank Fusion (RRF)
       - Combine multiple retrieval strategies
    
    3. Parent-Child Chunking
       - Small chunks for retrieval
       - Large chunks for context
    
    4. Hypothetical Document Embeddings (HyDE)
       - Generate hypothetical answer
       - Use it for retrieval
    
    5. Multi-Query Retrieval
       - Generate multiple query variations
       - Combine results
    """
    
    def __init__(self, rag_system: EnterpriseRAG):
        """Initialize advanced retriever."""
        self.rag = rag_system
    
    def mmr_retrieve(
        self,
        query: str,
        top_k: int = 5,
        lambda_param: float = 0.5
    ) -> List[Tuple[Dict, float]]:
        """
        TODO: Maximal Marginal Relevance retrieval.
        
        MMR = Î» * relevance - (1-Î») * redundancy
        """
        pass
    
    def hyde_retrieve(
        self,
        query: str,
        top_k: int = 5
    ) -> List[Tuple[Dict, float]]:
        """
        TODO: HyDE retrieval.
        
        1. Generate hypothetical answer to query
        2. Embed the hypothetical answer
        3. Use it to retrieve similar chunks
        """
        pass
    
    def multi_query_retrieve(
        self,
        query: str,
        num_variants: int = 3,
        top_k: int = 5
    ) -> List[Tuple[Dict, float]]:
        """
        TODO: Multi-query retrieval.
        
        1. Generate query variations
        2. Retrieve with each variation
        3. Combine and deduplicate results
        """
        pass

# Test advanced retrieval
# retriever = AdvancedRetriever(rag)
# results = retriever.mmr_retrieve("What is Python?", lambda_param=0.7)
# results = retriever.hyde_retrieve("Explain machine learning")

## Part 4: Security and Access Control

Enterprise systems need security.

In [None]:
class AccessControl:
    """
    Access control for RAG system.
    """
    
    def __init__(self):
        """Initialize access control."""
        # User permissions
        self.user_roles: Dict[str, str] = {}
        self.role_permissions: Dict[str, List[str]] = {
            'admin': ['read', 'write', 'delete', 'manage_users'],
            'editor': ['read', 'write'],
            'viewer': ['read']
        }
        
        # Document access control
        self.document_acl: Dict[str, List[str]] = {}  # doc_id -> allowed_users
        
        logger.info("Initialized access control")
    
    def add_user(self, user_id: str, role: str = 'viewer'):
        """Add user with role."""
        self.user_roles[user_id] = role
        logger.info(f"Added user {user_id} with role {role}")
    
    def set_document_access(self, document_id: str, allowed_users: List[str]):
        """Set document access control."""
        self.document_acl[document_id] = allowed_users
        logger.info(f"Set access for {document_id}: {len(allowed_users)} users")
    
    def can_access(self, user_id: str, document_id: str) -> bool:
        """
        Check if user can access document.
        
        Args:
            user_id: User identifier
            document_id: Document identifier
        
        Returns:
            True if allowed
        """
        # Admins can access everything
        if self.user_roles.get(user_id) == 'admin':
            return True
        
        # Check document ACL
        if document_id in self.document_acl:
            return user_id in self.document_acl[document_id]
        
        # Default: no access
        return False
    
    def has_permission(self, user_id: str, permission: str) -> bool:
        """
        Check if user has permission.
        
        Args:
            user_id: User identifier
            permission: Permission to check
        
        Returns:
            True if allowed
        """
        role = self.user_roles.get(user_id, 'viewer')
        return permission in self.role_permissions.get(role, [])

class SecureRAG(EnterpriseRAG):
    """
    RAG system with access control.
    """
    
    def __init__(self, *args, **kwargs):
        """Initialize secure RAG."""
        super().__init__(*args, **kwargs)
        self.access_control = AccessControl()
    
    def ingest_document(
        self,
        text: str,
        document_id: str,
        metadata: Optional[Dict] = None,
        allowed_users: Optional[List[str]] = None,
        user_id: Optional[str] = None
    ):
        """
        Ingest document with access control.
        
        Args:
            text: Document text
            document_id: Document identifier
            metadata: Optional metadata
            allowed_users: List of users who can access
            user_id: User performing the action
        """
        # Check permission
        if user_id and not self.access_control.has_permission(user_id, 'write'):
            raise PermissionError(f"User {user_id} does not have write permission")
        
        # Ingest document
        super().ingest_document(text, document_id, metadata)
        
        # Set access control
        if allowed_users:
            self.access_control.set_document_access(document_id, allowed_users)
    
    def query(self, request: QueryRequest) -> QueryResponse:
        """
        Query with access control.
        
        Args:
            request: Query request with user_id
        
        Returns:
            Query response (filtered by access)
        """
        # Get base response
        response = super().query(request)
        
        # Filter sources by access
        if request.user_id:
            filtered_sources = []
            for source in response.sources:
                doc_id = source['document_id']
                if self.access_control.can_access(request.user_id, doc_id):
                    filtered_sources.append(source)
            
            response.sources = filtered_sources
            response.num_chunks_retrieved = len(filtered_sources)
            
            # If no accessible sources, update answer
            if not filtered_sources:
                response.answer = "You don't have access to information needed to answer this question."
                response.confidence = 0.0
        
        return response

# Test secure RAG
print("\n" + "="*80)
print("SECURE RAG WITH ACCESS CONTROL")
print("="*80)

secure_rag = SecureRAG()

# Add users
secure_rag.access_control.add_user("alice", "admin")
secure_rag.access_control.add_user("bob", "editor")
secure_rag.access_control.add_user("charlie", "viewer")

# Ingest documents with access control
secure_rag.ingest_document(
    "This is confidential company data.",
    "confidential_doc",
    metadata={"classification": "confidential"},
    allowed_users=["alice"],
    user_id="alice"
)

secure_rag.ingest_document(
    "This is public information available to everyone.",
    "public_doc",
    metadata={"classification": "public"},
    allowed_users=["alice", "bob", "charlie"],
    user_id="alice"
)

# Test queries with different users
print("\nAlice (admin) query:")
request = QueryRequest(query="What is in the documents?", user_id="alice", top_k=5)
response = secure_rag.query(request)
print(f"  Sources accessible: {response.num_chunks_retrieved}")

print("\nCharlie (viewer) query:")
request = QueryRequest(query="What is in the documents?", user_id="charlie", top_k=5)
response = secure_rag.query(request)
print(f"  Sources accessible: {response.num_chunks_retrieved}")
print(f"  Answer: {response.answer[:100]}...")

### Exercise 4.1: Add Audit Logging

Implement comprehensive audit logging:

In [None]:
# TODO: Implement audit logging

class AuditLogger:
    """
    TODO: Audit logging for RAG system.
    
    Should log:
    1. All queries (who, what, when)
    2. Document access attempts
    3. Document modifications
    4. Permission changes
    5. Authentication events
    6. Errors and exceptions
    
    Should support:
    - Log rotation
    - Log levels
    - Structured logging
    - Export to external systems
    """
    
    def __init__(self, log_file: str = "rag_audit.log"):
        """Initialize audit logger."""
        pass
    
    def log_query(self, user_id: str, query: str, results: int):
        """TODO: Log query event."""
        pass
    
    def log_access_denied(self, user_id: str, resource: str):
        """TODO: Log access denied event."""
        pass
    
    def log_document_change(self, user_id: str, document_id: str, action: str):
        """TODO: Log document change."""
        pass
    
    def get_user_activity(self, user_id: str) -> List[Dict]:
        """TODO: Get activity for a user."""
        pass

# Test audit logging
# auditor = AuditLogger()
# auditor.log_query("alice", "What is Python?", 3)
# activity = auditor.get_user_activity("alice")

## Challenge Projects

### Challenge 1: Multi-Tenant RAG

Build RAG system supporting multiple tenants:

In [None]:
class MultiTenantRAG:
    """
    Multi-tenant RAG system.
    
    TODO: Implement:
    1. Tenant isolation
       - Separate data for each tenant
       - No cross-tenant access
    
    2. Resource quotas
       - Limit queries per tenant
       - Limit storage per tenant
    
    3. Custom configurations
       - Per-tenant models
       - Per-tenant parameters
    
    4. Cost tracking
       - Track usage per tenant
       - Billing/chargeback
    
    5. Tenant management
       - Create/delete tenants
       - Tenant statistics
    """
    
    def __init__(self):
        """Initialize multi-tenant RAG."""
        self.tenants: Dict[str, EnterpriseRAG] = {}
        self.quotas: Dict[str, Dict] = {}
    
    def create_tenant(
        self,
        tenant_id: str,
        quota_config: Optional[Dict] = None
    ):
        """TODO: Create new tenant."""
        pass
    
    def get_tenant_rag(self, tenant_id: str) -> EnterpriseRAG:
        """TODO: Get RAG system for tenant."""
        pass
    
    def query_with_tenant(
        self,
        tenant_id: str,
        request: QueryRequest
    ) -> QueryResponse:
        """TODO: Query with tenant isolation."""
        pass
    
    def get_tenant_usage(self, tenant_id: str) -> Dict:
        """TODO: Get usage statistics for tenant."""
        pass

# Usage:
# multi_tenant = MultiTenantRAG()
# multi_tenant.create_tenant("company_a", quota_config={"max_queries": 1000})
# multi_tenant.create_tenant("company_b", quota_config={"max_queries": 5000})
# 
# request = QueryRequest(query="What is AI?")
# response = multi_tenant.query_with_tenant("company_a", request)

### Challenge 2: Federated RAG

RAG across multiple distributed sources:

In [None]:
class FederatedRAG:
    """
    Federated RAG across distributed sources.
    
    TODO: Implement:
    1. Multiple data sources
       - Different databases
       - Different embedding models
       - Different formats
    
    2. Federated search
       - Query all sources in parallel
       - Merge and rank results
    
    3. Source prioritization
       - Weight by source quality
       - Prefer certain sources for certain queries
    
    4. Caching strategy
       - Cache results from remote sources
       - Invalidation policy
    
    5. Fault tolerance
       - Handle source failures gracefully
       - Partial results
    """
    
    def __init__(self):
        """Initialize federated RAG."""
        self.sources: Dict[str, EnterpriseRAG] = {}
        self.source_weights: Dict[str, float] = {}
    
    def register_source(
        self,
        source_id: str,
        rag_system: EnterpriseRAG,
        weight: float = 1.0
    ):
        """TODO: Register a data source."""
        pass
    
    def federated_query(
        self,
        request: QueryRequest,
        source_ids: Optional[List[str]] = None
    ) -> QueryResponse:
        """
        TODO: Query across federated sources.
        
        Should:
        - Query all sources in parallel
        - Merge results
        - Re-rank combined results
        - Handle partial failures
        """
        pass
    
    def health_check_sources(self) -> Dict[str, bool]:
        """TODO: Check health of all sources."""
        pass

# Usage:
# federated = FederatedRAG()
# federated.register_source("internal_docs", internal_rag, weight=1.0)
# federated.register_source("public_docs", public_rag, weight=0.8)
# federated.register_source("wiki", wiki_rag, weight=0.6)
# 
# request = QueryRequest(query="What is machine learning?")
# response = federated.federated_query(request)

### Challenge 3: Self-Healing RAG

RAG system that monitors and improves itself:

In [None]:
class SelfHealingRAG:
    """
    Self-healing RAG system.
    
    TODO: Implement:
    1. Automatic problem detection
       - High error rate
       - Poor answer quality
       - Slow performance
       - Low cache hit rate
    
    2. Automatic remediation
       - Adjust parameters
       - Refresh cache
       - Re-index documents
       - Switch models
    
    3. A/B testing
       - Test different configurations
       - Choose best performer
    
    4. Learning from feedback
       - Track user satisfaction
       - Improve based on feedback
    
    5. Predictive maintenance
       - Predict issues before they occur
       - Proactive optimization
    """
    
    def __init__(self, rag_system: EnterpriseRAG):
        """Initialize self-healing RAG."""
        self.rag = rag_system
        self.health_history: List[Dict] = []
        self.remediation_history: List[Dict] = []
    
    def monitor_health(self) -> Dict[str, Any]:
        """
        TODO: Monitor system health.
        
        Check:
        - Error rate
        - Latency
        - Cache performance
        - Answer quality
        """
        pass
    
    def detect_issues(self) -> List[str]:
        """TODO: Detect system issues."""
        pass
    
    def remediate(self, issue: str):
        """
        TODO: Automatically fix issue.
        
        Actions might include:
        - Clear cache
        - Adjust top_k
        - Change temperature
        - Re-index documents
        """
        pass
    
    def run_health_check_loop(self, interval_seconds: int = 60):
        """TODO: Continuous health monitoring."""
        pass

# Usage:
# self_healing = SelfHealingRAG(rag)
# self_healing.run_health_check_loop(interval_seconds=300)  # Check every 5 min
# 
# # System automatically detects and fixes issues
# issues = self_healing.detect_issues()
# for issue in issues:
#     self_healing.remediate(issue)

## Summary

In this lab, you've learned:

1. âœ… Production-ready RAG components
2. âœ… Embeddings caching for cost reduction
3. âœ… Rate limiting and concurrency control
4. âœ… Enterprise RAG implementation
5. âœ… Batch processing and parallelization
6. âœ… Advanced RAG patterns
7. âœ… Query routing and optimization
8. âœ… Streaming responses
9. âœ… Security and access control
10. âœ… Monitoring and metrics

### Key Takeaways

**Enterprise Requirements:**

**Performance:**
- Sub-second query latency (p95)
- High throughput (100+ queries/sec)
- Efficient caching strategies
- Parallel processing

**Reliability:**
- Error handling and recovery
- Rate limiting
- Circuit breakers
- Health monitoring

**Security:**
- Access control and authentication
- Audit logging
- Data isolation
- Encryption at rest/transit

**Scalability:**
- Horizontal scaling
- Load balancing
- Database sharding
- Caching layers

**Cost Optimization:**
- Embeddings caching (reduces API calls by 70-90%)
- Batch processing
- Resource pooling
- Usage tracking

### Best Practices

**Caching:**
- Cache embeddings aggressively
- Use memory + disk caching
- Set appropriate TTL
- Monitor hit rates (target >80%)

**Rate Limiting:**
- Stay within API limits
- Implement backoff strategies
- Queue requests during peaks
- Monitor capacity usage

**Concurrency:**
- Limit concurrent requests
- Use connection pools
- Implement request queuing
- Thread-safe operations

**Monitoring:**
- Track latency percentiles
- Monitor error rates
- Cache hit rates
- User satisfaction scores

**Security:**
- Implement RBAC (Role-Based Access Control)
- Audit all operations
- Sanitize user inputs
- Validate permissions

### Production Checklist

**Before Deployment:**

- [ ] Load testing completed
- [ ] Error handling tested
- [ ] Monitoring configured
- [ ] Logging set up
- [ ] Security review done
- [ ] Backup strategy in place
- [ ] Rollback plan ready
- [ ] Documentation complete

**Ongoing Operations:**

- [ ] Monitor metrics daily
- [ ] Review error logs
- [ ] Optimize based on usage
- [ ] Update models regularly
- [ ] Test disaster recovery
- [ ] Track costs
- [ ] Collect user feedback
- [ ] Iterate and improve

### Performance Targets

**Latency:**
- p50: <500ms
- p95: <1000ms
- p99: <2000ms

**Availability:**
- Uptime: >99.9%
- Error rate: <0.1%

**Quality:**
- Answer accuracy: >90%
- User satisfaction: >4.5/5

**Cost:**
- Cache hit rate: >80%
- Cost per query: <$0.01

### Common Production Issues

1. **Cache Misses**: Tune cache size and TTL
2. **Rate Limits**: Implement better queuing
3. **Slow Queries**: Optimize retrieval, reduce top_k
4. **High Costs**: Increase caching, batch operations
5. **Poor Quality**: Improve chunking, adjust prompts
6. **Concurrency Issues**: Add locks, use thread pools

### Next Steps

- Deploy RAG system to production
- Implement monitoring dashboards
- Set up alerting
- Conduct load testing
- Gather user feedback
- Iterate based on metrics
- Explore advanced frameworks:
  - **LangChain**: RAG orchestration
  - **LlamaIndex**: Advanced indexing
  - **Weaviate**: Vector database
  - **Pinecone**: Managed vector DB

**Congratulations!** You've completed the RAG Fundamentals week and built enterprise-grade RAG systems! ðŸŽ‰

**Provided by:** ADC ENGINEERING & CONSULTING LTD