# RAG Hallucination Detection with Conformal Prediction

This notebook shows how to build a hallucination detector for RAG systems using conformal prediction, then stress-test it under distribution shift.

**What you'll learn:**
- Build a conformal predictor for RAG hallucination detection
- Calibrate it on your documents
- Test what happens when users ask about new topics

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/debu-sinha/conformaldrift/blob/main/examples/01_rag_hallucination_detection.ipynb)

In [None]:
!pip install conformal-drift langchain langchain-openai langchain-community chromadb sentence-transformers -q

In [None]:
import os
import numpy as np

# Set your OpenAI API key
os.environ['OPENAI_API_KEY'] = 'your-key-here'  # Replace with your key

## Step 1: Build a Simple RAG System

We'll create a RAG system with some technical documentation.

In [None]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# Sample knowledge base - Python documentation
knowledge_base = [
    "Python lists are ordered, mutable sequences. You can add items with append() or extend().",
    "Python dictionaries store key-value pairs. Keys must be hashable. Access values with dict[key].",
    "List comprehensions provide a concise way to create lists: [x**2 for x in range(10)].",
    "The with statement ensures proper resource management. Files opened with 'with' are automatically closed.",
    "Decorators modify function behavior. Use @decorator syntax above the function definition.",
    "Python's GIL (Global Interpreter Lock) allows only one thread to execute Python bytecode at a time.",
    "Virtual environments isolate project dependencies. Create with: python -m venv myenv",
    "Type hints improve code readability: def greet(name: str) -> str: return f'Hello {name}'",
    "Generators yield values lazily, saving memory for large datasets. Use yield instead of return.",
    "Context managers implement __enter__ and __exit__ methods for resource management.",
]

# Create vector store
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(knowledge_base, embeddings)

# Create RAG chain
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

prompt = PromptTemplate(
    template="""Answer the question based ONLY on the following context. 
If you cannot answer from the context, say "I don't have information about that."

Context: {context}

Question: {question}

Answer:""",
    input_variables=["context", "question"]
)

rag_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)

print("RAG system ready!")

## Step 2: Define Hallucination Score (Nonconformity Score)

The key insight: **responses that don't align with retrieved documents are likely hallucinations**.

We use embedding similarity as our score:
- **Low score** (high similarity): Response grounded in documents
- **High score** (low similarity): Potential hallucination

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Use a fast, accurate embedding model
scorer = SentenceTransformer('all-MiniLM-L6-v2')

def compute_hallucination_score(response: str, source_docs: list) -> float:
    """
    Compute how likely the response is a hallucination.
    
    Returns:
        Score between 0 and 1. Higher = more likely hallucination.
    """
    if not source_docs:
        return 1.0  # No sources = definitely suspicious
    
    # Get embeddings
    response_emb = scorer.encode([response])
    doc_texts = [doc.page_content for doc in source_docs]
    doc_embs = scorer.encode(doc_texts)
    
    # Compute max similarity to any source document
    similarities = cosine_similarity(response_emb, doc_embs)[0]
    max_similarity = max(similarities)
    
    # Nonconformity = 1 - similarity
    return 1 - max_similarity

# Test it
result = rag_chain({"query": "How do I create a list in Python?"})
score = compute_hallucination_score(result['result'], result['source_documents'])
print(f"Question: How do I create a list in Python?")
print(f"Response: {result['result'][:100]}...")
print(f"Hallucination score: {score:.3f}")

## Step 3: Calibrate the Hallucination Detector

Run questions we know are answerable from our knowledge base to establish baseline scores.

In [None]:
# Calibration questions - things we KNOW are in our knowledge base
calibration_questions = [
    "What are Python lists?",
    "How do dictionaries work in Python?",
    "What is a list comprehension?",
    "How does the with statement work?",
    "What are decorators in Python?",
    "What is the GIL?",
    "How do I create a virtual environment?",
    "What are type hints?",
    "How do generators work?",
    "What are context managers?",
    "How do I add items to a list?",
    "What makes dictionary keys valid?",
    "Why use the with statement for files?",
    "How do I use the @ syntax?",
    "What does yield do?",
]

print("Calibrating hallucination detector...")
calibration_scores = []

for q in calibration_questions:
    result = rag_chain({"query": q})
    score = compute_hallucination_score(result['result'], result['source_documents'])
    calibration_scores.append(score)
    print(f"  {q[:40]:40} -> score: {score:.3f}")

calibration_scores = np.array(calibration_scores)
print(f"\nCalibration complete!")
print(f"Mean score: {calibration_scores.mean():.3f}")
print(f"Std score: {calibration_scores.std():.3f}")

## Step 4: Set Up Conformal Prediction Threshold

Using conformal prediction, we find a threshold that flags hallucinations while maintaining 90% coverage on grounded responses.

In [None]:
from conformal_drift import ConformalDriftAuditor

# Initialize auditor with calibration scores
# alpha=0.1 means we want 90% coverage (10% false positive rate)
auditor = ConformalDriftAuditor(
    calibration_scores=calibration_scores,
    alpha=0.1
)

# The threshold is the 90th percentile of calibration scores
threshold = np.percentile(calibration_scores, 90)
print(f"Hallucination threshold: {threshold:.3f}")
print(f"Responses with score > {threshold:.3f} will be flagged as potential hallucinations")

## Step 5: Test on Out-of-Domain Questions (Distribution Shift!)

Now let's see what happens when users ask about topics NOT in our knowledge base.

In [None]:
# Questions about topics NOT in our knowledge base
out_of_domain_questions = [
    "How does asyncio work in Python?",  # Not in KB
    "What is FastAPI?",  # Not in KB
    "How do I use pandas DataFrames?",  # Not in KB
    "Explain Python's memory management",  # Not in KB
    "What is the difference between Python 2 and 3?",  # Not in KB
    "How do I deploy a Flask app?",  # Not in KB
    "What are dataclasses?",  # Not in KB
    "How does multiprocessing work?",  # Not in KB
]

print("Testing on out-of-domain questions...")
print("="*60)

ood_scores = []
for q in out_of_domain_questions:
    result = rag_chain({"query": q})
    score = compute_hallucination_score(result['result'], result['source_documents'])
    ood_scores.append(score)
    
    flagged = "FLAGGED" if score > threshold else "OK"
    print(f"Q: {q}")
    print(f"A: {result['result'][:80]}...")
    print(f"Score: {score:.3f} [{flagged}]")
    print()

ood_scores = np.array(ood_scores)
print(f"\nOut-of-domain: {sum(s > threshold for s in ood_scores)}/{len(ood_scores)} flagged")

## Step 6: Audit Coverage Under Distribution Shift

The key question: **Does our 90% coverage guarantee hold when the question distribution shifts?**

In [None]:
# Mix in-domain and out-of-domain questions at different ratios
in_domain_questions = calibration_questions.copy()

def create_shifted_test_set(shift_intensity: float, n_samples: int = 20):
    """
    Create a test set with specified ratio of out-of-domain questions.
    shift_intensity = 0.0 means all in-domain
    shift_intensity = 1.0 means all out-of-domain
    """
    n_ood = int(n_samples * shift_intensity)
    n_id = n_samples - n_ood
    
    # Sample questions
    np.random.seed(42)
    id_qs = list(np.random.choice(in_domain_questions, min(n_id, len(in_domain_questions)), replace=True))
    ood_qs = list(np.random.choice(out_of_domain_questions, min(n_ood, len(out_of_domain_questions)), replace=True))
    
    questions = id_qs + ood_qs
    labels = [True] * len(id_qs) + [False] * len(ood_qs)  # True = grounded, False = should be flagged
    
    return questions, labels

# Test at different shift intensities
shift_levels = [0.0, 0.25, 0.5, 0.75, 1.0]
results = []

print("Auditing coverage under distribution shift...")
print("="*60)

for shift in shift_levels:
    questions, labels = create_shifted_test_set(shift)
    
    # Compute scores
    scores = []
    for q in questions:
        result = rag_chain({"query": q})
        score = compute_hallucination_score(result['result'], result['source_documents'])
        scores.append(score)
    
    # Coverage = fraction of grounded responses correctly not flagged
    grounded_mask = np.array(labels)
    grounded_scores = np.array(scores)[grounded_mask]
    
    if len(grounded_scores) > 0:
        coverage = np.mean(grounded_scores <= threshold)
    else:
        coverage = np.nan
    
    results.append({
        'shift': shift,
        'coverage': coverage,
        'n_grounded': sum(labels),
        'n_halluc': len(labels) - sum(labels)
    })
    
    print(f"Shift {shift:.0%}: Coverage = {coverage:.1%} (on {sum(labels)} grounded responses)")

print("\nNominal coverage target: 90%")

## Step 7: Visualize Coverage Degradation

In [None]:
import matplotlib.pyplot as plt

shifts = [r['shift'] for r in results]
coverages = [r['coverage'] for r in results]

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(shifts, coverages, 'b-o', linewidth=2, markersize=10)
ax.axhline(y=0.9, color='r', linestyle='--', linewidth=2, label='90% Target')
ax.fill_between([0, 1], [0.85, 0.85], [0.95, 0.95], alpha=0.2, color='green', label='Â±5% Tolerance')

ax.set_xlabel('Out-of-Domain Question Ratio', fontsize=12)
ax.set_ylabel('Coverage on Grounded Responses', fontsize=12)
ax.set_title('RAG Hallucination Detector Coverage Under Topic Shift', fontsize=14)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(True, alpha=0.3)

# Add annotations
for s, c in zip(shifts, coverages):
    if not np.isnan(c):
        ax.annotate(f'{c:.0%}', (s, c), textcoords="offset points", xytext=(0, 10), ha='center')

plt.tight_layout()
plt.savefig('rag_coverage_shift.png', dpi=150)
plt.show()

## Key Takeaways

1. **Conformal prediction gives coverage guarantees** - but only for in-distribution data

2. **Topic shift breaks guarantees** - when users ask about topics not in your knowledge base, the detector may behave unpredictably

3. **Test before you deploy** - use this audit to understand how your guardrail degrades under realistic shift scenarios

4. **Monitor in production** - track the distribution of hallucination scores to detect shift early

## What to do when coverage drops:
- Expand your knowledge base to cover new topics
- Recalibrate the detector periodically with new data
- Set more conservative thresholds for high-stakes applications