# 12. Contextual RAG - Context-Augmented Retrieval

**Complexity:** ⭐⭐⭐

## Overview

**Contextual RAG** is a technique introduced by Anthropic that improves retrieval quality by augmenting each text chunk with contextual information about its role within the larger document.

### The Problem

In standard RAG, document chunks lose their context:
- Chunks are embedded in isolation
- References ("this", "these methods", "the above") become ambiguous
- Topic boundaries aren't clear
- Retrieval may miss relevant chunks due to missing context

### The Solution

Contextual RAG prepends each chunk with:
1. **Document-level summary**: What the overall document is about
2. **Chunk-level context**: How this specific chunk relates to the document

```
Standard chunk: "The function takes two parameters..."
Contextual chunk: "[This section describes the authentication module's 
                   login function in the User Management API.] 
                   The function takes two parameters..."
```

### Pipeline

```
Documents → Split into chunks → For each chunk:
    1. Generate document summary (once per document)
    2. Generate chunk context (per chunk)
    3. Prepend context to chunk
    4. Embed contextual chunk
→ Store in vector database → Retrieve → Generate answer
```

### When to Use

✅ **Good for:**
- Long documents with many sections
- Technical documentation with cross-references
- Documents where context matters (legal, medical, academic)
- Improving precision on ambiguous queries

❌ **Not ideal for:**
- Short, self-contained documents
- Real-time applications (context generation adds latency)
- Cost-sensitive applications (extra LLM calls)

### Trade-offs

**Pros:**
- ✅ Better retrieval precision
- ✅ Handles ambiguous references
- ✅ Maintains document structure awareness
- ✅ One-time preprocessing cost

**Cons:**
- ❌ Higher indexing cost (LLM calls for each chunk)
- ❌ Larger embeddings (context + content)
- ❌ More complex preprocessing pipeline
- ❌ Slower indexing time

---

## Implementation

Let's build Contextual RAG step by step.

## 1. Setup and Imports

In [3]:
import sys
import time
from pathlib import Path

# Add parent directory to path for imports
sys.path.append(str(Path("../..").resolve()))

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from shared.config import (
    verify_api_key,
    DEFAULT_MODEL,
    DEFAULT_TEMPERATURE,
    OPENAI_EMBEDDING_MODEL,
    VECTOR_STORE_DIR,
)
from shared.loaders import load_and_split
from shared.prompts import (
    DOCUMENT_SUMMARY_PROMPT,
    CONTEXTUAL_CHUNK_PROMPT,
    CONTEXTUAL_RAG_ANSWER_PROMPT,
    RAG_PROMPT_TEMPLATE,
)
from shared.utils import (
    format_docs,
    print_section_header,
    load_vector_store,
    save_vector_store,
)

# Verify API key
verify_api_key()

print("✓ All imports successful")
print(f"✓ Using model: {DEFAULT_MODEL}")
print(f"✓ Using embeddings: {OPENAI_EMBEDDING_MODEL}")

✓ OpenAI API Key: LOADED
  Preview: sk-proj...vIQA
✓ All imports successful
✓ Using model: gpt-4o-mini
✓ Using embeddings: text-embedding-3-small


## 2. Load and Prepare Documents

In [None]:
print_section_header("Loading Documents")

# Load and split documents (returns tuple: original_docs, chunks)
_, docs = load_and_split(
    chunk_size=1000,
    chunk_overlap=200,
)

print(f"\n✓ Loaded {len(docs)} chunks")
print(f"✓ Average chunk size: {sum(len(d.page_content) for d in docs) / len(docs):.0f} chars")

# Show example chunk
print("\n" + "=" * 80)
print("Example chunk (standard):")
print("=" * 80)
print(docs[5].page_content[:300] + "...")

## 3. Group Chunks by Document

We need to group chunks by their source document to generate document-level summaries.

In [None]:
from collections import defaultdict

print_section_header("Grouping Chunks by Document")

# Group chunks by source document
docs_by_source = defaultdict(list)
for doc in docs:
    source = doc.metadata.get("source", "unknown")
    docs_by_source[source].append(doc)

print(f"\n✓ Found {len(docs_by_source)} unique source documents")
print("\nChunks per source:")
for source, chunks in list(docs_by_source.items())[:5]:
    print(f"  • {source}: {len(chunks)} chunks")

## 4. Generate Document Summaries

For each source document, we'll generate a summary that captures its main purpose and topics.

In [None]:
print_section_header("Generating Document Summaries")

# Initialize LLM for summarization
llm = ChatOpenAI(
    model=DEFAULT_MODEL,
    temperature=DEFAULT_TEMPERATURE,
)

# Create summarization chain
summary_chain = DOCUMENT_SUMMARY_PROMPT | llm | StrOutputParser()

# Generate summaries for each source document
doc_summaries = {}
print("\nGenerating summaries...")

for i, (source, chunks) in enumerate(docs_by_source.items(), 1):
    # Combine first few chunks to represent the document
    # (using all chunks would be too long and expensive)
    doc_text = "\n\n".join([chunk.page_content for chunk in chunks[:3]])
    
    # Generate summary
    summary = summary_chain.invoke({"document": doc_text})
    doc_summaries[source] = summary
    
    print(f"  [{i}/{len(docs_by_source)}] {source[:50]}...")
    
    # Rate limiting to avoid API throttling
    if i < len(docs_by_source):
        time.sleep(0.5)

print(f"\n✓ Generated {len(doc_summaries)} document summaries")

# Show example summary
example_source = list(doc_summaries.keys())[0]
print("\n" + "=" * 80)
print(f"Example summary for: {example_source}")
print("=" * 80)
print(doc_summaries[example_source])

## 5. Generate Contextual Chunks

Now we'll augment each chunk with contextual information.

In [None]:
print_section_header("Generating Contextual Chunks")

# Create contextualization chain
context_chain = CONTEXTUAL_CHUNK_PROMPT | llm | StrOutputParser()

# Generate contextual chunks
contextual_docs = []
print(f"\nProcessing {len(docs)} chunks...")
print("(This may take a few minutes)\n")

for i, doc in enumerate(docs, 1):
    source = doc.metadata.get("source", "unknown")
    doc_summary = doc_summaries.get(source, "No summary available")
    
    # Generate contextual description
    context = context_chain.invoke({
        "doc_summary": doc_summary,
        "chunk": doc.page_content[:500],  # Limit chunk size for context generation
    })
    
    # Create new document with context prepended
    contextual_content = f"[Context: {context}]\n\n{doc.page_content}"
    
    contextual_doc = Document(
        page_content=contextual_content,
        metadata={
            **doc.metadata,
            "context": context,
            "original_content": doc.page_content,
        },
    )
    contextual_docs.append(contextual_doc)
    
    # Progress indicator
    if i % 10 == 0:
        print(f"  Processed {i}/{len(docs)} chunks...")
    
    # Rate limiting
    if i < len(docs):
        time.sleep(0.3)

print(f"\n✓ Generated {len(contextual_docs)} contextual chunks")

## 6. Compare Standard vs Contextual Chunks

In [None]:
print_section_header("Comparing Standard vs Contextual Chunks")

# Show example comparison
example_idx = 5

print("\n" + "=" * 80)
print("STANDARD CHUNK:")
print("=" * 80)
print(docs[example_idx].page_content[:400] + "...")

print("\n" + "=" * 80)
print("CONTEXTUAL CHUNK:")
print("=" * 80)
print(contextual_docs[example_idx].page_content[:500] + "...")

print("\n" + "=" * 80)
print("STATISTICS:")
print("=" * 80)
avg_context_len = sum(
    len(doc.metadata.get("context", "")) for doc in contextual_docs
) / len(contextual_docs)
print(f"Average context length: {avg_context_len:.0f} characters")
print(f"Context overhead: {avg_context_len / 1000 * 100:.1f}% (of 1000 char chunks)")

## 7. Create Vector Stores

We'll create two vector stores to compare:
1. Standard RAG (no context)
2. Contextual RAG (with context)

In [None]:
from langchain_community.vectorstores import FAISS

print_section_header("Creating Vector Stores")

# Initialize embeddings
embeddings = OpenAIEmbeddings(model=OPENAI_EMBEDDING_MODEL)

# Try to load existing vector stores
standard_store_path = VECTOR_STORE_DIR / "contextual_rag_standard"
contextual_store_path = VECTOR_STORE_DIR / "contextual_rag_contextual"

print("\nChecking for existing vector stores...")

# Standard vector store
vectorstore_standard = load_vector_store(
    standard_store_path,
    embeddings,
)

if vectorstore_standard is None:
    print("\nCreating standard vector store...")
    vectorstore_standard = FAISS.from_documents(docs, embeddings)
    save_vector_store(vectorstore_standard, standard_store_path)
    print("✓ Standard vector store created and saved")
else:
    print("✓ Loaded existing standard vector store")

# Contextual vector store
vectorstore_contextual = load_vector_store(
    contextual_store_path,
    embeddings,
)

if vectorstore_contextual is None:
    print("\nCreating contextual vector store...")
    vectorstore_contextual = FAISS.from_documents(contextual_docs, embeddings)
    save_vector_store(vectorstore_contextual, contextual_store_path)
    print("✓ Contextual vector store created and saved")
else:
    print("✓ Loaded existing contextual vector store")

print(f"\n✓ Both vector stores ready")

## 8. Build RAG Chains

In [None]:
print_section_header("Building RAG Chains")

# Create retrievers
retriever_standard = vectorstore_standard.as_retriever(
    search_kwargs={"k": 4}
)
retriever_contextual = vectorstore_contextual.as_retriever(
    search_kwargs={"k": 4}
)

# Standard RAG chain
chain_standard = (
    {"context": retriever_standard | format_docs, "input": RunnablePassthrough()}
    | RAG_PROMPT_TEMPLATE
    | llm
    | StrOutputParser()
)

# Contextual RAG chain
chain_contextual = (
    {"context": retriever_contextual | format_docs, "input": RunnablePassthrough()}
    | CONTEXTUAL_RAG_ANSWER_PROMPT
    | llm
    | StrOutputParser()
)

print("✓ Standard RAG chain created")
print("✓ Contextual RAG chain created")

## 9. Test and Compare

Let's test both approaches with queries that benefit from context.

In [None]:
print_section_header("Testing Queries")

# Test queries that benefit from context
test_queries = [
    "How do I use LCEL to build chains?",
    "What are the different types of memory in LangChain?",
    "Explain the role of retrievers in RAG applications",
]

for i, query in enumerate(test_queries, 1):
    print("\n" + "=" * 80)
    print(f"Query {i}: {query}")
    print("=" * 80)
    
    # Standard RAG
    print("\n[STANDARD RAG]")
    print("-" * 80)
    start_time = time.time()
    response_standard = chain_standard.invoke(query)
    time_standard = time.time() - start_time
    print(response_standard)
    print(f"\n⏱️  Time: {time_standard:.2f}s")
    
    # Contextual RAG
    print("\n[CONTEXTUAL RAG]")
    print("-" * 80)
    start_time = time.time()
    response_contextual = chain_contextual.invoke(query)
    time_contextual = time.time() - start_time
    print(response_contextual)
    print(f"\n⏱️  Time: {time_contextual:.2f}s")
    
    # Comparison
    print("\n" + "-" * 80)
    print("COMPARISON:")
    print(f"  • Latency difference: {abs(time_contextual - time_standard):.2f}s")
    print(f"  • Response length difference: {len(response_contextual) - len(response_standard)} chars")

## 10. Retrieval Quality Comparison

Let's examine what documents each approach retrieves.

In [None]:
print_section_header("Retrieval Quality Analysis")

test_query = "What are the different types of memory in LangChain?"

print(f"\nQuery: {test_query}")
print("\n" + "=" * 80)

# Standard retrieval
docs_standard = retriever_standard.invoke(test_query)
print("\n[STANDARD RETRIEVAL]")
print("-" * 80)
for i, doc in enumerate(docs_standard, 1):
    print(f"\nDocument {i}:")
    print(f"Source: {doc.metadata.get('source', 'unknown')[:60]}")
    print(f"Preview: {doc.page_content[:200]}...")

# Contextual retrieval
docs_contextual = retriever_contextual.invoke(test_query)
print("\n" + "=" * 80)
print("\n[CONTEXTUAL RETRIEVAL]")
print("-" * 80)
for i, doc in enumerate(docs_contextual, 1):
    print(f"\nDocument {i}:")
    print(f"Source: {doc.metadata.get('source', 'unknown')[:60]}")
    print(f"Context: {doc.metadata.get('context', 'N/A')[:150]}...")
    print(f"Preview: {doc.metadata.get('original_content', doc.page_content)[:200]}...")

## 11. Performance Metrics

In [None]:
print_section_header("Performance Metrics")

# Indexing costs
num_documents = len(docs)
num_summaries = len(doc_summaries)
num_context_calls = len(contextual_docs)

print("\nINDEXING COSTS:")
print("-" * 80)
print(f"Documents processed: {num_documents}")
print(f"Document summaries generated: {num_summaries}")
print(f"Chunk contexts generated: {num_context_calls}")
print(f"Total LLM calls for contextualization: {num_summaries + num_context_calls}")
print(f"\nEstimated additional indexing time: ~{(num_summaries + num_context_calls) * 0.5 / 60:.1f} minutes")

# Storage costs
avg_standard_len = sum(len(doc.page_content) for doc in docs) / len(docs)
avg_contextual_len = sum(len(doc.page_content) for doc in contextual_docs) / len(contextual_docs)
overhead = (avg_contextual_len - avg_standard_len) / avg_standard_len * 100

print("\nSTORAGE OVERHEAD:")
print("-" * 80)
print(f"Average standard chunk: {avg_standard_len:.0f} chars")
print(f"Average contextual chunk: {avg_contextual_len:.0f} chars")
print(f"Overhead: {overhead:.1f}%")

# Query costs
print("\nQUERY COSTS:")
print("-" * 80)
print("Standard RAG: k retrievals + 1 generation")
print("Contextual RAG: k retrievals + 1 generation (same as standard)")
print("\n✓ No additional query-time cost!")

## 12. Key Takeaways

### Summary

**Contextual RAG** improves retrieval quality by augmenting chunks with contextual information:
- Document-level summaries provide high-level context
- Chunk-level descriptions clarify the role of each chunk
- Better handling of ambiguous references and cross-references

### Cost-Benefit Analysis

| Aspect | Impact | Notes |
|--------|--------|-------|
| **Indexing Time** | ❌ +50-100% | One-time cost |
| **Indexing Cost** | ❌ +$X | LLM calls for each chunk |
| **Storage** | ❌ +15-30% | Larger embeddings |
| **Query Time** | ✅ Same | No runtime overhead |
| **Query Cost** | ✅ Same | No additional calls |
| **Retrieval Quality** | ✅ Better | Improved precision |
| **Answer Quality** | ✅ Better | More context-aware |

### Best Practices

1. **Use for long documents**: Most beneficial when documents are long and complex
2. **Batch processing**: Generate contexts in batches to reduce costs
3. **Cache summaries**: Store document summaries to avoid regeneration
4. **Balance context length**: Keep contexts concise (1-2 sentences)
5. **Quality check**: Manually review a sample of generated contexts

### When to Use

Choose **Contextual RAG** when:
- ✅ Document quality matters more than cost
- ✅ Dealing with technical or complex documents
- ✅ Users ask ambiguous or context-dependent questions
- ✅ One-time indexing cost is acceptable

Stick with **Standard RAG** when:
- ✅ Documents are short and self-contained
- ✅ Cost optimization is critical
- ✅ Real-time indexing is required
- ✅ Simpler implementation is preferred

### Next Steps

- **Combine with other techniques**: Contextual chunks work well with re-ranking
- **Experiment with context generation**: Try different prompt strategies
- **Measure impact**: Use RAGAS metrics to quantify improvement
- **Optimize costs**: Use cheaper models for context generation

---

**Complexity Rating:** ⭐⭐⭐ (Medium - straightforward concept, some implementation overhead)

**Production Readiness:** ⭐⭐⭐⭐ (High - proven technique, minor trade-offs)

Continue to **13_fusion_rag.ipynb** for RAG-Fusion with Reciprocal Rank Fusion!