# Enhancing RAG with Contextual Retrieval

> Note: For more background information on Contextual Retrieval, including additional performance evaluations on various datasets, we recommend reading our accompanying  [blog post](https://www.anthropic.com/news/contextual-retrieval).

Retrieval Augmented Generation (RAG) enables Claude to leverage your internal knowledge bases, codebases, or any other corpus of documents when providing a response. Enterprises are increasingly building RAG applications to improve workflows in customer support, Q&A over internal company documents, financial & legal analysis, code generation, and much more.

In a [separate guide](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/retrieval_augmented_generation/guide.ipynb), we walked through setting up a basic retrieval system, demonstrated how to evaluate its performance, and then outlined a few techniques to improve performance. In this guide, we present a technique for improving retrieval performance: Contextual Embeddings.

In traditional RAG, documents are typically split into smaller chunks for efficient retrieval. While this approach works well for many applications, it can lead to problems when individual chunks lack sufficient context. Contextual Embeddings solve this problem by adding relevant context to each chunk before embedding. This method improves the quality of each embedded chunk, allowing for more accurate retrieval and thus better overall performance. Averaged across all data sources we tested, Contextual Embeddings reduced the top-20-chunk retrieval failure rate by 35%.

The same chunk-specific context can also be used with BM25 search to further improve retrieval performance. We introduce this technique in the “Contextual BM25” section.

In this guide, we'll demonstrate how to build and optimize a Contextual Retrieval system using a dataset of 9 codebases as our knowledge base. We'll walk through:

1) Setting up a basic retrieval pipeline to establish a baseline for performance.

2) Contextual Embeddings: what it is, why it works, and how prompt caching makes it practical for production use cases.

3) Implementing Contextual Embeddings and demonstrating performance improvements.

4) Contextual BM25: improving performance with *contextual* BM25 hybrid search.

5) Improving performance with reranking,

### Evaluation Metrics & Dataset:

We use a pre-chunked dataset of 9 codebases - all of which have been chunked according to a basic character splitting mechanism. Our evaluation dataset contains 248 queries - each of which contains a 'golden chunk.' We'll use a metric called Pass@k to evaluate performance. Pass@k checks whether or not the 'golden document' was present in the first k documents retrieved for each query. Contextual Embeddings in this case helped us to improve Pass@10 performance from ~87% --> ~95%.

You can find the code files and their chunks in `data/codebase_chunks.json` and the evaluation dataset in `data/evaluation_set.jsonl`

#### Additional Notes:

Prompt caching is helpful in managing costs when using this retrieval method. This feature is currently available on Anthropic's 1P API, and is coming soon to our 3P partner environments in AWS Bedrock and GCP Vertex. We know that many of our customers leverage AWS Knowledge Bases and GCP Vertex AI APIs when building RAG solutions, and this method can be used on either platform with a bit of customization. Consider reaching out to Anthropic or your AWS/GCP account team for guidance on this!

To make it easier to use this method on Bedrock, the AWS team has provided us with code that you can use to implement a Lambda function that adds context to each document. If you deploy this Lambda function, you can select it as a custom chunking option when configuring a [Bedrock Knowledge Base](https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base-create.html). You can find this code in `contextual-rag-lambda-function`. The main lambda function code is in `lambda_function.py`.

## Table of Contents

1) Setup

2) Basic RAG

3) Contextual Embeddings

4) Contextual BM25

5) Reranking


In [1]:
!pip install anthropic
!pip install openai
!pip install cohere
!pip install elasticsearch
!pip install pandas
!pip install numpy
!pip install matplotlib
!pip install tqdm
!pip install python-dotenv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgr

In [2]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Get API keys from environment variables
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 
ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY')
COHERE_API_KEY = os.getenv('COHERE_API_KEY') # Im running on a free student version of cohere ,-,

In [3]:
import anthropic

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)

## Level 1: Naive RAG

### Initialize a Vector DB Class

In-memory vector DB, for example purposes.
Can host it for more production related tasks


In [4]:
import os
import pickle
import json
import numpy as np
from openai import OpenAI
from typing import List, Dict, Any
from tqdm import tqdm

class VectorDB:
    def __init__(self, name: str, api_key = None):
        if api_key is None:
            api_key = os.getenv("OPENAI_API_KEY")
        self.client = OpenAI(api_key=api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/vector_db.pkl"

    def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)
        
        with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
            for doc in dataset:
                for chunk in doc['chunks']:
                    texts_to_embed.append(chunk['content'])
                    metadata.append({
                        'doc_id': doc['doc_id'],
                        'original_uuid': doc['original_uuid'],
                        'chunk_id': chunk['chunk_id'],
                        'original_index': chunk['original_index'],
                        'content': chunk['content']
                    })
                    pbar.update(1)

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()
        
        print(f"Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
    
    # for chatgpt
    def _chunk_text(self, text, max_tokens=7000, overlap_tokens=500):
        """Split text into overlapping chunks that fit within token limits"""
        # Rough estimate: 1 token ≈ 4 characters
        max_chars = max_tokens * 4
        overlap_chars = overlap_tokens * 4
        
        if len(text) <= max_chars:
            return [text]
        
        chunks = []
        start = 0
        
        while start < len(text):
            end = start + max_chars
            
            # If this isn't the last chunk, try to break at a sentence or paragraph
            if end < len(text):
                # Look for sentence endings in the last 500 characters
                break_point = text.rfind('.', end - 500, end)
                if break_point == -1:
                    break_point = text.rfind('\n', end - 500, end)
                if break_point != -1:
                    end = break_point + 1
            
            chunk = text[start:end]
            chunks.append(chunk)
            
            # Move start position, accounting for overlap
            if end >= len(text):
                break
            start = end - overlap_chars
            
        print(f"Split text of {len(text)} characters into {len(chunks)} chunks")
        return chunks


    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 100
        with tqdm(total=len(texts), desc="Embedding chunks") as pbar:
            result = []
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                response = self.client.embeddings.create(
                    input = batch,
                    model='text-embedding-3-small'
                )
                
                batch_embeddings = [item.embedding for item in response.data]
                result.extend(batch_embeddings)
                pbar.update(len(batch))
        
        self.embeddings = result
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.client.embeddings.create(
                input = [query], 
                model="text-embedding-3-small"
                ).data[0].embedding
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]
        
        top_results = []

        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

    def validate_embedded_chunks(self):
        unique_contents = set()
        for meta in self.metadata:
            unique_contents.add(meta['content'])
    
        print(f"Validation results:")
        print(f"Total embedded chunks: {len(self.metadata)}")
        print(f"Unique embedded contents: {len(unique_contents)}")
    
        if len(self.metadata) != len(unique_contents):
            print("Warning: There may be duplicate chunks in the embedded data.")
        else:
            print("All embedded chunks are unique.")

## Level 1 - Basic RAG

To get started, we'll set up a basic RAG pipeline using a bare bones approach. This is sometimes called 'Naive RAG' by many in the industry. A basic RAG pipeline includes the following 3 steps:

1) Chunk documents by heading - containing only the content from each subheading

2) Embed each document

3) Use Cosine similarity to retrieve documents in order to answer query

In [6]:
# Load and transform the benefits_wellbeing data
# Create the corresponding Pickle file for the VectorDB

import json
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
import logging
from typing import Callable, List, Dict, Any, Tuple, Set

with open('data/employee_handbook.json', 'r') as f:
    employee_handbook_raw = json.load(f)

def transform_data_for_vectordb(raw_data):
    """Transform the benefits_wellbeing data to match VectorDB structure"""
    transformed = []
    
    # Create a single document containing all chunks
    doc = {
        "doc_id": "benefits_wellbeing",
        "original_uuid": "benefits_wellbeing_doc",
        "chunks": []
    }
    
    for i, item in enumerate(raw_data):
        chunk = {
            "chunk_id": f"chunk_{i}",
            "original_index": i,
            "content": item["text"],
            "heading": item["chunk_heading"],
            "link": item["chunk_link"]
        }
        doc["chunks"].append(chunk)
    
    transformed.append(doc)
    return transformed

# Transform the data to match VectorDB expected structure
employee_handbook = transform_data_for_vectordb(employee_handbook_raw)

# Initialize the VectorDB
db = VectorDB("employee_handbook")
db.load_data(employee_handbook)

print(f"Loaded {len(employee_handbook[0]['chunks'])} chunks into VectorDB")


Processing chunks: 100%|██████████| 77/77 [00:00<00:00, 819699.01it/s]
Embedding chunks: 100%|██████████| 77/77 [00:01<00:00, 56.71it/s]

Vector database loaded and saved. Total chunks processed: 77
Loaded 77 chunks into VectorDB





In [56]:
# RAG Functions

def retrieve_base(query, db, k=5):
    """Retrieve relevant documents and format context"""
    results = db.search(query, k=k)
    context = ""
    for result in results:
        chunk = result['metadata']
        context += f"\n{chunk['content']}\n"
    return results, context

def answer_query_base(query, db):
    """Answer a query using the RAG pipeline"""
    documents, context = retrieve_base(query, db)
    prompt = f""" ### SYSTEM ###
    You are **Uniswap Benefits Assistant**.

    You have been provided with relevant company documents to answer employee questions.

    **Workflow for this query:**
    1. **Analyze the user question**: {query}
    2. **Review the provided context** below for relevant information
    3. **Write a natural-language answer** following these rules:
    • Use only facts that appear verbatim in the provided context
    • If the information isn't in the context, reply: "I don't have that information in the provided context."
    • Compose the most **expansive, detailed answer possible** by weaving together **every relevant fact** found in the context—rephrasing, grouping, and elaborating on those facts for clarity and flow
    • You may explain terms, list related details, and provide a logical structure
    • **Never introduce information that is not stated verbatim in the context**
    • **Always aim to reduce verbosity**
    4. **Do not** provide preamble such as "Here is the answer" or "Based on the documents"
    5. **Always append exactly**: "Double-check with Julian or Megan for any of this information!"

    ### USER QUESTION ###
    {query}

    ### CONTEXT ###
    {context}

    ### RESPONSE ###"""
    
    response = client.messages.create(
        model="claude-sonnet-4-20250514", #output should use a stronger model than context enrichment models
        max_tokens=2500,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0 #creates a deterministic, more factual response
    )
    # Handle the response content properly
    try:
        return response.content[0].text
    except AttributeError:
        return str(response.content[0])



In [57]:
# Test the naive RAG system
test_query = "Do we get any sort of wfh-setup-refresh benefits? I'm tormented by the fact that I didn't get a proper chair within my first year of joining"
answer = answer_query_base(test_query, db)
print(f"Query: {test_query}")
print(f"Answer: {answer}")

Query: Do we get any sort of wfh-setup-refresh benefits? I'm tormented by the fact that I didn't get a proper chair within my first year of joining
Answer: Uniswap Labs provides a Home Office Set up benefit for remote employees, which reimburses up to $2,000 USD to cover the purchase of office supplies, productivity items, and anything else you might need to get your home office set up. This benefit is designed to help you create a proper workspace at home.

However, I don't have information in the provided context about whether this is a one-time benefit or if there are periodic refresh opportunities for work-from-home setup expenses beyond the initial $2,000 reimbursement.

Additionally, if you prefer working from a co-working space instead of your home office, Uniswap Labs reimburses the cost up to $500 USD per month for co-working space expenses.

Double-check with Julian or Megan for any of this information!


## Level 2: Contextual Embeddings

With basic RAG, each embedded chunk contains a potentially useful piece of information, but these chunks lack context. With Contextual Embeddings, we create a variation on the embedding itself by adding more context to each text chunk before embedding it. Specifically, we use Claude to create a concise context that explains the chunk using the context of the overall document. In the case of our codebases dataset, we can provide both the chunk and the full file that each chunk was found within to an LLM, then produce the context. Then, we will combine this 'context' and the raw text chunk together into a single text block prior to creating each embedding.

### Additional Considerations: Cost and Latency

The extra work we're doing to 'situate' each document happens only at ingestion time: it's a cost you'll pay once when you store each document (and periodically in the future if you have a knowledge base that updates over time). There are many approaches like HyDE (hypothetical document embeddings) which involve performing steps to improve the representation of the query prior to executing a search. These techniques have shown to be moderately effective, but they add significant latency at runtime.

[Prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) also makes this much more cost effective. Creating contextual embeddings requires us to pass the same document to the model for every chunk we want to generate extra context for. With prompt caching, we can write the overall doc to the cache once, and then because we're doing our ingestion job all in sequence, we can just read the document from cache as we generate context for each chunk within that document (the information you write to the cache has a 5 minute time to live). This means that the first time we pass a document to the model, we pay a bit more to write it to the cache, but for each subsequent API call that contains that doc, we receive  a 90% discount on all of the input tokens read from the cache. Assuming 800 token chunks, 8k token documents, 50 token context instructions, and 100 tokens of context per chunk, the cost to generate contextualized chunks is $1.02 per million document tokens.

When you load data into your ContextualVectorDB below, you'll see in logs just how big this impact is. 

Warning: some smaller embedding models have a fixed input token limit. Contextualizing the chunk makes it longer, so if you notice much worse performance from contextualized embeddings, the contextualized chunk is likely getting truncated

#### Create the contextual chunks

In [25]:
import json
from anthropic import Anthropic
from tqdm import tqdm

# Initialize Anthropic client
client = Anthropic()

# Template for the full benefits document
BENEFITS_DOCUMENT_CONTEXT_PROMPT = """
<document>
This is Uniswap Labs' comprehensive Benefits & Wellbeing documentation for employees. 
It covers health insurance, leave policies, perks, financial benefits, and time-off policies.

{doc_content}
</document>
"""

# Template for contextualizing individual chunks
CHUNK_CONTEXT_PROMPT = """
Here is a specific section from Uniswap's benefits documentation that we want to situate within the overall benefits package:

<chunk>
Section: {chunk_heading}
{chunk_content}
</chunk>

Please provide a short, succinct context to situate this benefits section within Uniswap's overall employee benefits package. This context will help employees find this information when searching for related benefits topics.

Answer only with the succinct context and nothing else.
"""

def situate_benefits_context(full_benefits_doc: str, chunk_heading: str, chunk_content: str) -> str:
    """
    Generate contextual information for a benefits chunk within the full benefits document
    """
    response = client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=150,  # Keep context concise
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": BENEFITS_DOCUMENT_CONTEXT_PROMPT.format(doc_content=full_benefits_doc),
                        "cache_control": {"type": "ephemeral"}  # Cache the full benefits doc
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(
                            chunk_heading=chunk_heading,
                            chunk_content=chunk_content
                        ),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response.content[0].text.strip()

def add_contextual_information(input_file, output_file, original_markdown_file):
    """
    Process all benefits chunks to add contextual information
    """
    # Load the chunked JSON data
    with open(input_file, 'r') as f:
        chunks = json.load(f)
    
    # Load the original full markdown document for context
    with open(original_markdown_file, 'r') as f:
        full_benefits_doc = f.read()
    
    print(f"Processing {len(chunks)} benefits chunks...")
    
    enhanced_chunks = []
    
    for chunk in tqdm(chunks, desc="Adding contextual information"):
        try:
            # Generate situational context
            situational_context = situate_benefits_context(
                full_benefits_doc=full_benefits_doc,
                chunk_heading=chunk['chunk_heading'],
                chunk_content=chunk['text']
            )
            
            # Create enhanced chunk
            enhanced_chunk = {
                "chunk_link": chunk["chunk_link"],
                "chunk_heading": chunk["chunk_heading"],
                "text": chunk["text"],
                "situational_context": situational_context
            }
            
            enhanced_chunks.append(enhanced_chunk)
            
        except Exception as e:
            print(f"Error processing chunk '{chunk['chunk_heading']}': {e}")
            # Add chunk without context if there's an error
            enhanced_chunks.append(chunk)
    
    # Save the enhanced chunks
    with open(output_file, 'w') as f:
        json.dump(enhanced_chunks, f, indent=2)
    
    print(f"Contextual information added! Enhanced chunks saved to {output_file}")
    return enhanced_chunks

def preview_contextual_chunks(chunks, num_chunks=2):
    """Preview the contextual information added to chunks"""
    print(f"\nPreview of contextual information for first {num_chunks} chunks:")
    print("=" * 60)
    
    for i, chunk in enumerate(chunks[:num_chunks]):
        print(f"\nChunk {i+1}: {chunk['chunk_heading']}")
        print(f"Original text length: {len(chunk['text'])} characters")
        if 'situational_context' in chunk:
            print(f"Situational Context: {chunk['situational_context']}")
        print("-" * 40)

# Example usage
if __name__ == "__main__":
    # File paths - update these to match your setup
    input_json = "data/benefits_wellbeing.json"  # Your chunked JSON from the markdown converter
    original_markdown = "data/Benefits & Wellbeing.md"  # Your original markdown file
    output_json = "data/benefits_wellbeing_with_context.json"  # Output with contextual info
    
    try:
        # Process the chunks to add contextual information
        enhanced_chunks = add_contextual_information(
            input_file=input_json,
            output_file=output_json,
            original_markdown_file=original_markdown
        )
        
        # Preview the results
        preview_contextual_chunks(enhanced_chunks)
        
        print(f"\n✅ Complete! Your enhanced chunks are ready for RAG at: {output_json}")
        
    except FileNotFoundError as e:
        print(f"❌ Error: Could not find file - {e}")
        print("Please make sure your file paths are correct.")
    except Exception as e:
        print(f"❌ Error during processing: {e}")

Processing 5 benefits chunks...


Adding contextual information: 100%|██████████| 5/5 [00:07<00:00,  1.56s/it]

Contextual information added! Enhanced chunks saved to data/benefits_wellbeing_with_context.json

Preview of contextual information for first 2 chunks:

Chunk 1: Health Benefits
Original text length: 2176 characters
Situational Context: The Health Benefits section outlines the various health insurance plans, including medical, dental, and vision, that Uniswap offers to its employees as part of their comprehensive benefits package. This section provides details on the specific providers, plan options, and resources available to help employees navigate and access these health-related benefits.
----------------------------------------

Chunk 2: Leaves
Original text length: 15766 characters
Situational Context: This section on Leaves, including Parental Leave and Unpaid Leaves of Absence, is part of Uniswap's comprehensive Benefits & Wellbeing documentation for employees. It outlines the company's policies and processes around various types of leaves, providing details on eligibility, proc




In [None]:
def create_contextual_dataset():
    """
    Transform benefits_wellbeing_with_context data to match ContextualVectorDB expected format
    """
    
    # Load the original chunked data
    with open('data/benefits_wellbeing_with_context.json', 'r') as f:
        chunks_data = json.load(f)
    
    # Load the full markdown document content
    with open('data/Benefits & Wellbeing.md', 'r') as f:
        full_document_content = f.read()
    
    print(f"Loaded {len(chunks_data)} chunks from benefits_wellbeing.json")
    print(f"Full document length: {len(full_document_content)} characters")
    
    # Create the properly formatted dataset
    contextual_dataset = [{
        "doc_id": "benefits_wellbeing",
        "original_uuid": "benefits_wellbeing_uuid", 
        "content": full_document_content,  # This is what ContextualVectorDB needs!
        "chunks": []
    }]
    
    # Transform each chunk to the expected format
    for i, chunk in enumerate(chunks_data):
        formatted_chunk = {
            "chunk_id": f"chunk_{i}",
            "original_index": i,
            "content": chunk["text"],  # Change from 'text' to 'content'
            "heading": chunk["chunk_heading"],
            "link": chunk["chunk_link"]
        }
        contextual_dataset[0]["chunks"].append(formatted_chunk)
    
    # Save the transformed dataset
    output_file = 'data/benefits_wellbeing_contextualDBformat.json'
    with open(output_file, 'w') as f:
        json.dump(contextual_dataset, f, indent=2)
    
    print(f"✅ Created contextual dataset with {len(contextual_dataset[0]['chunks'])} chunks")
    print(f"✅ Saved to: {output_file}")
    
    return contextual_dataset

# Create the properly formatted dataset
contextual_dataset = create_contextual_dataset()


Loaded 5 chunks from benefits_wellbeing.json
Full document length: 41272 characters
✅ Created contextual dataset with 5 chunks
✅ Saved to: data/benefits_wellbeing_contextual_format.json


#### New DB, ContextualDB

In [12]:
import os
import pickle
import json
import numpy as np
from openai import OpenAI
from typing import List, Dict, Any
from tqdm import tqdm
import anthropic
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

class ContextualVectorDB:
    def __init__(self, name: str, anthropic_api_key=None, openai_api_key=None):
        if anthropic_api_key is None:
            anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
        if openai_api_key is None:
            openai_api_key = os.getenv("OPENAI_API_KEY")
        
        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
        self.openai_client = OpenAI(api_key=openai_api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/contextual_vector_db.pkl"

        self.token_counts = {
            'input': 0,
            'output': 0,
            'cache_read': 0,
            'cache_creation': 0
        }
        self.token_lock = threading.Lock()

    def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:
        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        Here is the chunk we want to situate within the whole document
        <chunk>
        {chunk_content}
        </chunk>

        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
        Answer only with the succinct context and nothing else.
        """

        response = self.anthropic_client.messages.create(
            model="claude-3-haiku-20240307",
            max_tokens=1000,
            temperature=0.0,
            messages=[
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "text",
                            "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                            "cache_control": {"type": "ephemeral"} #we will make use of prompt caching for the full documents
                        },
                        {
                            "type": "text",
                            "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                        },
                    ]
                },
            ],
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
        )
        # Handle the response content properly
        try:
            response_text = response.content[0].text
        except AttributeError:
            response_text = str(response.content[0])
        return response_text, response.usage

    def load_data(self, dataset: List[Dict[str, Any]], parallel_threads: int = 1):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)

        def process_chunk(doc, chunk):
            #for each chunk, produce the context
            contextualized_text, usage = self.situate_context(doc['content'], chunk['content'])
            with self.token_lock:
                self.token_counts['input'] += usage.input_tokens
                self.token_counts['output'] += usage.output_tokens
                self.token_counts['cache_read'] += usage.cache_read_input_tokens
                self.token_counts['cache_creation'] += usage.cache_creation_input_tokens
            
            return {
                #append the context to the original text chunk
                'text_to_embed': f"{chunk['content']}\\n\\n{contextualized_text}",
                'metadata': {
                    'doc_id': doc['doc_id'],
                    'original_uuid': doc['original_uuid'],
                    'chunk_id': chunk['chunk_id'],
                    'original_index': chunk['original_index'],
                    'original_content': chunk['content'],
                    'contextualized_content': contextualized_text
                }
            }

        print(f"Processing {total_chunks} chunks with {parallel_threads} threads")
        with ThreadPoolExecutor(max_workers=parallel_threads) as executor:
            futures = []
            for doc in dataset:
                for chunk in doc['chunks']:
                    futures.append(executor.submit(process_chunk, doc, chunk))
            
            for future in tqdm(as_completed(futures), total=total_chunks, desc="Processing chunks"):
                result = future.result()
                texts_to_embed.append(result['text_to_embed'])
                metadata.append(result['metadata'])

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        #logging token usage
        print(f"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
        print(f"Total input tokens without caching: {self.token_counts['input']}")
        print(f"Total output tokens: {self.token_counts['output']}")
        print(f"Total input tokens written to cache: {self.token_counts['cache_creation']}")
        print(f"Total input tokens read from cache: {self.token_counts['cache_read']}")
        
        total_tokens = self.token_counts['input'] + self.token_counts['cache_read'] + self.token_counts['cache_creation']
        savings_percentage = (self.token_counts['cache_read'] / total_tokens) * 100 if total_tokens > 0 else 0
        print(f"Total input token savings from prompt caching: {savings_percentage:.2f}% of all input tokens used were read from cache.")
        print("Tokens read from cache come at a 90 percent discount!")

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 100
        result = []
        with tqdm(total=len(texts), desc="Embedding chunks") as pbar:
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                response = self.openai_client.embeddings.create(
                    input = batch,
                    model = "text-embedding-3-small"
                )
                batch_embeddings = [item.embedding for item in response.data]
                result.extend(batch_embeddings)
                pbar.update(len(batch))

        self.embeddings = result
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            response = self.openai_client.embeddings.create(
                input = [query],
                model = "text-embedding-3-small"
            )
            query_embedding = response.data[0].embedding
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]
        
        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

print("✅ ContextualVectorDB class ready!")


✅ ContextualVectorDB class ready!


In [49]:
# Now let's test the ContextualVectorDB with our properly formatted data
print("Loading the contextual dataset...")

# Initialize the ContextualVectorDB
contextual_db = ContextualVectorDB("benefits_wellbeing_contextual")

# Load and process the data with contextual embeddings
# Note: This will use Claude to generate context for each chunk and then create enhanced embeddings
print("Processing with ContextualVectorDB (this may take a few minutes)...")
contextual_db.load_data(contextual_dataset, parallel_threads=1)

print("✅ ContextualVectorDB ready for enhanced retrieval!")

# Validate the database loaded correctly
print(f"Database contains {len(contextual_db.metadata)} contextualized chunks")
if contextual_db.metadata:
    print("✅ Contextual embeddings have been created and cached successfully!")
else:
    print("❌ No data found in contextual database")

Loading the contextual dataset...
Processing with ContextualVectorDB (this may take a few minutes)...
Loading vector database from disk.
✅ ContextualVectorDB ready for enhanced retrieval!
Database contains 5 contextualized chunks
✅ Contextual embeddings have been created and cached successfully!


In [50]:
# Contextual RAG Functions

def retrieve_contextual(query, contextual_db, k=3):
    """Retrieve relevant documents and format context from contextual database"""
    results = contextual_db.search(query, k=k)
    context = ""
    for result in results:
        chunk = result['metadata']
        # Use the original content for context (the contextualized content was used for better retrieval)
        context += f"\n{chunk['original_content']}\n"
    return results, context

def answer_query_contextual(query, contextual_db):
    """Answer a query using the Contextual RAG pipeline"""
    documents, context = retrieve_contextual(query, contextual_db)
    prompt = f""" ### SYSTEM ###
    You are **Uniswap Benefits Assistant**.

    You have been provided with relevant company documents to answer employee questions.

    **Workflow for this query:**
    1. **Analyze the user question**: {query}
    2. **Review the provided context** below for relevant information
    3. **Write a natural-language answer** following these rules:
    • Use only facts that appear verbatim in the provided context
    • If the information isn't in the context, reply: "I don't have that information in the provided context."
    • Compose the most **expansive, detailed answer possible** by weaving together **every relevant fact** found in the context—rephrasing, grouping, and elaborating on those facts for clarity and flow
    • You may explain terms, list related details, and provide a logical structure
    • **Never introduce information that is not stated verbatim in the context**
    4. **Do not** provide preamble such as "Here is the answer" or "Based on the documents"
    5. **Always append exactly**: "Double-check with Julian or Megan for any of this information!"

    ### USER QUESTION ###
    {query}

    ### CONTEXT ###
    {context}

    ### RESPONSE ###"""
    
    response = client.messages.create(
        model="claude-sonnet-4-20250514", #output should use a stronger model than context enrichment models
        max_tokens=2500,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0 #creates a deterministic, more factual response
    )
    # Handle the response content properly
    try:
        if hasattr(response.content[0], 'text'):
            return response.content[0].text
        else:
            return str(response.content[0])
    except (AttributeError, IndexError):
        return str(response.content)


In [51]:
# Comprehensive Testing of Contextual RAG vs Basic RAG
print("=== TESTING CONTEXTUAL RAG SYSTEM ===\n")

# Test queries that should benefit from contextual understanding
test_queries = [
    "Do we have any sort of discounts on rental cars?",
    "What gym benefits are available?", 
    "How does parental leave work?",
    "What learning opportunities does the company provide?",
    "Are there any financial wellbeing benefits?"
]

for i, query in enumerate(test_queries, 1):
    print(f"{i}. Query: {query}")
    print("-" * 50)
    
    # Get contextual RAG answer
    try:
        contextual_answer = answer_query_contextual(query, contextual_db)
        print(f"Contextual RAG Answer: {contextual_answer[:200]}...")
        print("✅ Contextual RAG working!\n")
    except Exception as e:
        print(f"❌ Contextual RAG Error: {e}\n")

print("\n=== CONTEXTUAL RAG SYSTEM STATUS ===")
print("Your contextual RAG system is set up to:")
print("• Use Claude to generate situational context for each chunk")
print("• Create enhanced embeddings with contextual information")
print("• Provide more accurate retrieval for complex queries")
print("• Maintain original content for final answers")
print("\n✅ Ready for production use with benefits_wellbeing_contextual dataset!")

=== TESTING CONTEXTUAL RAG SYSTEM ===

1. Query: Do we have any sort of discounts on rental cars?
--------------------------------------------------
Contextual RAG Answer: Yes, Uniswap offers discounts on rental cars through partnerships with two major rental car companies:

**Hertz Discounted Car Rentals**: Uniswap has partnered with Hertz to provide savings on rental ...
✅ Contextual RAG working!

2. Query: What gym benefits are available?
--------------------------------------------------
Contextual RAG Answer: Uniswap offers gym benefits through a partnership with Equinox. The company has partnered with Equinox to offer you $165 off your monthly membership fees.

If you're not near a club location, you can ...
✅ Contextual RAG working!

3. Query: How does parental leave work?
--------------------------------------------------
Contextual RAG Answer: Parental leave at Uniswap is designed to support new parents through a comprehensive process managed in partnership with Cocoon. Here's

# COMBINED NAIVE RAG

In [None]:
## Creating a Combined VectorDB with Multiple Documents

# Let's create a combined dataset that includes both benefits_wellbeing and employee_handbook
def create_combined_dataset():
    """
    Create a combined dataset with both benefits_wellbeing and employee_handbook documents
    """
    
    # Load both JSON files
    with open('data/benefits_wellbeing.json', 'r') as f:
        benefits_data = json.load(f)
    
    with open('data/employee_handbook.json', 'r') as f:
        handbook_data = json.load(f)
    
    print(f"Loaded {len(benefits_data)} benefits chunks")
    print(f"Loaded {len(handbook_data)} handbook chunks")
    
    def transform_to_combined_format(raw_data, doc_id, doc_type):
        """Transform data to VectorDB format with document identification"""
        doc = {
            "doc_id": doc_id,
            "original_uuid": f"{doc_id}_uuid",
            "doc_type": doc_type,  # Add document type for easier filtering
            "chunks": []
        }
        
        for i, item in enumerate(raw_data):
            chunk = {
                "chunk_id": f"{doc_id}_chunk_{i}",
                "original_index": i,
                "content": item["text"],
                "heading": item["chunk_heading"],
                "link": item["chunk_link"],
                "doc_type": doc_type,  # Also add to chunk metadata
                "source_doc": doc_id  # Clear source identification
            }
            doc["chunks"].append(chunk)
        
        return doc
    
    # Transform both datasets
    benefits_doc = transform_to_combined_format(benefits_data, "benefits_wellbeing", "benefits")
    handbook_doc = transform_to_combined_format(handbook_data, "employee_handbook", "handbook")
    
    # Combine into a single dataset
    combined_dataset = [benefits_doc, handbook_doc]
    
    # Save the combined dataset
    output_file = 'data/combined_documents.json'
    with open(output_file, 'w') as f:
        json.dump(combined_dataset, f, indent=2)
    
    total_chunks = len(benefits_doc["chunks"]) + len(handbook_doc["chunks"])
    print(f"✅ Created combined dataset with {total_chunks} total chunks")
    print(f"   - Benefits & Wellbeing: {len(benefits_doc['chunks'])} chunks")
    print(f"   - Employee Handbook: {len(handbook_doc['chunks'])} chunks")
    print(f"✅ Saved to: {output_file}")
    
    return combined_dataset

# Create the combined dataset
combined_dataset = create_combined_dataset()


Loaded 5 benefits chunks
Loaded 77 handbook chunks
✅ Created combined dataset with 82 total chunks
   - Benefits & Wellbeing: 5 chunks
   - Employee Handbook: 77 chunks
✅ Saved to: data/combined_documents.json


In [39]:
# Initialize a combined VectorDB
print("Creating combined VectorDB with both documents...")

# Initialize the VectorDB for combined documents
combined_db = VectorDB("combined_documents")
combined_db.load_data(combined_dataset)

total_chunks = sum(len(doc['chunks']) for doc in combined_dataset)
print(f"✅ Combined VectorDB ready! Total chunks: {total_chunks}")


Creating combined VectorDB with both documents...
Loading vector database from disk.
✅ Combined VectorDB ready! Total chunks: 82


In [40]:
# Enhanced RAG Functions for Combined Database

def retrieve_combined(query, combined_db, k=5):
    """Retrieve relevant documents from combined database with source information"""
    results = combined_db.search(query, k=k)
    context = ""
    sources = []
    
    for result in results:
        chunk = result['metadata']
        
        # Extract source information - handle the actual metadata structure
        doc_id = chunk.get('doc_id', 'unknown_doc')
        chunk_id = chunk.get('chunk_id', 'unknown_chunk')
        
        # Determine source document from doc_id
        if 'benefits' in doc_id.lower():
            source_doc = 'benefits_wellbeing'
        elif 'handbook' in doc_id.lower() or 'employee' in doc_id.lower():
            source_doc = 'employee_handbook'
        else:
            source_doc = doc_id
        
        # Try to extract heading from content (first line if it starts with #)
        content = chunk.get('content', '')
        heading = 'Unknown Section'
        content_lines = content.split('\n')
        for line in content_lines[:3]:  # Check first 3 lines
            if line.strip().startswith('#'):
                heading = line.strip().replace('#', '').strip()
                break
        
        # If no heading found, use chunk_id as fallback
        if heading == 'Unknown Section':
            heading = chunk_id
        
        # Include source information in context
        source_info = f"[Source: {source_doc} - {heading}]"
        context += f"\n{source_info}\n{content}\n"
        sources.append({
            'source_doc': source_doc,
            'heading': heading,
            'similarity': result['similarity']
        })
    
    return results, context, sources

def answer_query_combined(query, combined_db):
    """Answer a query using the Combined RAG pipeline"""
    documents, context, sources = retrieve_combined(query, combined_db)
    
    # Create source summary
    source_summary = "Sources consulted:\n"
    for source in sources:
        source_summary += f"• {source['source_doc']}: {source['heading']} (similarity: {source['similarity']:.3f})\n"
    
    prompt = f""" ### SYSTEM ###
    You are **Uniswap Employee Assistant**.

    You have been provided with relevant company documents from multiple sources to answer employee questions.

    **Workflow for this query:**
    1. **Analyze the user question**: {query}
    2. **Review the provided context** below for relevant information
    3. **Write a natural-language answer** following these rules:
    • Use only facts that appear verbatim in the provided context
    • If the information isn't in the context, reply: "I don't have that information in the provided context."
    • Compose the most **expansive, detailed answer possible** by weaving together **every relevant fact** found in the context—rephrasing, grouping, and elaborating on those facts for clarity and flow
    • You may explain terms, list related details, and provide a logical structure
    • **Never introduce information that is not stated verbatim in the context**
    • When referencing information, mention which document type it comes from (e.g., "According to the benefits documentation..." or "The employee handbook states...")
    4. **Do not** provide preamble such as "Here is the answer" or "Based on the documents"
    5. **Always append exactly**: "Double-check with Julian or Megan for any of this information!"

    ### USER QUESTION ###
    {query}

    ### CONTEXT ###
    {context}

    ### RESPONSE ###"""
    
    response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=500,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )
    
    # Handle the response content properly
    try:
        answer = response.content[0].text
    except AttributeError:
        answer = str(response.content[0])
    
    return answer, source_summary


In [41]:
# First, let's check what the actual metadata structure looks like
print("Checking metadata structure in combined_db...")
test_results = combined_db.search("test", k=1)
if test_results:
    print("Sample metadata:", test_results[0]['metadata'])
    print("Metadata keys:", list(test_results[0]['metadata'].keys()))

Checking metadata structure in combined_db...
Sample metadata: {'doc_id': 'employee_handbook', 'original_uuid': 'employee_handbook_uuid', 'chunk_id': 'employee_handbook_chunk_58', 'original_index': 58, 'content': '##'}
Metadata keys: ['doc_id', 'original_uuid', 'chunk_id', 'original_index', 'content']


In [42]:
# Test the combined RAG system with a query that might span both documents
test_query = "Tell me about the company's gym benefits"
answer, sources = answer_query_combined(test_query, combined_db)

print(f"Query: {test_query}")
print(f"\n{sources}")
print(f"Combined Answer: {answer}")

Query: Tell me about the company's gym benefits

Sources consulted:
• employee_handbook: Employee Benefits {employee-benefits} (similarity: 0.470)
• benefits_wellbeing: Health Benefits (similarity: 0.424)
• benefits_wellbeing: 401k & Financial Benefits (similarity: 0.393)
• employee_handbook: Phones (similarity: 0.356)
• employee_handbook: Reasonable Accommodations:  Disability, Nursing Mothers and Religious  {reasonable-accommodations:-disability,-nursing-mothers-and-religious} (similarity: 0.339)

Combined Answer: Based on the provided company documents, I don't have information about gym benefits in the provided context. 

The employee handbook and benefits documentation detail various wellness and health benefits including medical coverage through Anthem Blue Cross and Kaiser, dental plans through Guardian, vision insurance through VSP, access to One Medical, Maven for fertility and family planning, and an Employee Assistance Program for counseling services. The company also provid

In [43]:
# Let's first debug what keys are available in the metadata
test_query_debug = "What is our at-will employment policy?"

# Fix: Use combined_db instead of the empty contextual_db
results = combined_db.search(test_query_debug, k=3)

print("Debug: Available metadata keys in search results:")
print("=" * 50)
if results:
    first_result = results[0]
    print(f"Available keys in result: {list(first_result.keys())}")
    print(f"Available keys in metadata: {list(first_result['metadata'].keys())}")
    print(f"Sample metadata: {first_result['metadata']}")
else:
    print("No results found!")

# Let's test some specific queries that demonstrate the power of combined search

test_queries = [
    "What is our at-will employment policy?",  # Should find employee handbook
    "What gym benefits do we have?",           # Should find benefits document  
    "What are our company principles?",        # Should find employee handbook
    "How does parental leave work?",           # Should find benefits document
    "What does it mean to be 'people first'?" # Should find employee handbook values
]

print("\n\nTesting Combined VectorDB with various queries:")
print("=" * 60)

for i, query in enumerate(test_queries, 1):
    print(f"\n{i}. Query: {query}")
    print("-" * 40)
    
    # Use the working retrieve_combined function that properly handles metadata
    results, context, sources = retrieve_combined(query, combined_db, k=3)
    
    print("Top sources found:")
    for j, source in enumerate(sources[:2]):  # Show top 2 results
        print(f"   {j+1}. {source['source_doc']} - {source['heading']} (similarity: {source['similarity']:.3f})")
    
    print()


Debug: Available metadata keys in search results:
Available keys in result: ['metadata', 'similarity']
Available keys in metadata: ['doc_id', 'original_uuid', 'chunk_id', 'original_index', 'content']
Sample metadata: {'doc_id': 'employee_handbook', 'original_uuid': 'employee_handbook_uuid', 'chunk_id': 'employee_handbook_chunk_0', 'original_index': 0, 'content': '## For our team members working outside of New York State:\n\nThis Handbook is intended for use by team members in all states where Uniswap Labs has team members, but it also provides certain information applicable only to team members working in New York State. Team members working outside of New York State may be covered by different state-specific policies and benefits, and team members working in certain states may be provided with a state handbook supplement summarizing specific policies or benefits applicable in those states.  \n\nNeither this handbook nor any other communication whether oral or written, is intended in a

In [None]:
# Let's test some specific queries that demonstrate the power of combined search

test_queries = [
    "What is our at-will employment policy?",  # Should find employee handbook
    "What gym benefits do we have?",           # Should find benefits document  
    "What are our company principles?",        # Should find employee handbook
    "How does parental leave work?",           # Should find benefits document
    "What does it mean to be 'people first'?" # Should find employee handbook values
]

print("Testing Combined VectorDB with various queries:")
print("=" * 60)

for i, query in enumerate(test_queries, 1):
    print(f"\n{i}. Query: {query}")
    print(f"Answer: {answer_query_base(query, db)}")
    print("-" * 40)
    
    # Use the working retrieve_combined function that properly handles metadata
    results, context, sources = retrieve_combined(query, combined_db, k=5)
    
    print("Top sources found:")
    for j, source in enumerate(sources[:3]):  # Show top 3 results
        print(f"   {j+1}. {source['source_doc']} - {source['heading']} (similarity: {source['similarity']:.3f})")
    print("-" * 40)
    print()


Testing Combined VectorDB with various queries:

1. Query: What is our at-will employment policy?
Answer: Employment with Uniswap Labs is "at-will". This means that employment may be terminated at any time by the employee or Uniswap Labs with or without cause, and with or without advance notice. The at-will relationship includes the right to hire, transfer, promote, reclassify, layoff, discipline, terminate or change any term or condition of employment (except for at-will employment itself) at any time with or without cause or advance notice.

The policy of "at will" employment cannot be changed except by a formal written agreement specifically entered into for this purpose signed by a member of the executive team. Neither the employee handbook, nor any policy or practice of Uniswap Labs, is intended to imply continued or guaranteed employment, or otherwise limit in any way the policy of at-will employment. In describing Uniswap Labs' policies and procedures, the handbook does not obli

In [45]:
# FIX for the KeyError issue - Helper function to safely extract metadata
def safe_extract_source_info(chunk_metadata):
    """Safely extract source info from metadata, handling different structures"""
    # Try to get source_doc and heading directly (new format)
    source_doc = chunk_metadata.get('source_doc')
    heading = chunk_metadata.get('heading')
    
    if source_doc and heading:
        return source_doc, heading
    
    # Fallback: extract from other fields
    doc_id = chunk_metadata.get('doc_id', 'unknown_doc')
    chunk_id = chunk_metadata.get('chunk_id', 'unknown_chunk')
    
    # Determine source document from doc_id
    if 'benefits' in doc_id.lower():
        source_doc = 'benefits_wellbeing'
    elif 'handbook' in doc_id.lower() or 'employee' in doc_id.lower():
        source_doc = 'employee_handbook'
    else:
        source_doc = doc_id
    
    # Try to extract heading from content
    if not heading:
        content = chunk_metadata.get('content', '')
        heading = 'Unknown Section'
        content_lines = content.split('\n')
        for line in content_lines[:3]:
            if line.strip().startswith('#'):
                heading = line.strip().replace('#', '').strip()
                break
        
        if heading == 'Unknown Section':
            heading = chunk_id
    
    return source_doc, heading

# FIXED version of the test queries
test_queries = [
    "What is our at-will employment policy?",  # Should find employee handbook
    "What gym benefits do we have?",           # Should find benefits document  
    "What are our company principles?",        # Should find employee handbook
    "How does parental leave work?",           # Should find benefits document
    "What does it mean to be 'people first'?" # Should find employee handbook values
]

print("Testing Combined VectorDB with various queries (FIXED VERSION):")
print("=" * 70)

for i, query in enumerate(test_queries, 1):
    print(f"\n{i}. Query: {query}")
    print("-" * 50)
    
    # Get search results to see which documents are retrieved
    try:
        results = combined_db.search(query, k=3)
        
        print("Top sources found:")
        for j, result in enumerate(results[:2]):  # Show top 2 results
            chunk = result['metadata']
            source_doc, heading = safe_extract_source_info(chunk)
            print(f"   {j+1}. {source_doc} - {heading} (similarity: {result['similarity']:.3f})")
    except Exception as e:
        print(f"   Error: {e}")
    
    print()

print("✅ Combined search tests completed!")
print("\nThe combined VectorDB successfully searches across both:")
print("• Benefits & Wellbeing document (5 chunks)")  
print("• Employee Handbook document (78+ chunks)")
print("\nThis approach allows you to:")
print("• Search all company documents in one query")
print("• Get source attribution for each result") 
print("• Combine information from multiple documents in answers")


Testing Combined VectorDB with various queries (FIXED VERSION):

1. Query: What is our at-will employment policy?
--------------------------------------------------
Top sources found:
   1. employee_handbook - For our team members working outside of New York State: (similarity: 0.446)
   2. employee_handbook - General Anti-Retaliation Policy (similarity: 0.411)


2. Query: What gym benefits do we have?
--------------------------------------------------
Top sources found:
   1. employee_handbook - Employee Benefits {employee-benefits} (similarity: 0.415)
   2. benefits_wellbeing - Health Benefits (similarity: 0.411)


3. Query: What are our company principles?
--------------------------------------------------
Top sources found:
   1. employee_handbook - **Uniswap Principles | Uni-code** (similarity: 0.510)
   2. employee_handbook - 4.5       	How the Company Complies with Anti-Corruption Laws (similarity: 0.455)


4. Query: How does parental leave work?
--------------------------------

## Creating a Combined Contextual VectorDB

For even better performance, you can also create a **Contextual** VectorDB that combines multiple documents. This approach gives you the benefits of:

1. **Multiple document sources** - Search across all your company documents
2. **Contextual embeddings** - Better retrieval quality through situational context
3. **Source attribution** - Know which document provided each piece of information

### Key Advantages of Combined Multi-Document RAG:

- **Comprehensive answers**: Pull information from multiple sources to give complete responses
- **Cross-document insights**: Find connections between different types of company information
- **Efficient search**: One search across all documents instead of separate searches
- **Better context**: Questions about company policies, benefits, culture all in one place

This is especially powerful for employee-facing applications where questions might span HR policies, benefits, company culture, procedures, etc.


In [46]:
# Optional: Create a Contextual Combined VectorDB for even better performance
# This combines contextual embeddings with multiple document sources

def create_combined_contextual_dataset():
    """
    Create a combined dataset for contextual embeddings with both documents
    """
    
    # Load both JSON files
    with open('data/benefits_wellbeing.json', 'r') as f:
        benefits_data = json.load(f)
    
    with open('data/employee_handbook.json', 'r') as f:
        handbook_data = json.load(f)
    
    # Load the full documents for context
    with open('data/Benefits & Wellbeing.md', 'r') as f:
        full_benefits_doc = f.read()
    
    with open('data/Employee Handbook.md', 'r') as f:
        full_handbook_doc = f.read()
    
    def transform_to_contextual_format(raw_data, doc_id, doc_type, full_content):
        """Transform data to ContextualVectorDB format with document identification"""
        doc = {
            "doc_id": doc_id,
            "original_uuid": f"{doc_id}_uuid",
            "doc_type": doc_type,
            "content": full_content,  # Full document content for contextual embeddings
            "chunks": []
        }
        
        for i, item in enumerate(raw_data):
            chunk = {
                "chunk_id": f"{doc_id}_chunk_{i}",
                "original_index": i,
                "content": item["text"],
                "heading": item["chunk_heading"],
                "link": item["chunk_link"],
                "doc_type": doc_type,
                "source_doc": doc_id
            }
            doc["chunks"].append(chunk)
        
        return doc
    
    # Transform both datasets
    benefits_doc = transform_to_contextual_format(
        benefits_data, "benefits_wellbeing", "benefits", full_benefits_doc
    )
    handbook_doc = transform_to_contextual_format(
        handbook_data, "employee_handbook", "handbook", full_handbook_doc
    )
    
    # Combine into a single dataset
    combined_contextual_dataset = [benefits_doc, handbook_doc]
    
    # Save the combined dataset
    output_file = 'data/combined_contextual_documents.json'
    with open(output_file, 'w') as f:
        json.dump(combined_contextual_dataset, f, indent=2)
    
    total_chunks = len(benefits_doc["chunks"]) + len(handbook_doc["chunks"])
    print(f"✅ Created combined contextual dataset with {total_chunks} total chunks")
    print(f"   - Benefits & Wellbeing: {len(benefits_doc['chunks'])} chunks")
    print(f"   - Employee Handbook: {len(handbook_doc['chunks'])} chunks")
    print(f"✅ Saved to: {output_file}")
    
    return combined_contextual_dataset

# Uncomment to create the contextual combined dataset
# combined_contextual_dataset = create_combined_contextual_dataset()

# To initialize a ContextualVectorDB with combined documents:
# combined_contextual_db = ContextualVectorDB("combined_contextual_documents")
# combined_contextual_db.load_data(combined_contextual_dataset, parallel_threads=1)

print("✅ Combined document RAG setup complete!")
print("\nYou now have the tools to:")
print("1. Combine multiple documents into a single VectorDB")
print("2. Search across all documents simultaneously") 
print("3. Get source attribution for each retrieved chunk")
print("4. Optionally use contextual embeddings for better retrieval")
print("\nThis approach scales to any number of documents!")


✅ Combined document RAG setup complete!

You now have the tools to:
1. Combine multiple documents into a single VectorDB
2. Search across all documents simultaneously
3. Get source attribution for each retrieved chunk
4. Optionally use contextual embeddings for better retrieval

This approach scales to any number of documents!
