# Speculative RAG: Enhancing RAG Through Drafting

This notebook implements **Speculative RAG** based on Google Research's paper: ["Speculative RAG: Enhancing Retrieval Augmented Generation through Drafting"](https://arxiv.org/abs/2407.08223)

## What is Speculative RAG?

Speculative RAG is an advanced approach that uses a **two-stage drafter-verifier architecture** to improve both accuracy and latency:

- **Specialist Drafter**: A smaller, fine-tuned model generates multiple answer drafts with rationales **in parallel**, each working with **distinct document subsets**
- **Generalist Verifier**: A larger model calculates **conditional generation probability** P(answer | rationale, documents, query) for each draft and selects the highest-confidence answer
- **Distinct Document Partitioning**: Retrieved documents are split into **non-overlapping subsets**, with each drafter processing different documents to provide diverse perspectives
- **Rationale-Based Scoring**: Each draft includes reasoning, which the verifier uses to compute conditional probabilities for answer verification

## Key Differences from Standard RAG

1. **Standard RAG**: Feeds all retrieved documents directly to a single LLM
2. **Speculative RAG**: Uses a smaller specialist to draft multiple candidates from **distinct document subsets** in parallel, then a larger model verifies using **conditional probability** and selects the best one

## Performance Benefits (from paper)

- **12.97% accuracy improvement** on PubHealth dataset
- **51% latency reduction** compared to standard RAG
- State-of-the-art results on TriviaQA, MuSiQue, PubHealth, and ARC-Challenge

## How It Works (Paper Algorithm)

1. **Document Retrieval**: Semantic search finds top-k relevant documents
2. **Document Partitioning**: Split documents into **distinct (non-overlapping) subsets** for parallel processing
3. **Parallel Drafting**: Smaller specialist model generates answer drafts with rationales from each subset simultaneously
4. **Verification**: Larger generalist model calculates P(answer | rationale, documents, query) for each draft using conditional probability
5. **Selection**: Return the draft with the **highest conditional probability score**

## Implementation Notes

This implementation follows the paper's core algorithm:
- **Drafter**: Uses GPT-4o-mini (smaller model) to generate drafts with rationales (paper uses fine-tuned Mistral-7B)
- **Verifier**: Uses GPT-4o (larger model) to compute conditional probabilities via logprobs (paper uses Mixtral-8x7B)
- **Document Partitioning**: Creates distinct, non-overlapping subsets as specified in the paper
- **Probability Calculation**: Uses OpenAI API's logprobs to approximate P(answer | rationale, documents, query)

## Setup and Imports

In [74]:
!pip install openai numpy scikit-learn wikipedia-api -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [75]:
import os
import asyncio
from typing import List, Dict, Any, Tuple
from openai import AsyncOpenAI
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
from datetime import datetime
import wikipediaapi
import re

In [None]:
# Configure OpenAI API
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])

## Knowledge Base Setup

Fetch and process Wikipedia article on Quantum Entanglement.

In [77]:
def fetch_wikipedia_article(page_title: str) -> str:
    """Fetch Wikipedia article content."""
    wiki = wikipediaapi.Wikipedia(
        language='en',
        user_agent='SpeculativeRAG/1.0 (Educational Purpose)'
    )
    
    page = wiki.page(page_title)
    
    if not page.exists():
        raise ValueError(f"Wikipedia page '{page_title}' does not exist")
    
    return page.text


def chunk_text_by_sections(text: str, max_chunk_size: int = 500) -> List[Dict[str, Any]]:
    """
    Chunk Wikipedia text into semantic sections.
    
    Args:
        text: Full Wikipedia article text
        max_chunk_size: Maximum characters per chunk
        
    Returns:
        List of document dictionaries with id, content, and metadata
    """
    # Split by section headers (lines that end with ==)
    sections = re.split(r'\n(?=\w+.*?\n=+)', text)
    
    documents = []
    doc_id = 1
    
    for section in sections:
        # Extract section title if present
        lines = section.strip().split('\n')
        if len(lines) > 1 and '=' in lines[1]:
            title = lines[0].strip()
            content = '\n'.join(lines[2:]).strip()
        else:
            title = "Introduction"
            content = section.strip()
        
        # Skip empty sections
        if not content:
            continue
        
        # If section is too large, chunk it by paragraphs
        if len(content) > max_chunk_size:
            paragraphs = content.split('\n\n')
            current_chunk = ""
            
            for para in paragraphs:
                if len(current_chunk) + len(para) > max_chunk_size and current_chunk:
                    documents.append({
                        "id": doc_id,
                        "content": current_chunk.strip(),
                        "metadata": {"section": title, "source": "wikipedia"}
                    })
                    doc_id += 1
                    current_chunk = para
                else:
                    current_chunk += "\n\n" + para if current_chunk else para
            
            # Add remaining content
            if current_chunk.strip():
                documents.append({
                    "id": doc_id,
                    "content": current_chunk.strip(),
                    "metadata": {"section": title, "source": "wikipedia"}
                })
                doc_id += 1
        else:
            documents.append({
                "id": doc_id,
                "content": content,
                "metadata": {"section": title, "source": "wikipedia"}
            })
            doc_id += 1
    
    return documents


print("✓ Wikipedia fetching and chunking functions defined")

✓ Wikipedia fetching and chunking functions defined


In [78]:
# Fetch and process Wikipedia article on Quantum Entanglement
print("Fetching Wikipedia article: Quantum entanglement")
wiki_text = fetch_wikipedia_article("Quantum entanglement")

print("Chunking article into semantic sections...")
KNOWLEDGE_BASE = chunk_text_by_sections(wiki_text, max_chunk_size=500)

print(f"\n✓ Loaded {len(KNOWLEDGE_BASE)} document chunks from Wikipedia")
print(f"\nFirst 3 chunks:")
for i, doc in enumerate(KNOWLEDGE_BASE[:3]):
    print(f"\n--- Document {doc['id']} (Section: {doc['metadata']['section']}) ---")
    print(f"{doc['content'][:200]}...")
    
print(f"\n... and {len(KNOWLEDGE_BASE) - 3} more chunks")

Fetching Wikipedia article: Quantum entanglement
Chunking article into semantic sections...

✓ Loaded 74 document chunks from Wikipedia

First 3 chunks:

--- Document 1 (Section: Introduction) ---
Quantum entanglement is the phenomenon where the quantum state of each particle in a group cannot be described independently of the state of the others, even when the particles are separated by a larg...

--- Document 2 (Section: Introduction) ---
History
Albert Einstein and Niels Bohr engaged in a long-running collegial dispute over the interpretation of quantum mechanics, now known as the Bohr–Einstein debates. During these debates, Einstein ...

--- Document 3 (Section: Introduction) ---
Concept
Meaning of entanglement
Just as energy is a resource that facilitates mechanical operations, entanglement is a resource that facilitates performing tasks that involve communication and computa...

... and 71 more chunks


## Embedding and Retrieval Functions

Breaking down the retrieval pipeline into individual components for clarity.

In [79]:
async def get_embedding(text: str, model: str = "text-embedding-3-small") -> List[float]:
    """Get embedding for a text using OpenAI's embedding model."""
    response = await client.embeddings.create(
        input=text,
        model=model
    )
    return response.data[0].embedding

print("✓ get_embedding function defined")

✓ get_embedding function defined


In [80]:
async def get_embeddings_batch(texts: List[str], model: str = "text-embedding-3-small") -> List[List[float]]:
    """Get embeddings for multiple texts in parallel."""
    tasks = [get_embedding(text, model) for text in texts]
    return await asyncio.gather(*tasks)

print("✓ get_embeddings_batch function defined")

✓ get_embeddings_batch function defined


In [81]:
async def retrieve_documents(query: str, knowledge_base: List[Dict], top_k: int = 3) -> List[Dict]:
    """Retrieve top-k relevant documents based on semantic similarity."""
    # Get query embedding
    query_embedding = await get_embedding(query)
    
    # Get embeddings for all documents (in practice, these would be pre-computed)
    doc_texts = [doc["content"] for doc in knowledge_base]
    doc_embeddings = await get_embeddings_batch(doc_texts)
    
    # Calculate similarities
    similarities = cosine_similarity(
        [query_embedding],
        doc_embeddings
    )[0]
    
    # Get top-k indices
    top_indices = np.argsort(similarities)[-top_k:][::-1]
    
    # Return documents with scores
    results = []
    for idx in top_indices:
        doc = knowledge_base[idx].copy()
        doc["similarity_score"] = float(similarities[idx])
        results.append(doc)
    
    return results

print("✓ retrieve_documents function defined")

✓ retrieve_documents function defined


## Speculative RAG Components

Breaking down the two-stage drafter-verifier architecture into individual components.

### Part 1: RAGDrafter Class - Specialist Model for Draft Generation

In [82]:
class RAGDrafter:
    """
    Specialist drafter using a smaller model to generate answer drafts with rationales.
    Corresponds to the fine-tuned Mistral-7B in the paper.
    """
    
    def __init__(self, client: AsyncOpenAI, model: str = "gpt-4o-mini"):
        self.client = client
        self.model = model

print("✓ RAGDrafter class initialized")

✓ RAGDrafter class initialized


In [83]:
async def _rag_drafter_generate_draft(self, query: str, documents: List[Dict], draft_id: int) -> Dict[str, Any]:
    """
    Generate a single answer draft with rationale from a document subset.
    
    Args:
        query: The user's question
        documents: Subset of retrieved documents for this draft
        draft_id: Identifier for this draft
        
    Returns:
        Dict containing answer, rationale, and metadata
    """
    # Format documents as context
    context_text = "\n\n".join([
        f"Document {i+1}:\n{doc['content']}" 
        for i, doc in enumerate(documents)
    ])
    
    prompt = f"""You are a specialist RAG drafter. Generate a concise answer to the question based ONLY on the provided documents.

Context Documents:
{context_text}

Question: {query}

Provide your response in this JSON format:
{{
    "answer": "Your concise answer based on the documents",
    "rationale": "Brief explanation of which documents you used and why this answer is correct",
    "confidence": "A brief assessment of how well the documents support this answer"
}}

Important: Only use information from the provided documents. If the documents don't contain enough information, acknowledge this in your rationale."""
    
    response = await self.client.chat.completions.create(
        model=self.model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3,  # Low temperature for consistency
        response_format={"type": "json_object"}
    )
    
    draft_data = json.loads(response.choices[0].message.content)
    
    return {
        "draft_id": draft_id,
        "answer": draft_data.get("answer", ""),
        "rationale": draft_data.get("rationale", ""),
        "confidence_note": draft_data.get("confidence", ""),
        "documents_used": [doc['id'] for doc in documents],
        "tokens": response.usage.total_tokens,
        "model": self.model
    }

# Add method to RAGDrafter class
RAGDrafter.generate_draft = _rag_drafter_generate_draft

print("✓ RAGDrafter.generate_draft method defined")

✓ RAGDrafter.generate_draft method defined


In [84]:
async def _rag_drafter_generate_drafts_parallel(self, query: str, document_subsets: List[List[Dict]]) -> List[Dict[str, Any]]:
    """
    Generate multiple drafts in parallel, each from a different document subset.
    This is the key to Speculative RAG's latency improvement.
    
    Args:
        query: The user's question
        document_subsets: List of document subsets, one for each draft
        
    Returns:
        List of draft dictionaries
    """
    tasks = [
        self.generate_draft(query, subset, i)
        for i, subset in enumerate(document_subsets)
    ]
    
    drafts = await asyncio.gather(*tasks)
    return drafts

# Add method to RAGDrafter class
RAGDrafter.generate_drafts_parallel = _rag_drafter_generate_drafts_parallel

print("✓ RAGDrafter.generate_drafts_parallel method defined")

✓ RAGDrafter.generate_drafts_parallel method defined


### Part 2: RAGVerifier - Generalist Model for Draft Selection

In [85]:
class RAGVerifier:
    """
    Generalist verifier using a larger model to score and select the best draft.
    Corresponds to the Mixtral-8x7B model in the paper.
    """
    
    def __init__(self, client: AsyncOpenAI, model: str = "gpt-4o"):
        self.client = client
        self.model = model

print("✓ RAGVerifier class initialized")

✓ RAGVerifier class initialized


In [86]:
async def _rag_verifier_verify_and_select(self, query: str, drafts: List[Dict], all_documents: List[Dict]) -> Dict[str, Any]:
    """
    Verify drafts and select the best one based on conditional generation probability.
    
    Following the paper: Calculate P(answer | rationale, documents, query) for each draft.
    The verifier computes the conditional probability of generating each answer given its rationale,
    the documents, and the query, then selects the draft with highest probability.
    
    Args:
        query: The user's question
        drafts: List of draft answers with rationales
        all_documents: All retrieved documents for reference
        
    Returns:
        Selected draft with verification metadata
    """
    # Calculate conditional probability for each draft
    draft_scores = []
    
    for draft in drafts:
        # Format the context: documents + rationale + query
        docs_for_draft = [doc for doc in all_documents if doc['id'] in draft['documents_used']]
        context_text = "\n\n".join([
            f"Document {doc['id']}: {doc['content']}"
            for doc in docs_for_draft
        ])
        
        # Build verification prompt: Given documents, query, and rationale, 
        # calculate probability of the answer
        verification_prompt = f"""Given the following context, question, and reasoning, verify this answer.

Context Documents:
{context_text}

Question: {query}

Reasoning: {draft['rationale']}

Proposed Answer: {draft['answer']}

Is this answer correct and well-supported by the documents and reasoning? Respond with only 'yes' or 'no'."""
        
        # Get model's probability assessment
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": verification_prompt}],
            temperature=0.0,
            max_tokens=10,
            logprobs=True,
            top_logprobs=5
        )
        
        # Calculate confidence score from logprobs
        # Higher logprob = higher confidence in the answer
        if response.choices[0].logprobs and response.choices[0].logprobs.content:
            first_token_logprobs = response.choices[0].logprobs.content[0].top_logprobs
            
            # Find 'yes' token probability
            yes_prob = 0.0
            for token_logprob in first_token_logprobs:
                if token_logprob.token.lower().strip() in ['yes', 'y']:
                    yes_prob = np.exp(token_logprob.logprob)
                    break
            
            confidence = yes_prob
        else:
            # Fallback if logprobs not available
            confidence = 0.5
        
        draft_scores.append({
            "draft_id": draft['draft_id'],
            "score": float(confidence),
            "answer": draft['answer'],
            "rationale": draft['rationale']
        })
    
    # Select draft with highest conditional probability
    best_draft = max(draft_scores, key=lambda x: x['score'])
    selected_idx = best_draft['draft_id']
    
    return {
        "selected_draft": drafts[selected_idx],
        "confidence_score": best_draft['score'],
        "verification_reasoning": f"Selected draft {selected_idx + 1} with conditional probability score of {best_draft['score']:.3f}",
        "all_draft_scores": draft_scores,
        "verifier_tokens": sum(1 for _ in draft_scores) * 50,  # Approximate
        "verifier_model": self.model
    }

# Add method to RAGVerifier class
RAGVerifier.verify_and_select = _rag_verifier_verify_and_select

print("✓ RAGVerifier.verify_and_select method defined (with conditional probability scoring)")

✓ RAGVerifier.verify_and_select method defined (with conditional probability scoring)


## Speculative RAG Orchestrator

Coordinates the full drafter-verifier pipeline with document partitioning and parallel processing.

In [87]:
def _speculative_rag_partition_documents(self, documents: List[Dict], num_partitions: int = 3) -> List[List[Dict]]:
    """
    Partition documents into DISTINCT subsets for parallel draft generation.
    
    Per the paper: "Each draft is generated from a distinct subset of retrieved documents,
    providing diverse perspectives."
    
    Strategy: Create non-overlapping subsets to ensure each draft sees different documents.
    This provides diverse perspectives and reduces token counts per draft.
    
    Args:
        documents: Retrieved documents sorted by relevance
        num_partitions: Number of distinct subsets to create
        
    Returns:
        List of non-overlapping document subsets
    """
    if len(documents) < num_partitions:
        # If we have fewer docs than partitions, give each partition what we can
        subsets = []
        for i in range(num_partitions):
            if i < len(documents):
                subsets.append([documents[i]])
            else:
                # For extra partitions, use empty list (will be handled in draft generation)
                subsets.append([])
        return subsets
    
    # Create DISTINCT (non-overlapping) partitions as per the paper
    partition_size = len(documents) // num_partitions
    remainder = len(documents) % num_partitions
    
    subsets = []
    start_idx = 0
    
    for i in range(num_partitions):
        # Distribute remainder documents across first partitions
        current_size = partition_size + (1 if i < remainder else 0)
        end_idx = start_idx + current_size
        
        subset = documents[start_idx:end_idx]
        subsets.append(subset)
        
        start_idx = end_idx
    
    return subsets

# Add method to SpeculativeRAG class
SpeculativeRAG.partition_documents = _speculative_rag_partition_documents

print("✓ SpeculativeRAG.partition_documents method defined (distinct, non-overlapping subsets)")

✓ SpeculativeRAG.partition_documents method defined (distinct, non-overlapping subsets)


## Initialize Speculative RAG System

In [88]:
# Run Speculative RAG query
question = "What is quantum entanglement and how does it work?"

result = await spec_rag.query(
    question=question,
    knowledge_base=KNOWLEDGE_BASE,
    top_k=6,
    num_drafts=3
)

print(f"\n{'='*80}")
print(f"SPECULATIVE RAG RESULT")
print(f"{'='*80}\n")

print(f"Question: {result['question']}\n")

print(f"SELECTED ANSWER:")
print(f"{result['answer']}\n")

print(f"Rationale: {result['rationale']}\n")

print(f"Confidence Score: {result['confidence_score']:.2f}")
print(f"Verification Reasoning: {result['verification_reasoning']}\n")

print(f"{'='*80}")
print(f"PERFORMANCE METRICS")
print(f"{'='*80}\n")

print(f"Total Time: {result['timing']['total']:.2f}s")
print(f"  - Retrieval: {result['timing']['retrieval']:.2f}s")
print(f"  - Drafting ({len(result['all_drafts'])} drafts in parallel): {result['timing']['drafting']:.2f}s")
print(f"  - Verification: {result['timing']['verification']:.2f}s")
print(f"\n{result['timing']['speedup_from_parallel']}\n")

print(f"Total Tokens: {result['tokens']['total']}")
print(f"  - Drafter ({result['models_used']['drafter']}): {result['tokens']['drafter_total']}")
print(f"  - Verifier ({result['models_used']['verifier']}): {result['tokens']['verifier']}\n")

print(f"{'='*80}")
print(f"ALL DRAFTS GENERATED")
print(f"{'='*80}\n")

for i, draft in enumerate(result['all_drafts']):
    score_info = next((s for s in result['draft_scores'] if s['draft_id'] == i), {})
    print(f"\nDraft {i+1}:")
    print(f"Documents Used: {draft['documents_used']}")
    print(f"Score: {score_info.get('score', 'N/A')}")
    print(f"\nAnswer: {draft['answer']}")
    print(f"\nRationale: {draft['rationale']}")
    if score_info.get('issues'):
        print(f"Issues Noted: {score_info['issues']}")
    print(f"\n{'-'*80}")

Retrieving top 6 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...

SPECULATIVE RAG RESULT

Question: What is quantum entanglement and how does it work?

SELECTED ANSWER:
Quantum entanglement is a phenomenon where the quantum state of particles cannot be described independently, leading to strong correlations in measurements of their properties, even when separated by large distances. It occurs when particles interact in such a way that their states become interdependent, and measurements on one particle instantaneously affect the state of the other. This phenomenon is distinct from classical correlations and is a fundamental feature of quantum mechanics.

Rationale: This answer is based on Document 1, which explains the nature of quantum entanglement, its implications, and the correlations observed in measurements. Document 2 further elaborates on the concept of entanglement, describing how it

## Example 1: Basic Speculative RAG Query

Demonstrating the drafter-verifier pipeline with document partitioning.

In [89]:
# Analyze how document partitioning affects draft quality
question2 = "What is the EPR paradox and how does it relate to quantum entanglement?"

result2 = await spec_rag.query(
    question=question2,
    knowledge_base=KNOWLEDGE_BASE,
    top_k=6,
    num_drafts=4  # Generate 4 drafts for comparison
)

print(f"\n{'='*80}")
print(f"DOCUMENT PARTITIONING ANALYSIS")
print(f"{'='*80}\n")

print(f"Question: {result2['question']}\n")

print("Document Subsets Created:")
for i, subset_ids in enumerate(result2['document_subsets']):
    print(f"\nSubset {i+1}: Documents {subset_ids}")
    # Show which documents these are
    docs_in_subset = [doc for doc in result2['retrieved_documents'] if doc['id'] in subset_ids]
    for doc in docs_in_subset:
        print(f"  - Doc {doc['id']}: {doc['content'][:80]}...")

print(f"\n{'='*80}")
print(f"DRAFT COMPARISON")
print(f"{'='*80}\n")

for i, draft in enumerate(result2['all_drafts']):
    score_info = next((s for s in result2['draft_scores'] if s['draft_id'] == i), {})
    
    print(f"\nDraft {i+1} (Documents: {draft['documents_used']})")
    print(f"Score: {score_info.get('score', 'N/A')}")
    print(f"Answer: {draft['answer'][:200]}...")
    print(f"Rationale: {draft['rationale'][:150]}...")
    
print(f"\n{'='*80}")
print(f"FINAL SELECTION")
print(f"{'='*80}\n")

print(f"Selected Answer:\n{result2['answer']}\n")
print(f"Confidence: {result2['confidence_score']:.2f}")
print(f"Why: {result2['verification_reasoning']}")

Retrieving top 6 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 4 subsets...
Generating 4 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...

DOCUMENT PARTITIONING ANALYSIS

Question: What is the EPR paradox and how does it relate to quantum entanglement?

Document Subsets Created:

Subset 1: Documents [4, 1]
  - Doc 4: Paradox
The singlet state described above is the basis for one version of the EP...
  - Doc 1: Quantum entanglement is the phenomenon where the quantum state of each particle ...

Subset 2: Documents [2, 14]
  - Doc 2: History
Albert Einstein and Niels Bohr engaged in a long-running collegial dispu...
  - Doc 14: If the composite system is in this state, it is impossible to attribute to eithe...

Subset 3: Documents [3]
  - Doc 3: Concept
Meaning of entanglement
Just as energy is a resource that facilitates me...

Subset 4: Documents [5]
  - Doc 5: Failure of local hidden-variable theories
A possible resolution to the paradox i...

DRAFT COMPARISON




## Example 2: Comparing Draft Quality

Examining how different document subsets lead to different draft answers.

In [90]:
# Test multiple queries to analyze performance patterns
test_questions = [
    "What is Bell's theorem and its significance?",
    "How was quantum entanglement experimentally verified?",
    "What are the applications of quantum entanglement in quantum computing?"
]

print(f"Running Speculative RAG on {len(test_questions)} questions...\n")

results = []
for i, question in enumerate(test_questions):
    print(f"[{i+1}/{len(test_questions)}] {question}")
    result = await spec_rag.query(question, KNOWLEDGE_BASE, top_k=5, num_drafts=3)
    results.append(result)
    print(f"  ✓ Completed in {result['timing']['total']:.2f}s\n")

# Performance summary
print(f"\n{'='*80}")
print(f"PERFORMANCE SUMMARY")
print(f"{'='*80}\n")

total_tokens = 0
total_duration = 0
total_drafts = 0

print(f"{'Query':<50} {'Time (s)':<12} {'Tokens':<10} {'Confidence'}")
print("-" * 90)

for i, result in enumerate(results):
    tokens = result['tokens']['total']
    duration = result['timing']['total']
    conf = result['confidence_score']
    
    total_tokens += tokens
    total_duration += duration
    total_drafts += len(result['all_drafts'])
    
    q_short = result['question'][:47] + "..." if len(result['question']) > 50 else result['question']
    print(f"{q_short:<50} {duration:<12.2f} {tokens:<10} {conf:.2f}")

print("-" * 90)
print(f"{'TOTALS':<50} {total_duration:<12.2f} {total_tokens:<10}")
print(f"\nAverage per query: {total_duration/len(results):.2f}s, {total_tokens//len(results)} tokens")
print(f"Total drafts generated: {total_drafts}")
print(f"Average drafting time: {sum(r['timing']['drafting'] for r in results)/len(results):.2f}s")
print(f"Average verification time: {sum(r['timing']['verification'] for r in results)/len(results):.2f}s")

Running Speculative RAG on 3 questions...

[1/3] What is Bell's theorem and its significance?
Retrieving top 5 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...
  ✓ Completed in 12.20s

[2/3] How was quantum entanglement experimentally verified?
Retrieving top 5 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...
  ✓ Completed in 11.55s

[3/3] What are the applications of quantum entanglement in quantum computing?
Retrieving top 5 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...
  ✓ Completed in 13.70s


PERFORMANCE SUMMARY

Query                                              Time (s)     Tokens     Confidence
------------------------------------------------------------------------------------------
What is Bell's theorem and its significance?       12.20        5606       0.90
How was quantum entanglement experimentally ver... 11.55        5610       0.90
What are the applications of quantum entangleme... 13.70        6930       0.90
------------------------------------------------------------------------------------------
TOTALS                                             37.44        18146     

Average per query: 12.48s, 6048 tokens
Total drafts generated: 9
Average drafting time: 2.94s
Average verification time: 7.10s


## Example 3: Performance Analysis Across Multiple Queries

In [91]:
# Try your own question with detailed analysis
custom_question = "What is quantum entanglement swapping and how does it work?"

result = await spec_rag.query(custom_question, KNOWLEDGE_BASE, top_k=6, num_drafts=3)

print(f"\n{'='*80}")
print(f"DETAILED SPECULATIVE RAG ANALYSIS")
print(f"{'='*80}\n")

print(f"Question: {result['question']}\n")

print(f"{'='*80}")
print("STEP 1: DOCUMENT RETRIEVAL")
print(f"{'='*80}\n")

print(f"Retrieved {len(result['retrieved_documents'])} documents:\n")
for doc in result['retrieved_documents']:
    print(f"Doc {doc['id']} (similarity: {doc['similarity_score']:.3f}):")
    print(f"  {doc['content'][:120]}...\n")

print(f"{'='*80}")
print("STEP 2: DOCUMENT PARTITIONING")
print(f"{'='*80}\n")

for i, subset_ids in enumerate(result['document_subsets']):
    print(f"Subset {i+1}: Documents {subset_ids}")

print(f"\n{'='*80}")
print("STEP 3: PARALLEL DRAFT GENERATION")
print(f"{'='*80}\n")

print(f"Generated {len(result['all_drafts'])} drafts in {result['timing']['drafting']:.2f}s")
print(f"(Sequential would take ~{result['timing']['drafting'] * len(result['all_drafts']):.2f}s)\n")

for draft in result['all_drafts']:
    print(f"\nDraft {draft['draft_id'] + 1}:")
    print(f"  Documents: {draft['documents_used']}")
    print(f"  Tokens: {draft['tokens']}")
    print(f"  Answer: {draft['answer'][:100]}...")

print(f"\n{'='*80}")
print("STEP 4: VERIFICATION & SELECTION")
print(f"{'='*80}\n")

print(f"Verification completed in {result['timing']['verification']:.2f}s")
print(f"Tokens used: {result['tokens']['verifier']}\n")

print("Draft Scores:")
for score in result['draft_scores']:
    print(f"  Draft {score['draft_id'] + 1}: {score['score']:.2f}")
    if score.get('issues'):
        print(f"    Issues: {score['issues']}")

print(f"\n{'='*80}")
print("FINAL ANSWER")
print(f"{'='*80}\n")

print(f"Selected: Draft {result['draft_scores'].index(max(result['draft_scores'], key=lambda x: x['score'])) + 1}")
print(f"Confidence: {result['confidence_score']:.2f}\n")

print(f"Answer:\n{result['answer']}\n")

print(f"Rationale:\n{result['rationale']}\n")

print(f"Verification Reasoning:\n{result['verification_reasoning']}")

Retrieving top 6 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...

DETAILED SPECULATIVE RAG ANALYSIS

Question: What is quantum entanglement swapping and how does it work?

STEP 1: DOCUMENT RETRIEVAL

Retrieved 6 documents:

Doc 34 (similarity: 0.731):
  Entanglement swapping is variant of teleportation that allows two parties that have never interacted to share an entangl...

Doc 1 (similarity: 0.599):
  Quantum entanglement is the phenomenon where the quantum state of each particle in a group cannot be described independe...

Doc 33 (similarity: 0.590):
  Entanglement as a resource
In quantum information theory, entangled states are considered a 'resource', i.e., something ...

Doc 56 (similarity: 0.578):
  Applications
Entanglement has many applications in quantum information theory. With the aid of entanglement, otherwise i...

Doc 65 (similarity: 0.572):
  Methods of creating entanglement

## Example 4: Custom Question with Detailed Analysis

In [92]:
async def standard_rag(question: str, documents: List[Dict]) -> Dict[str, Any]:
    """
    Standard RAG baseline: feed all documents directly to a single LLM.
    """
    start_time = datetime.now()
    
    context_text = "\n\n".join([
        f"Document {doc['id']}:\n{doc['content']}" 
        for doc in documents
    ])
    
    prompt = f"""Based on the following context documents, answer the question.

Context:
{context_text}

Question: {question}

Provide a comprehensive answer based on the documents."""
    
    response = await client.chat.completions.create(
        model="gpt-4o",  # Use same model as verifier for fair comparison
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3
    )
    
    duration = (datetime.now() - start_time).total_seconds()
    
    return {
        "answer": response.choices[0].message.content,
        "tokens": response.usage.total_tokens,
        "duration": duration,
        "model": "gpt-4o"
    }


# Compare both approaches
test_question = "What is the relationship between quantum entanglement and nonlocality?"

print("Running Standard RAG...")
retrieved = await retrieve_documents(test_question, KNOWLEDGE_BASE, top_k=6)
standard_result = await standard_rag(test_question, retrieved)

print("Running Speculative RAG...")
spec_result = await spec_rag.query(test_question, KNOWLEDGE_BASE, top_k=6, num_drafts=3)

print(f"\n{'='*80}")
print("STANDARD RAG vs SPECULATIVE RAG COMPARISON")
print(f"{'='*80}\n")

print(f"Question: {test_question}\n")

print(f"{'='*80}")
print("STANDARD RAG")
print(f"{'='*80}\n")
print(f"Model: {standard_result['model']}")
print(f"Duration: {standard_result['duration']:.2f}s")
print(f"Tokens: {standard_result['tokens']}")
print(f"\nAnswer:\n{standard_result['answer']}\n")

print(f"{'='*80}")
print("SPECULATIVE RAG")
print(f"{'='*80}\n")
print(f"Models: {spec_result['models_used']['drafter']} (drafter) + {spec_result['models_used']['verifier']} (verifier)")
print(f"Duration: {spec_result['timing']['total']:.2f}s")
print(f"  - Drafting (parallel): {spec_result['timing']['drafting']:.2f}s")
print(f"  - Verification: {spec_result['timing']['verification']:.2f}s")
print(f"Tokens: {spec_result['tokens']['total']}")
print(f"  - Drafter: {spec_result['tokens']['drafter_total']}")
print(f"  - Verifier: {spec_result['tokens']['verifier']}")
print(f"Confidence: {spec_result['confidence_score']:.2f}")
print(f"\nAnswer:\n{spec_result['answer']}\n")
print(f"Rationale:\n{spec_result['rationale']}\n")

print(f"{'='*80}")
print("ANALYSIS")
print(f"{'='*80}\n")

speedup = ((standard_result['duration'] - spec_result['timing']['total']) / standard_result['duration']) * 100
token_diff = spec_result['tokens']['total'] - standard_result['tokens']

print(f"Latency: Speculative RAG is {abs(speedup):.1f}% {'faster' if speedup > 0 else 'slower'}")
print(f"Token Usage: Speculative RAG uses {abs(token_diff)} {'more' if token_diff > 0 else 'fewer'} tokens")
print(f"\nKey Advantages of Speculative RAG:")
print("  1. Multiple draft perspectives increase robustness")
print("  2. Parallel drafting improves latency (when network-bound)")
print("  3. Verification step provides confidence scoring")
print("  4. Drafter specialization can improve accuracy with fine-tuning")

Running Standard RAG...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Running Speculative RAG...
Retrieving top 6 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...

STANDARD RAG vs SPECULATIVE RAG COMPARISON

Question: What is the relationship between quantum entanglement and nonlocality?

STANDARD RAG

Model: gpt-4o
Duration: 14.57s
Tokens: 2993

Answer:
Quantum entanglement and nonlocality are closely related concepts in quantum mechanics, but they are not identical. Here's a comprehensive explanation of their relationship based on the provided documents:

1. **Quantum Entanglement**: 
   - Quantum entanglement is a phenomenon where the quantum state of each particle in a group cannot be described independently of the state of the others, even when the particles are separated by large distances (Document 1). It is a fundamental feature of quantum mechanics that distinguishes it from classical mechanics.
   - Entanglement results in strong correlations between measurements on entangled pa

## Comparison: Speculative RAG vs Standard RAG

Demonstrating the advantages of the drafter-verifier architecture.

In [93]:
async def standard_rag(question: str, documents: List[Dict]) -> Dict[str, Any]:
    """
    Standard RAG baseline: feed all documents directly to a single LLM.
    """
    start_time = datetime.now()
    
    context_text = "\n\n".join([
        f"Document {doc['id']}:\n{doc['content']}" 
        for doc in documents
    ])
    
    prompt = f"""Based on the following context documents, answer the question.

Context:
{context_text}

Question: {question}

Provide a comprehensive answer based on the documents."""
    
    response = await client.chat.completions.create(
        model="gpt-4o",  # Use same model as verifier for fair comparison
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3
    )
    
    duration = (datetime.now() - start_time).total_seconds()
    
    return {
        "answer": response.choices[0].message.content,
        "tokens": response.usage.total_tokens,
        "duration": duration,
        "model": "gpt-4o"
    }


# Compare both approaches
test_question = "What are transformers and how do they process data?"

print("Running Standard RAG...")
retrieved = await retrieve_documents(test_question, KNOWLEDGE_BASE, top_k=6)
standard_result = await standard_rag(test_question, retrieved)

print("Running Speculative RAG...")
spec_result = await spec_rag.query(test_question, KNOWLEDGE_BASE, top_k=6, num_drafts=3)

print(f"\n{'='*80}")
print("STANDARD RAG vs SPECULATIVE RAG COMPARISON")
print(f"{'='*80}\n")

print(f"Question: {test_question}\n")

print(f"{'='*80}")
print("STANDARD RAG")
print(f"{'='*80}\n")
print(f"Model: {standard_result['model']}")
print(f"Duration: {standard_result['duration']:.2f}s")
print(f"Tokens: {standard_result['tokens']}")
print(f"\nAnswer:\n{standard_result['answer']}\n")

print(f"{'='*80}")
print("SPECULATIVE RAG")
print(f"{'='*80}\n")
print(f"Models: {spec_result['models_used']['drafter']} (drafter) + {spec_result['models_used']['verifier']} (verifier)")
print(f"Duration: {spec_result['timing']['total']:.2f}s")
print(f"  - Drafting (parallel): {spec_result['timing']['drafting']:.2f}s")
print(f"  - Verification: {spec_result['timing']['verification']:.2f}s")
print(f"Tokens: {spec_result['tokens']['total']}")
print(f"  - Drafter: {spec_result['tokens']['drafter_total']}")
print(f"  - Verifier: {spec_result['tokens']['verifier']}")
print(f"Confidence: {spec_result['confidence_score']:.2f}")
print(f"\nAnswer:\n{spec_result['answer']}\n")
print(f"Rationale:\n{spec_result['rationale']}\n")

print(f"{'='*80}")
print("ANALYSIS")
print(f"{'='*80}\n")

speedup = ((standard_result['duration'] - spec_result['timing']['total']) / standard_result['duration']) * 100
token_diff = spec_result['tokens']['total'] - standard_result['tokens']

print(f"Latency: Speculative RAG is {abs(speedup):.1f}% {'faster' if speedup > 0 else 'slower'}")
print(f"Token Usage: Speculative RAG uses {abs(token_diff)} {'more' if token_diff > 0 else 'fewer'} tokens")
print(f"\nKey Advantages of Speculative RAG:")
print("  1. Multiple draft perspectives increase robustness")
print("  2. Parallel drafting improves latency (when network-bound)")
print("  3. Verification step provides confidence scoring")
print("  4. Drafter specialization can improve accuracy with fine-tuning")

Running Standard RAG...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Running Speculative RAG...
Retrieving top 6 documents...


  ret = a @ b
  ret = a @ b
  ret = a @ b


Partitioning documents into 3 subsets...
Generating 3 answer drafts in parallel using gpt-4o-mini...
Verifying drafts and selecting best answer using gpt-4o...

STANDARD RAG vs SPECULATIVE RAG COMPARISON

Question: What are transformers and how do they process data?

STANDARD RAG

Model: gpt-4o
Duration: 1.19s
Tokens: 1832

Answer:
The provided documents do not contain any information about transformers or how they process data. The documents focus on quantum entanglement, its applications, and its role as a resource in quantum information theory. If you have any other questions or need information on a different topic, feel free to ask!

SPECULATIVE RAG

Models: gpt-4o-mini (drafter) + gpt-4o (verifier)
Duration: 9.36s
  - Drafting (parallel): 2.59s
  - Verification: 3.03s
Tokens: 4743
  - Drafter: 2356
  - Verifier: 2387
Confidence: 0.95

Answer:
The documents do not provide information on transformers or how they process data.

Rationale:
The provided documents focus on entanglement

## Summary

This notebook implements **Speculative RAG** following Google Research's paper architecture:

### Architecture

1. **RAGDrafter**: Smaller specialist model (gpt-4o-mini) that generates multiple answer drafts with rationales, each from different document subsets
2. **RAGVerifier**: Larger generalist model (gpt-4o) that scores drafts and selects the best answer based on quality and relevance
3. **Parallel Processing**: Documents are partitioned and processed in parallel for latency reduction

### Key Benefits from the Paper

- **12.97% accuracy improvement** on PubHealth dataset
- **51% latency reduction** compared to standard RAG
- Better handling of diverse document sets through parallel draft generation
- Confidence scoring via verification step

### Implementation Details

- **Drafter model**: Uses low temperature (0.3) for consistent drafting
- **Verifier model**: Uses very low temperature (0.1) for reliable evaluation
- **Document partitioning**: Creates overlapping subsets to ensure adequate context
- **Rationale-based scoring**: Each draft includes reasoning that the verifier uses for assessment

### Potential Improvements

1. **Fine-tune drafter**: Train a specialized smaller model on RAG tasks (as in the paper)
2. **Smarter partitioning**: Use relevance scores to create better document subsets
3. **Adaptive drafting**: Adjust number of drafts based on query complexity
4. **Caching**: Cache document embeddings to speed up retrieval

In [94]:
# Utility: Visualize the Speculative RAG pipeline for a result
def visualize_pipeline(result: Dict[str, Any]):
    """Visualize the Speculative RAG pipeline flow."""
    print(f"\n{'='*80}")
    print("SPECULATIVE RAG PIPELINE VISUALIZATION")
    print(f"{'='*80}\n")
    
    print(f"Question: {result['question']}\n")
    
    # Step 1: Retrieval
    print("┌─ STEP 1: RETRIEVAL")
    print(f"│  Retrieved {len(result['retrieved_documents'])} documents")
    print(f"│  Time: {result['timing']['retrieval']:.2f}s")
    print("│")
    
    # Step 2: Partitioning
    print("├─ STEP 2: PARTITIONING")
    for i, subset in enumerate(result['document_subsets']):
        print(f"│  Subset {i+1}: Docs {subset}")
    print("│")
    
    # Step 3: Parallel Drafting
    print("├─ STEP 3: PARALLEL DRAFTING (gpt-4o-mini)")
    print(f"│  Generated {len(result['all_drafts'])} drafts in parallel")
    print(f"│  Time: {result['timing']['drafting']:.2f}s")
    print(f"│  Tokens: {result['tokens']['drafter_total']}")
    for i, draft in enumerate(result['all_drafts']):
        score = next((s['score'] for s in result['draft_scores'] if s['draft_id'] == i), 'N/A')
        print(f"│    Draft {i+1}: {draft['answer'][:60]}... (score: {score})")
    print("│")
    
    # Step 4: Verification
    print("├─ STEP 4: VERIFICATION (gpt-4o)")
    print(f"│  Evaluated all drafts")
    print(f"│  Time: {result['timing']['verification']:.2f}s")
    print(f"│  Tokens: {result['tokens']['verifier']}")
    best_idx = next(i for i, s in enumerate(result['draft_scores']) if s['score'] == max(s['score'] for s in result['draft_scores']))
    print(f"│  Selected: Draft {best_idx + 1}")
    print(f"│  Confidence: {result['confidence_score']:.2f}")
    print("│")
    
    # Final Answer
    print("└─ FINAL ANSWER")
    print(f"   {result['answer'][:100]}...")
    print()
    
    # Totals
    print(f"Total Time: {result['timing']['total']:.2f}s")
    print(f"Total Tokens: {result['tokens']['total']}")
    

# Example usage with a previous result
if 'result' in locals():
    visualize_pipeline(result)


SPECULATIVE RAG PIPELINE VISUALIZATION

Question: What is quantum entanglement swapping and how does it work?

┌─ STEP 1: RETRIEVAL
│  Retrieved 6 documents
│  Time: 2.18s
│
├─ STEP 2: PARTITIONING
│  Subset 1: Docs [34, 1]
│  Subset 2: Docs [33, 56]
│  Subset 3: Docs [65, 2]
│
├─ STEP 3: PARALLEL DRAFTING (gpt-4o-mini)
│  Generated 3 drafts in parallel
│  Time: 3.18s
│  Tokens: 3212
│    Draft 1: Quantum entanglement swapping is a process that allows two p... (score: 0.9)
│    Draft 2: The provided documents do not contain information about quan... (score: 0.2)
│    Draft 3: Quantum entanglement swapping is a method of creating entang... (score: 0.7)
│
├─ STEP 4: VERIFICATION (gpt-4o)
│  Evaluated all drafts
│  Time: 6.06s
│  Tokens: 3256
│  Selected: Draft 1
│  Confidence: 0.90
│
└─ FINAL ANSWER
   Quantum entanglement swapping is a process that allows two parties, who have never interacted, to sh...

Total Time: 11.42s
Total Tokens: 6468


## References

- **Paper**: [Speculative RAG: Enhancing Retrieval Augmented Generation through Drafting](https://research.google/blog/speculative-rag-enhancing-retrieval-augmented-generation-through-drafting/)
- **Key Innovation**: Drafter-Verifier architecture where a smaller specialist model generates multiple answer drafts in parallel, and a larger model verifies and selects the best one
- **Performance**: 12.97% accuracy improvement and 51% latency reduction on benchmarks