# Demo #7: Corrective RAG (CRAG) - Self-Correcting Retrieval

## Overview

This notebook demonstrates **Corrective RAG (CRAG)**, a self-reflective system that evaluates retrieval quality and triggers corrective actions when internal knowledge is insufficient.

### Key Concepts

1. **Self-Correction**: Automatically evaluate whether retrieved documents are relevant
2. **Dynamic Routing**: Route queries based on confidence scores
3. **Fallback Mechanisms**: Use external knowledge sources (web search) when internal knowledge fails
4. **Knowledge Refinement**: Filter and grade documents at the sentence level

### The CRAG Architecture

```
Query → Retrieve from Internal KB → Evaluate Relevance → Route:
├─ High Confidence (>0.7): Use internal documents directly
├─ Low Confidence (<0.4): Discard internal results, use web search
└─ Ambiguous (0.4-0.7): Merge internal + web search results
```

### Citation

This implementation is based on:
- **Corrective Retrieval Augmented Generation** (arXiv:2401.15884)
- Reference #67 in the workshop curriculum

## 1. Setup and Imports

In [None]:
import os
import sys
from pathlib import Path
from typing import List, Dict, Tuple
import json

# LlamaIndex core imports
from llama_index.core import (
    VectorStoreIndex,
    SimpleDirectoryReader,
    Settings,
    StorageContext,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.query_engine import RetrieverQueryEngine

# Azure OpenAI imports
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding

# Web search (DuckDuckGo)
from duckduckgo_search import DDGS

# Load environment variables
from dotenv import load_dotenv
load_dotenv()

print("✓ All imports successful")

## 2. Configure Azure OpenAI

In [None]:
# Azure OpenAI Configuration
api_key = os.getenv("AZURE_OPENAI_API_KEY")
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")

# Initialize LLM
llm = AzureOpenAI(
    model="gpt-4",
    deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
    api_key=api_key,
    azure_endpoint=azure_endpoint,
    api_version=api_version,
    temperature=0.1,
)

# Initialize Embedding Model
embed_model = AzureOpenAIEmbedding(
    model="text-embedding-ada-002",
    deployment_name=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
    api_key=api_key,
    azure_endpoint=azure_endpoint,
    api_version=api_version,
)

# Configure global settings
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = 512
Settings.chunk_overlap = 50

print("✓ Azure OpenAI configured successfully")

## 3. Create Limited Internal Knowledge Base

For CRAG demonstration, we'll create a **small, focused knowledge base** about specific technical topics. This will allow us to test scenarios where queries fall both within and outside the knowledge base scope.

In [None]:
# Load documents from a subset of technical docs
# Using tech_docs directory which contains: BERT, GPT-4, Docker, REST API, etc.
data_path = Path("../RAG_v2/data/tech_docs")

if not data_path.exists():
    print(f"Warning: Data path {data_path} not found. Using sample documents.")
    # Create sample documents if path doesn't exist
    data_path = Path("./sample_data")
    data_path.mkdir(exist_ok=True)

# Load documents
documents = SimpleDirectoryReader(str(data_path)).load_data()

print(f"✓ Loaded {len(documents)} documents from internal knowledge base")
print(f"\nDocument sources:")
for i, doc in enumerate(documents, 1):
    filename = Path(doc.metadata.get('file_name', 'unknown')).name
    preview = doc.text[:100].replace('\n', ' ')
    print(f"{i}. {filename}: {preview}...")

## 4. Build Internal Vector Index

In [None]:
# Parse documents into chunks
node_parser = SentenceSplitter(
    chunk_size=512,
    chunk_overlap=50,
)

nodes = node_parser.get_nodes_from_documents(documents)
print(f"✓ Created {len(nodes)} chunks from documents")

# Build vector index
internal_index = VectorStoreIndex(
    nodes=nodes,
    embed_model=embed_model,
)

print("✓ Vector index created successfully")

## 5. Implement Retrieval Evaluator

The evaluator uses the LLM to assess how relevant the retrieved documents are to the query. It returns a confidence score between 0 and 1.

In [None]:
def evaluate_retrieval_relevance(
    query: str,
    retrieved_nodes: List[NodeWithScore],
    llm: AzureOpenAI
) -> Tuple[float, str]:
    """
    Evaluate the relevance of retrieved documents to the query.
    
    Args:
        query: The user's question
        retrieved_nodes: List of retrieved document nodes
        llm: The language model for evaluation
    
    Returns:
        Tuple of (confidence_score, explanation)
        - confidence_score: 0.0-1.0 indicating relevance
        - explanation: Text explanation of the score
    """
    # Combine retrieved documents
    retrieved_text = "\n\n".join([
        f"Document {i+1}:\n{node.node.text[:500]}..."
        for i, node in enumerate(retrieved_nodes[:3])  # Evaluate top 3
    ])
    
    # Create evaluation prompt
    eval_prompt = f"""You are a relevance evaluator for a retrieval system.

Given a user query and retrieved documents, assess how well the documents can answer the query.

USER QUERY:
{query}

RETRIEVED DOCUMENTS:
{retrieved_text}

Evaluate the relevance and provide:
1. A confidence score between 0.0 and 1.0:
   - 0.0-0.3: Documents are irrelevant or missing critical information
   - 0.4-0.6: Documents have partial information but gaps exist
   - 0.7-1.0: Documents contain sufficient relevant information

2. A brief explanation (2-3 sentences)

Respond in JSON format:
{{
    "confidence_score": <float 0.0-1.0>,
    "explanation": "<your explanation>"
}}
"""
    
    try:
        # Get LLM evaluation
        response = llm.complete(eval_prompt)
        response_text = response.text.strip()
        
        # Parse JSON response
        # Handle markdown code blocks if present
        if "```json" in response_text:
            response_text = response_text.split("```json")[1].split("```")[0].strip()
        elif "```" in response_text:
            response_text = response_text.split("```")[1].split("```")[0].strip()
        
        result = json.loads(response_text)
        confidence = float(result.get("confidence_score", 0.5))
        explanation = result.get("explanation", "No explanation provided")
        
        # Clamp confidence to valid range
        confidence = max(0.0, min(1.0, confidence))
        
        return confidence, explanation
        
    except Exception as e:
        print(f"Error in evaluation: {e}")
        # Default to medium confidence on error
        return 0.5, f"Evaluation error: {str(e)}"

print("✓ Retrieval evaluator function defined")

## 6. Implement Web Search Fallback

When internal knowledge is insufficient, we use DuckDuckGo to search the web for current information.

In [None]:
def web_search_fallback(query: str, max_results: int = 5) -> List[str]:
    """
    Perform web search using DuckDuckGo as fallback for missing knowledge.
    
    Args:
        query: The search query
        max_results: Maximum number of results to return
    
    Returns:
        List of formatted search result texts
    """
    try:
        ddgs = DDGS()
        results = ddgs.text(query, max_results=max_results)
        
        formatted_results = []
        for i, result in enumerate(results, 1):
            formatted = f"""Source {i}: {result['title']}
URL: {result['href']}
Content: {result['body']}
"""
            formatted_results.append(formatted)
        
        return formatted_results
        
    except Exception as e:
        print(f"Web search error: {e}")
        return [f"Web search unavailable: {str(e)}"]

# Test web search
test_results = web_search_fallback("machine learning latest trends", max_results=2)
print("✓ Web search function tested successfully")
print(f"\nSample web search result:")
print(test_results[0][:200] + "..." if test_results else "No results")

## 7. Implement Sentence-Level Knowledge Filtering

This is an optional enhancement that filters retrieved documents at the sentence level to remove irrelevant content.

In [None]:
def filter_sentences_by_relevance(
    query: str,
    text: str,
    embed_model: AzureOpenAIEmbedding,
    threshold: float = 0.5
) -> str:
    """
    Filter sentences in text by relevance to query.
    
    Args:
        query: The search query
        text: The document text to filter
        embed_model: Embedding model for similarity
        threshold: Minimum similarity score to keep sentence
    
    Returns:
        Filtered text with only relevant sentences
    """
    # Split into sentences (simple approach)
    sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 20]
    
    if len(sentences) == 0:
        return text
    
    try:
        # Get embeddings
        query_embedding = embed_model.get_query_embedding(query)
        sentence_embeddings = [embed_model.get_text_embedding(s) for s in sentences]
        
        # Calculate cosine similarities
        import numpy as np
        
        def cosine_similarity(a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
        
        similarities = [
            cosine_similarity(query_embedding, sent_emb)
            for sent_emb in sentence_embeddings
        ]
        
        # Filter sentences above threshold
        relevant_sentences = [
            sentences[i] for i, sim in enumerate(similarities)
            if sim >= threshold
        ]
        
        if len(relevant_sentences) == 0:
            # If nothing passes, keep top 3 sentences
            top_indices = np.argsort(similarities)[-3:]
            relevant_sentences = [sentences[i] for i in top_indices]
        
        return '. '.join(relevant_sentences) + '.'
        
    except Exception as e:
        print(f"Error in sentence filtering: {e}")
        return text  # Return original on error

print("✓ Sentence filtering function defined")

## 8. Build CRAG Query Engine

This is the core of CRAG: a custom query engine that evaluates retrieval quality and routes accordingly.

In [None]:
class CRAGQueryEngine:
    """
    Corrective RAG Query Engine with self-evaluation and dynamic routing.
    """
    
    def __init__(
        self,
        internal_index: VectorStoreIndex,
        llm: AzureOpenAI,
        embed_model: AzureOpenAIEmbedding,
        high_threshold: float = 0.7,
        low_threshold: float = 0.4,
        top_k: int = 5,
        use_sentence_filtering: bool = True,
    ):
        self.internal_index = internal_index
        self.llm = llm
        self.embed_model = embed_model
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.top_k = top_k
        self.use_sentence_filtering = use_sentence_filtering
        
        # Create internal retriever
        self.internal_retriever = internal_index.as_retriever(
            similarity_top_k=top_k
        )
    
    def query(self, query_str: str) -> Dict:
        """
        Execute CRAG query with evaluation and routing.
        
        Returns:
            Dictionary with:
            - response: Final answer
            - confidence: Confidence score
            - route: Which route was taken
            - source: Where information came from
            - explanation: Evaluation explanation
        """
        print(f"\n{'='*80}")
        print(f"CRAG Query: {query_str}")
        print(f"{'='*80}")
        
        # Step 1: Retrieve from internal knowledge base
        print("\n[Step 1] Retrieving from internal knowledge base...")
        retrieved_nodes = self.internal_retriever.retrieve(query_str)
        print(f"Retrieved {len(retrieved_nodes)} documents")
        
        # Step 2: Evaluate relevance
        print("\n[Step 2] Evaluating retrieval relevance...")
        confidence, explanation = evaluate_retrieval_relevance(
            query_str, retrieved_nodes, self.llm
        )
        print(f"Confidence Score: {confidence:.2f}")
        print(f"Explanation: {explanation}")
        
        # Step 3: Route based on confidence
        context_chunks = []
        source = ""
        route = ""
        
        if confidence >= self.high_threshold:
            # HIGH CONFIDENCE: Use internal documents only
            route = "HIGH_CONFIDENCE"
            source = "Internal Knowledge Base"
            print(f"\n[Step 3] Route: {route} → Using internal documents")
            
            for node in retrieved_nodes:
                text = node.node.text
                if self.use_sentence_filtering:
                    text = filter_sentences_by_relevance(
                        query_str, text, self.embed_model
                    )
                context_chunks.append(text)
        
        elif confidence < self.low_threshold:
            # LOW CONFIDENCE: Use web search only
            route = "LOW_CONFIDENCE"
            source = "Web Search"
            print(f"\n[Step 3] Route: {route} → Triggering web search fallback")
            
            web_results = web_search_fallback(query_str, max_results=3)
            context_chunks = web_results
        
        else:
            # AMBIGUOUS: Merge internal + web search
            route = "AMBIGUOUS"
            source = "Internal KB + Web Search"
            print(f"\n[Step 3] Route: {route} → Merging internal and web sources")
            
            # Add filtered internal documents
            for node in retrieved_nodes[:2]:  # Top 2 internal
                text = node.node.text
                if self.use_sentence_filtering:
                    text = filter_sentences_by_relevance(
                        query_str, text, self.embed_model
                    )
                context_chunks.append(text)
            
            # Add web search results
            web_results = web_search_fallback(query_str, max_results=2)
            context_chunks.extend(web_results)
        
        # Step 4: Generate final answer
        print(f"\n[Step 4] Generating final answer with {len(context_chunks)} context chunks...")
        
        combined_context = "\n\n".join(context_chunks)
        
        generation_prompt = f"""You are a helpful assistant. Answer the question based on the provided context.

CONTEXT:
{combined_context}

QUESTION:
{query_str}

Provide a comprehensive answer based on the context. If the context is insufficient, acknowledge this.
"""
        
        response = self.llm.complete(generation_prompt)
        
        return {
            "response": response.text,
            "confidence": confidence,
            "route": route,
            "source": source,
            "explanation": explanation,
            "context_chunks": context_chunks,
        }

# Create CRAG query engine
crag_engine = CRAGQueryEngine(
    internal_index=internal_index,
    llm=llm,
    embed_model=embed_model,
    high_threshold=0.7,
    low_threshold=0.4,
    top_k=5,
    use_sentence_filtering=True,
)

print("\n✓ CRAG Query Engine initialized")

## 9. Test Scenario 1: In-Domain Query (High Confidence)

Query that should be well-covered by internal knowledge base.

In [None]:
# Test query that should be in our tech_docs knowledge base
query_1 = "What is BERT and how does it differ from GPT models?"

result_1 = crag_engine.query(query_1)

print("\n" + "="*80)
print("FINAL RESULT")
print("="*80)
print(f"Route: {result_1['route']}")
print(f"Confidence: {result_1['confidence']:.2f}")
print(f"Source: {result_1['source']}")
print(f"\nAnswer:\n{result_1['response']}")

## 10. Test Scenario 2: Out-of-Domain Query (Low Confidence)

Query that is outside our internal knowledge base scope.

In [None]:
# Test query about current events (not in our static knowledge base)
query_2 = "What are the latest developments in quantum computing as of 2025?"

result_2 = crag_engine.query(query_2)

print("\n" + "="*80)
print("FINAL RESULT")
print("="*80)
print(f"Route: {result_2['route']}")
print(f"Confidence: {result_2['confidence']:.2f}")
print(f"Source: {result_2['source']}")
print(f"\nAnswer:\n{result_2['response']}")

## 11. Test Scenario 3: Ambiguous Query (Medium Confidence)

Query that has partial information in internal KB but may benefit from external sources.

In [None]:
# Test query that might have partial coverage
query_3 = "How are transformer models being used in production applications today?"

result_3 = crag_engine.query(query_3)

print("\n" + "="*80)
print("FINAL RESULT")
print("="*80)
print(f"Route: {result_3['route']}")
print(f"Confidence: {result_3['confidence']:.2f}")
print(f"Source: {result_3['source']}")
print(f"\nAnswer:\n{result_3['response']}")

## 12. Comparative Analysis

Compare CRAG with a standard (non-corrective) RAG system.

In [None]:
# Create standard RAG baseline for comparison
baseline_engine = internal_index.as_query_engine(
    similarity_top_k=5,
    llm=llm,
)

print("Standard RAG Comparison")
print("="*80)

# Test on the out-of-domain query
baseline_response = baseline_engine.query(query_2)

print(f"\nQuery: {query_2}")
print("\n--- Standard RAG (No Self-Correction) ---")
print(baseline_response.response)
print("\n--- CRAG (With Self-Correction) ---")
print(result_2['response'])
print("\n" + "="*80)
print("\n🔍 Key Difference:")
print("Standard RAG blindly uses internal documents even when irrelevant.")
print(f"CRAG detected low confidence ({result_2['confidence']:.2f}) and switched to web search.")

## 13. Visualize CRAG Decision Flow

In [None]:
import pandas as pd

# Create summary table
results_summary = pd.DataFrame([
    {
        "Query": query_1[:50] + "...",
        "Confidence": f"{result_1['confidence']:.2f}",
        "Route": result_1['route'],
        "Source": result_1['source'],
    },
    {
        "Query": query_2[:50] + "...",
        "Confidence": f"{result_2['confidence']:.2f}",
        "Route": result_2['route'],
        "Source": result_2['source'],
    },
    {
        "Query": query_3[:50] + "...",
        "Confidence": f"{result_3['confidence']:.2f}",
        "Route": result_3['route'],
        "Source": result_3['source'],
    },
])

print("\nCRAG Routing Summary")
print("="*80)
print(results_summary.to_string(index=False))
print("\n" + "="*80)

## 14. Key Takeaways

### What We Learned

1. **Self-Evaluation is Critical**: Blindly using retrieved documents can lead to poor answers when internal knowledge is insufficient.

2. **Dynamic Routing Improves Robustness**: By evaluating confidence and routing accordingly, CRAG adapts to different query types:
   - High confidence → Use internal KB (fast, accurate for in-domain)
   - Low confidence → Use web search (access to current information)
   - Ambiguous → Merge both (comprehensive coverage)

3. **Knowledge Refinement Helps**: Sentence-level filtering removes noise and focuses on relevant information.

4. **Fallback Mechanisms Essential**: External knowledge sources (web search) handle queries outside the knowledge base scope.

### When to Use CRAG

- **Limited/Specialized Knowledge Bases**: When you know your KB won't cover all queries
- **Current Information Needs**: When users ask about recent events
- **High-Stakes Applications**: When wrong answers are costly
- **Hybrid Scenarios**: When you want the speed of internal KB with web fallback

### Implementation Considerations

- **Threshold Tuning**: Adjust high/low thresholds based on your domain
- **Evaluator Cost**: LLM evaluation adds latency and cost
- **Web Search Limits**: Consider rate limits and API costs
- **Caching**: Cache evaluation results for common queries

## 15. Data Flow Diagram

```
┌─────────────────────────────────────────────────────────────────┐
│                        CRAG ARCHITECTURE                        │
└─────────────────────────────────────────────────────────────────┘

User Query
    ↓
┌──────────────────────┐
│ Retrieve from        │
│ Internal KB          │
│ (Vector Search)      │
└──────────────────────┘
    ↓
┌──────────────────────┐
│ LLM Evaluates        │
│ Relevance            │
│ → Confidence Score   │
└──────────────────────┘
    ↓
    ├─── Confidence ≥ 0.7 (HIGH) ────→ Use Internal Docs
    │                                            ↓
    ├─── Confidence < 0.4 (LOW) ─────→ Web Search Only
    │                                            ↓
    └─── 0.4 ≤ Confidence < 0.7 ─────→ Merge Internal + Web
                                                ↓
                                    ┌──────────────────────┐
                                    │ Optional:            │
                                    │ Sentence-Level       │
                                    │ Filtering            │
                                    └──────────────────────┘
                                                ↓
                                    ┌──────────────────────┐
                                    │ LLM Generation       │
                                    │ with Refined Context │
                                    └──────────────────────┘
                                                ↓
                                        Final Answer
```

## References

1. **Corrective Retrieval Augmented Generation** - Yan et al., 2024
   - arXiv:2401.15884
   - Introduces the CRAG framework with retrieval evaluation and correction

2. **Workshop Curriculum Reference #67**: "Corrective Retrieval Augmented Generation"

3. **Related Approaches**:
   - Self-RAG: Learning to retrieve, generate and critique through self-reflection
   - Adaptive RAG: Dynamically selecting retrieval strategies