In [11]:
import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
from inference_adapter import InferenceAdapter

class ContextualVectorDB:
    def __init__(self, name: str):
        self.openai_client = InferenceAdapter()
        self.voyage_client = voyageai.Client()
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/contextual_vector_db.pkl"

        self.token_counts = {
            'input': 0,
            'output': 0,
            'cache_read': 0,
            'non_cached': 0
        }

    async def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:
        from pydantic import BaseModel, Field

        class ContextResponse(BaseModel):
            context: str = Field(..., description="The succinct context for the chunk")

        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        Here is the chunk we want to situate within the whole document
        <chunk>
        {chunk_content}
        </chunk>

        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
        Answer only with the succinct context and nothing else.
        """

        return await self.openai_client.predict_with_parse_async(
            {
                "model": "gpt-4o-2024-08-06",
                "temperature": 0,
            },
            response_format=ContextResponse,
            messages=[
                {
                    "role": "user",
                    "content": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc)
                },
                {
                    "role": "user",
                    "content": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk)
                },
            ],
        )

    async def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)

        async def process_chunk(doc, chunk):
            #for each chunk, produce the context
            contextualized_text, usage = await self.situate_context(doc['content'], chunk['content'])
            print(usage)
            self.token_counts['input'] += usage.prompt_tokens
            self.token_counts['output'] += usage.completion_tokens
            self.token_counts['cache_read'] += usage.prompt_tokens_details.get('cached_tokens', 0)
            self.token_counts['non_cached'] += usage.prompt_tokens - usage.prompt_tokens_details.get('cached_tokens', 0)

            return {
                #append the context to the original text chunk
                'text_to_embed': f"{chunk['content']}\n\n{contextualized_text.context}",
                'metadata': {
                    'doc_id': doc['doc_id'],
                    'original_uuid': doc['original_uuid'],
                    'chunk_id': chunk['chunk_id'],
                    'original_index': chunk['original_index'],
                    'original_content': chunk['content'],
                    'contextualized_content': contextualized_text.context
                }
            }

        print(f"Processing {total_chunks} chunks sequentially")
        for doc in tqdm(dataset[:], desc="Processing documents"):
            for chunk in doc['chunks']:
                result = await process_chunk(doc, chunk)
                print(result)
                texts_to_embed.append(result['text_to_embed'])
                metadata.append(result['metadata'])

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        #logging token usage
        print(f"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
        print(f"Total input tokens without caching: {self.token_counts['non_cached']}")
        print(f"Total output tokens: {self.token_counts['output']}")
        print(f"Total input tokens read from cache: {self.token_counts['cache_read']}")
        
        total_tokens = self.token_counts['input'] + self.token_counts['output']
        savings_percentage = (self.token_counts['cache_read'] / total_tokens) * 100 if total_tokens > 0 else 0
        print(f"Total input token savings from prompt caching: {savings_percentage:.2f}% of all input tokens used were read from cache.")

    #we use voyage AI here for embeddings. Read more here: https://docs.voyageai.com/docs/embeddings
    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        result = [
            self.voyage_client.embed(
                texts[i : i + batch_size],
                model="voyage-2"
            ).embeddings
            for i in range(0, len(texts), batch_size)
        ]
        self.embeddings = [embedding for batch in result for embedding in batch]
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.voyage_client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]
        
        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

In [12]:
# Load the transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the ContextualVectorDB
contextual_db = ContextualVectorDB("my_contextual_db")

# Load and process the data
#note: consider increasing the number of parallel threads to run this faster, or reducing the number of parallel threads if concerned about hitting your API rate limit
await contextual_db.load_data(transformed_dataset)

Processing 20 chunks sequentially


Processing documents:   0%|          | 0/2 [00:00<?, ?it/s]

CompletionUsage(completion_tokens=46, prompt_tokens=2567, total_tokens=2613, prompt_tokens_details={'cached_tokens': 2176}, completion_tokens_details={'reasoning_tokens': 0})
{'text_to_embed': '//! Executor for differential fuzzing.\n//! It wraps two executors that will be run after each other with the same input.\n//! In comparison to the [`crate::executors::CombinedExecutor`] it also runs the secondary executor in `run_target`.\n//!\nuse core::{cell::UnsafeCell, fmt::Debug, ptr};\n\nuse libafl_bolts::{ownedref::OwnedMutPtr, tuples::MatchName};\nuse serde::{Deserialize, Serialize};\n\nuse crate::{\n    executors::{Executor, ExitKind, HasObservers},\n    inputs::UsesInput,\n    observers::{DifferentialObserversTuple, ObserversTuple, UsesObservers},\n    state::UsesState,\n    Error,\n};\n\n/// A [`DiffExecutor`] wraps a primary executor, forwarding its methods, and a secondary one\n#[derive(Debug)]\npub struct DiffExecutor<A, B, OTA, OTB, DOT> {\n    primary: A,\n    secondary: B,\n   

Processing documents:  50%|█████     | 1/2 [00:21<00:21, 21.96s/it]

CompletionUsage(completion_tokens=43, prompt_tokens=2436, total_tokens=2479, prompt_tokens_details={'cached_tokens': 2304}, completion_tokens_details={'reasoning_tokens': 0})
{'text_to_embed': '    #[inline]\n    fn observers_mut(&mut self) -> &mut ProxyObserversTuple<OTA, OTB, DOT> {\n        unsafe {\n            self.observers\n                .get()\n                .as_mut()\n                .unwrap()\n                .set(self.primary.observers(), self.secondary.observers());\n            self.observers.get().as_mut().unwrap()\n        }\n    }\n}\n\n\nThe chunk is a method implementation within the `DiffExecutor` struct, which provides mutable access to the proxy observers, ensuring they are updated with the current observers from the primary and secondary executors.', 'metadata': {'doc_id': 'doc_1', 'original_uuid': '5e4c01057a10732d34784af2a97bee9d173863f043b9901de8ef7f57bc590145', 'chunk_id': 'doc_1_chunk_12', 'original_index': 12, 'original_content': '    #[inline]\n    fn o

Processing documents: 100%|██████████| 2/2 [00:32<00:00, 16.47s/it]

CompletionUsage(completion_tokens=54, prompt_tokens=1440, total_tokens=1494, prompt_tokens_details={'cached_tokens': 1024}, completion_tokens_details={'reasoning_tokens': 0})
{'text_to_embed': '    // Setup a mutational stage with a basic bytes mutator\n    let mutator = StdScheduledMutator::new(tuple_list!(\n        StringCategoryRandMutator,\n        StringSubcategoryRandMutator,\n        StringSubcategoryRandMutator,\n        StringSubcategoryRandMutator,\n        StringSubcategoryRandMutator\n    ));\n    let mut stages = tuple_list!(\n        StringIdentificationStage::new(),\n        StdMutationalStage::transforming(mutator)\n    );\n\n    fuzzer\n        .fuzz_loop(&mut stages, &mut executor, &mut state, &mut mgr)\n        .expect("Error in the fuzzing loop");\n}\n\n\nThis chunk describes the setup of a mutational stage in a fuzzing process using the libafl library. It involves creating a mutator with various string mutation strategies and integrating it into the fuzzing loop to




Contextual Vector database loaded and saved. Total chunks processed: 20
Total input tokens without caching: 8145
Total output tokens: 1182
Total input tokens read from cache: 35072
Total input token savings from prompt caching: 78.99% of all input tokens used were read from cache.
