In [None]:
!pip install transformers datasets torch

## Chunk and Index data into DB ##

In [None]:
from datasets import load_dataset
datasets = ['hagrid', 'hotpotqa', 'msmarco']

retrieval_model = "BAAI/LLM-Embedder"
# retrieval_model = "sentence-transformers/all-miniLM-L6-v2"

TOKEN_LIMIT = 512
SLIDING_WINDOW_OVERLAP = 100

# Function for chunking with token limit and sliding window
def chunk_with_token_limit(text, token_limit, overlap):
    sentences = sent_tokenize(text)  # Split text into sentences
    chunks = []  # Store resulting chunks
    current_chunk = []  # Temporarily hold sentences for the current chunk
    current_chunk_tokens = 0  # Token count for the current chunk

    for sentence in sentences:
        # Tokenize the sentence and calculate its token count
        sentence_tokens = tokenizer.tokenize(sentence)
        num_tokens = len(sentence_tokens)

        # print(f"Tokens: {sentence_tokens[0]}")

        # If adding this sentence exceeds the token limit
        if current_chunk_tokens + num_tokens > token_limit:
            # Save the current chunk
            chunk_text = " ".join(current_chunk)
            chunks.append(chunk_text)

            # Prepare the next chunk with overlap
            overlap_tokens = tokenizer.tokenize(" ".join(current_chunk[-1:]))
            current_chunk = [sentence for sentence in current_chunk[-(overlap // len(overlap_tokens)) :]] if current_chunk else []
            current_chunk_tokens = sum(len(tokenizer.tokenize(sent)) for sent in current_chunk)

        # Add the sentence to the current chunk
        current_chunk.append(sentence)
        current_chunk_tokens += num_tokens

    # Add the last chunk if it exists
    if current_chunk:
        chunk_text = " ".join(current_chunk)
        chunks.append(chunk_text)

    return chunks

def process_document_with_identifiers(document):
    processed_data = []
    title_count = -1  # to start from 0
    print("document>>>>>>>",document)
    for section in document:
        section_chunks = []
        passage_count = [ord('a')]  # Passage identifier as a list to handle nested increments
        title_count += 1  # Increment title count

        # Tokenize the section into sentences
        sentences = sent_tokenize(section)
        for sentence in sentences:
            if sentence.startswith("Title:"):
                # New document detected
                identifier = f"{title_count}{''.join(chr(c) for c in passage_count)}"  # Identifier for the title
                chunked_texts = chunk_with_token_limit(sentence, TOKEN_LIMIT, SLIDING_WINDOW_OVERLAP)
                for chunk in chunked_texts:
                    section_chunks.append([identifier, chunk])
                passage_count = [ord('a')]  # Reset passage count for the new document
            else:
                # Sentence under the current document
                identifier = f"{title_count}{''.join(chr(c) for c in passage_count)}"
                chunked_texts = chunk_with_token_limit(sentence, TOKEN_LIMIT, SLIDING_WINDOW_OVERLAP)
                #print("chunked_texts>>>>process_document_with_identifiers>>>>> "+ "".join(chunked_texts))
                for chunk in chunked_texts:
                    section_chunks.append([identifier, chunk])

                # Increment passage_count intelligently
                i = len(passage_count) - 1
                while i >= 0:
                    passage_count[i] += 1
                    if passage_count[i] > ord('z'):
                        passage_count[i] = ord('a')
                        if i == 0:
                            passage_count.insert(0, ord('a'))  # Add a new character to the identifier
                        i -= 1
                    else:
                        break


        print("section_chunks>>>>>>>",section_chunks)
        processed_data.append(section_chunks)

    return processed_data

## **Check uniqueness of data before insertion** ##

In [None]:
import hashlib

# Function to generate a hash based on content and key metadata
def generate_hash(content, metadata):
    """Generate a unique hash for the document content and key metadata."""
    key_fields = f"{content}|{metadata.get('item_index')}|{metadata.get('prefix')}"
    return hashlib.md5(key_fields.encode('utf-8')).hexdigest()

# Function to retrieve existing hashes from the database
def get_existing_hashes(collection):
    """Retrieve all existing hashes (IDs) currently in the database."""
    all_records = collection.get(include=["documents", "metadatas"])  # Fetch documents and metadata
    existing_hashes = set()
    for doc, metadata in zip(all_records["documents"], all_records["metadatas"]):
        doc_hash = generate_hash(doc, metadata)
        existing_hashes.add(doc_hash)
    return existing_hashes

# Function to retrieve existing hashes from the database
def get_existing_hashes_milvus(all_records):
    """Retrieve all existing hashes (IDs) currently in the database."""
    existing_hashes = set()
    print(f"all records >>> {len(all_records)}")    
    if all_records == None or len(all_records) == 0:
        return existing_hashes
        
    for doc, metadata in zip(all_records["documents"], all_records["metadata"]):
        doc_hash = generate_hash(doc, metadata)
        existing_hashes.add(doc_hash)
    return existing_hashes

## Store and retrieve data from Milvus** ##

In [None]:
import time
import numpy as np
from pymilvus import connections
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
from pymilvus import MilvusClient
from pymilvus import utility

class VectorDataStore:
    db_url = "http://localhost:19530"
    #description = f"collection created for {self.name}"

    def __init__(self, path="/content/ragbench.db"):
        self.client = MilvusClient(path)



    def create_collection(self, name, vec_dim=128):
        if self.client.has_collection(name):
            self.default_collection_name = name

        self.description = f"collection to store {name}"

        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="embedding",
            index_type="HNSW",
            params = {
                "M": 16, # Number of bidirectional links created for each element
                "efConstruction": 200 # Size of the dynamic list for the nearest neighbours during indexing
            }
            metric_type="COSINE"
        )
        schema = self.client.create_schema(
            auto_id=False,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.VARCHAR, max_length=64, is_primary=True)
        schema.add_field(field_name="metadata", datatype=DataType.JSON)
        schema.add_field(field_name="documents", datatype=DataType.VARCHAR, max_length=512)
        schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=vec_dim)
        schema.add_field(field_name="timestamp", datatype=DataType.INT64)
        
        collection = self.client.create_collection(collection_name=name,
                                       schema=schema,
                                       index_params=index_params)
        self.current_collection = collection
        return collection


    def get_collection(self, name):
        if not self.client.has_collection(name):
            raise ValueError(f"Collection '{name}' does not exist.")
        self.current_collection = Collection(name)
        return self.current_collection

    def get_all_records(self, collection):
        all_records = self.client.query(
            collection_name=collection,
            filter=None,
            output_fields=["documents", "metadata"],
            limit=10000
        )
        if all_records == None:
            all_records = []

        return all_records

    def has_entities(self, name):
        if not self.client.has_collection(name):
            raise ValueError(f"Collection '{name}' does not exists.")
        self.default_collection = name
        collection_stats = self.client.get_collection_stats(collection_name)
        count = collection_stats.get("row_count", 0)  # Retrieve the number of entities
        return count

    def insert(self, collection_name: str, metadata: list[dict[str, any]],
                documents: list[str], embeddings: np.ndarray, ids: list[int]):

        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist. Create it first.")

        if len(metadata) != len(embeddings) != len(documents) != len(ids):
           raise ValueError("Metadata, documnets, ids and embeddings must have the same length.")

        data = []
        for meta, doc, emb, id in zip(metadata, documents, embeddings, ids):
          datum = {
              "pk": id,
              "metadata": meta,
              "documents": doc,
              "embedding": emb.tolist(),
              "timestamp": int(time.time()),
          }
          data.append(datum)

        self.client.insert(collection_name, data)
        print(f"Inserted {len(metadata)} records into collection '{collection_name}'.")

    def drop_collection(self, collection_name: str):
        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist.")
        self.client.drop_collection(collection_name)
        print(f"Dropped collection '{collection_name}'.")

    def delete_all(self, collection_name: str):
        if not self.client.has_collection(collection_name):
            raise ValueError(f"Collection '{collection_name}' does not exist.")
        self.client.delete(collection_name, expr="pk >= 0")
        self.client.flush([collection_name])

    def search(self, query_embedding: np.ndarray, top_k: int = 10) -> list[dict[str, any]]:
        """
        Search across all collections for the top-k closest embeddings.
        :param query_embedding: The embedding vector to search for.
        :param top_k: Number of top results to retrieve.
        :return: A list of dictionaries containing collection name, id, metadata, and distance.
        """
        results = []
        #collections = self.client.list_collections()
        collections = ["ragbench_collection_techqa_v09"]
        start_time = time.time()
        for collection_name in collections:
            if not self.client.has_collection(collection_name):
                continue

            # Set params to COSINE to match chromadb
            search_params = {
                "metric_type": "COSINE", 
                "params": {
                    "ef": 64
                }
            }

            search_results = self.client.search(
                collection_name=collection_name,
                data=[query_embedding],
                anns_field="embedding",
                search_params=search_params,
                limit=top_k,
                output_fields=["metadata", "documents"]
            )

            for hits in search_results:
                for hit in hits:
                    print(f"Collection: {collection_name}, data: {str(hit)}")
                    results.append({
                        "collection": collection_name,
                        "id": hit["id"],
                        "metadata": hit["entity"]["metadata"],
                        "distance": hit["distance"],
                        "documents": hit["entity"]["documents"]
                      })

        results = sorted(results, key=lambda x: x["distance"])[:top_k]
        end_time = time.time()
        print(f"Search completed. Found {len(results)} results. in {end_time - start_time} secs")
        return results

    def extract_documents(self, search_results: list[dict[str, any]]) -> list[np.ndarray]:
      """
      Extract embedding values from search results.
      :param search_results: List of dictionaries containing search results.
      :return: List of embedding vectors as NumPy arrays.
      """
      return [np.array(result["documents"]) for result in search_results if "documents" in result]

In [None]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

datasets = ['hagrid', 'hotpotqa', 'msmarco']

all_documents = []
all_ids = []
all_metadatas = []

# Process each dataset
doc_idx = 0  # Global document index for unique IDs
for dataset in datasets:
    data = load_dataset("rungalileo/ragbench", dataset, split="train")
    # #only select first 5 records for debugging duplicate records. **PLEASE REMOVE THIS AFTER DEBUGGING**
    # data = data.select(range(2))
    for idx, row in tqdm(enumerate(data), desc=f"Processing {dataset}"):
        # Extract document text
        doc_text = row.get('documents', '')

        # Skip if no documents found
        if not doc_text:
            continue

        # Process the document
        processed_output = process_document_with_identifiers(doc_text)
        added_item_idxs = set()

        # Populate the lists
        for section_idx, section in enumerate(processed_output):
            for item_idx, (prefix, content) in enumerate(section):
                # Skip if this item_idx has already been processed
                if item_idx in added_item_idxs:
                    continue

                # Add the item_idx to the set to track it
                added_item_idxs.add(item_idx)

                # Add the document
                document = f"[{prefix}] {content}"
                all_documents.append(document)

                # Construct a globally unique ID
                doc_id = f"{dataset}_{doc_idx}_{section_idx}_{item_idx}"
                all_ids.append(doc_id)

                # Construct metadata
                metadata = {
                    "dataset": dataset,
                    "global_index": doc_idx,
                    "section_index": section_idx,
                    "item_index": item_idx,
                    "prefix": prefix,
                    "type": "Title" if prefix.endswith("a") else "Passage",
                }
                all_metadatas.append(metadata)

        doc_idx += 1  # Increment global document index

# Step 4: Generate Embeddings
#embedder = SentenceTransformer(retrieval_model)  # Pretrained sentence transformer
embedder = SentenceTransformer(retrieval_model)  # Pretrained sentence transformer
batch_size = 2500  # Adjust based on available memory

# Generate embeddings in batches
all_embeddings = []
for i in tqdm(range(0, len(all_documents), batch_size), desc="Generating embeddings"):
    batch_docs = all_documents[i:i + batch_size]
    batch_embeddings = embedder.encode(batch_docs, show_progress_bar=True)
    all_embeddings.extend(batch_embeddings)

In [None]:
questions = ['When was Rolex founded?', 'How large is the region of Macedonia?', 
             'Where is GMT Games headquartered?', 'What state is directly north of North Carolina?', 
             'When was Brown v. Board of Education?',
             
             'What star of Parks and Recreation appeared in November?', 
             'What is the capacity of the Stadium, other than Kauffman Stadium, designed by Charles Deaton ?', 
             'What was the island, on which Marinelli Glacier is located, formerly known as?', 
             'The American Sweetgum is the hostplant of what kind of bug?', 
             'The name of the Japanese rock band T-Bolan was inspired by the name of an English rock band formed in what year?',
             
             'symptoms of pregnancy before a missed period', 'monoclonal antibodies biology definition', 
             'what is iron sulfate', "who sang one day i'll fly away", 
             'describe the antebellum reform movement period'
            ]

## **Retrieve Candidates from DB **##