<a href="https://colab.research.google.com/github/codeby3/searchable-encryption/blob/resource-utilization/resource-utilization/dcpe/subset_zilliz_dcpe_vector_search_resource_Util.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up

In [None]:
!pip install -qU beir sentence-transformers pymilvus datasets pycryptodome

In [None]:
!pip install --upgrade --quiet torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118

In [None]:
import os
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
import pandas as pd
from pymilvus import MilvusClient, FieldSchema, DataType, CollectionSchema, Collection
from sentence_transformers import SentenceTransformer
import torch
from google.colab import userdata
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import base64
import time
import numpy as np
import sys
import os
import random

In [None]:
sys.path.append(os.getcwd())
import dcpe

In [None]:
DCPE_KEY = dcpe.DCPEKey.generate_random(scaling_factor=1.2) # Using a scaling factor of 1.2
DCPE_APPROXIMATION_FACTOR = 1.0 # Using an approximation factor of 1.0 for a good tradeoff between sceurity and accuracy

# Load Datasets
We're using 3 datasets from the BEIR datasets - nfcorpus, fiqa and scidocs (https://huggingface.co/datasets/BeIR/beir)

In [None]:
datasets_to_load = ["nfcorpus", "fiqa", "scidocs"]
beir_data_path = "./beir_datasets" # Local directory to store BEIR data
os.makedirs(beir_data_path, exist_ok=True)
loaded_beir_data = {}

In [None]:
for dataset_name in datasets_to_load:
    print(f"\nProcessing dataset: {dataset_name}")

    # Step 3a: Download the dataset
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
    out_dir = os.path.join(beir_data_path, dataset_name)

    if not os.path.exists(out_dir):
        print(f"Downloading {dataset_name} from {url} to {out_dir}...")
        data_path = util.download_and_unzip(url, out_dir)
        print(f"Downloaded {dataset_name} to: {data_path}")
    else:
        print(f"Dataset {dataset_name} already exists at {out_dir}. Skipping download.")
        data_path = out_dir

    # Step 3b: Load the corpus, queries, and qrels
    try:
        corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
        loaded_beir_data[dataset_name] = {
            "corpus": corpus,
            "queries": queries,
            "qrels": qrels
        }
        print(f"Successfully loaded data for {dataset_name}.")
        print(f"  Corpus size: {len(corpus)} documents")
        print(f"  Queries size: {len(queries)} queries")
        print(f"  Qrels size: {len(qrels)} relevance judgments")

        # Print a sample document and query to verify
        if len(corpus) > 0:
            sample_doc_id = list(corpus.keys())[0]
            print(f"  Sample corpus entry ({sample_doc_id}): {corpus[sample_doc_id]['title']} - {corpus[sample_doc_id]['text'][:100]}...")
        if len(queries) > 0:
            sample_query_id = list(queries.keys())[0]
            print(f"  Sample query entry ({sample_query_id}): {queries[sample_query_id][:100]}...")

    except Exception as e:
        print(f"Error loading {dataset_name}: {e}")

### Using subsets of data

In [None]:
SUBSET_QUERY_PERCENTAGE = 0.05
RANDOM_SEED = 42 # VERY IMPORTANT: Use a fixed seed for reproducibility!

random.seed(RANDOM_SEED)
subset_beir_data = {}

In [None]:
for dataset_name, data in loaded_beir_data.items():
    print(f"\nCreating subset for dataset: {dataset_name}")

    corpus = data["corpus"]
    queries = data["queries"]
    qrels = data["qrels"]

    # Create a subset of queries ---
    num_queries_subset = int(len(queries) * SUBSET_QUERY_PERCENTAGE)
    query_ids = list(queries.keys())
    # Ensure there are queries to sample from
    if num_queries_subset > 0 and len(query_ids) >= num_queries_subset:
        subset_query_ids = set(random.sample(query_ids, num_queries_subset))
    else:
        subset_query_ids = set(query_ids) # Use all if subset is too small or percentage is 100%

    subset_queries = {qid: queries[qid] for qid in subset_query_ids}
    print(f"  Selected {len(subset_queries)} queries (out of {len(queries)})")

    # Create a corpus subset containing only relevant documents for the query subset ---
    # This is a better approach than random corpus sampling for ensuring evaluation is meaningful.
    relevant_corpus_ids = set()
    for qid in subset_query_ids:
        if qid in qrels:
            for doc_id, score in qrels[qid].items():
                if score > 0: # A score > 0 indicates relevance
                    relevant_corpus_ids.add(doc_id)

    # Filter the main corpus to only include these relevant documents
    subset_corpus = {cid: corpus[cid] for cid in relevant_corpus_ids if cid in corpus}
    print(f"  Selected {len(subset_corpus)} documents relevant to the query subset.")

    # Filter qrels to match the new query and corpus subsets ---
    subset_qrels = {}
    for qid, doc_scores in qrels.items():
        if qid in subset_query_ids: # Only consider queries in our subset
            filtered_scores = {did: score for did, score in doc_scores.items() if did in subset_corpus}
            if filtered_scores: # Only add the query if it has relevant docs in the subset
                subset_qrels[qid] = filtered_scores
    print(f"  Filtered qrels to {len(subset_qrels)} queries with relevant documents in the new corpus subset.")


    # Store the new subset data ---
    subset_beir_data[dataset_name] = {
        "corpus": subset_corpus,
        "queries": subset_queries,
        "qrels": subset_qrels
    }

# Ingest data in Zilliz Vector DB

In [None]:
ZILLIZ_CLOUD_URI = userdata.get("ZILLIZ_ENDPOINT")
ZILLIZ_CLOUD_API_KEY = userdata.get("ZILLIZ_TOKEN")

In [None]:
zilliz_client = MilvusClient(
        uri=ZILLIZ_CLOUD_URI,
        token=ZILLIZ_CLOUD_API_KEY
    )

Using all-MiniLM-L6-v2 embedding model from HuggingFace SentenceTransformers Library. It generates embeddings of 384 dimensions.



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model
try:
    embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
    print(f"Model 'all-MiniLM-L6-v2' loaded successfully on {device}.")
    # Verify model output dimension (all-MiniLM-L6-v2 has 384 dimensions)
    dummy_embedding = embedding_model.encode("test sentence")
    print(f"Model output dimension: {len(dummy_embedding)}")

except Exception as e:
    print(f"Failed to load embedding model: {e}")

Batch ingestion

In [None]:
DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
BATCH_SIZE = 64
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
ingestion_metrics = {}

In [None]:
for dataset_name in subset_beir_data:
    print(f"\nProcessing Zilliz ingestion for dataset: {dataset_name}")
    collection_name = f"subset_beir_{dataset_name.replace('-', '_')}_dcpe"

    if zilliz_client.has_collection(collection_name=collection_name):
        print(f"Collection '{collection_name}' already exists. Dropping and recreating.")
        zilliz_client.drop_collection(collection_name=collection_name)

    # Define fields
    fields = [
        FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=512, is_primary=True),
        FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=512),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
        FieldSchema(name="encrypted_vector", dtype=DataType.FLOAT_VECTOR, dim=DIMENSION),
        FieldSchema(name="iv", dtype=DataType.VARCHAR, max_length=64),
        FieldSchema(name="auth_hash", dtype=DataType.VARCHAR, max_length=64)
    ]
    schema = CollectionSchema(fields, description=f"BEIR {dataset_name} DCPE encrypted documents")

    print(f"Creating collection '{collection_name}'...")
    zilliz_client.create_collection(
        collection_name=collection_name,
        schema=schema,
        consistency_level="Strong"
    )
    print(f"Collection '{collection_name}' created successfully.")

    corpus = subset_beir_data[dataset_name]["corpus"]
    doc_ids = list(corpus.keys())
    num_documents = len(doc_ids)

    start_total_ingestion = time.perf_counter()

    documents_to_embed = []
    original_titles = []
    original_texts = []

    for doc_id in doc_ids:
        title = corpus[doc_id].get("title", "")
        text = corpus[doc_id].get("text", "")
        content_to_embed = f"{title} {text}".strip()
        documents_to_embed.append(content_to_embed)
        original_titles.append(title)
        original_texts.append(text)

    print(f"Generating embeddings and inserting encrypted data for '{dataset_name}' corpus ({len(doc_ids)} documents)...")

    # Lists to hold encrypted components for batch insertion
    batch_encrypted_vectors_list = []
    batch_ivs_list = []
    batch_auth_hashes_list = []
    batch_doc_ids_list = []
    batch_titles_list = []
    batch_texts_list = []

    for i in tqdm(range(0, len(doc_ids), BATCH_SIZE), desc=f"Ingesting {dataset_name}"):
        batch_current_doc_ids = doc_ids[i:i + BATCH_SIZE]
        batch_contents = documents_to_embed[i:i + BATCH_SIZE]
        batch_titles = original_titles[i:i + BATCH_SIZE]
        batch_texts = original_texts[i:i + BATCH_SIZE]

        # Generate original embeddings
        original_batch_vectors = embedding_model.encode(batch_contents, convert_to_list=True)

        # Encrypt each vector in the batch
        encrypted_batch_vectors = []
        encrypted_batch_ivs = []
        encrypted_batch_auth_hashes = []

        for vec in original_batch_vectors:
           encrypted_result = dcpe.encrypt_vector(DCPE_KEY, vec.tolist(), DCPE_APPROXIMATION_FACTOR)
           encrypted_batch_vectors.append(encrypted_result.ciphertext)
           encrypted_batch_ivs.append(base64.b64encode(encrypted_result.iv).decode('utf-8'))
           encrypted_batch_auth_hashes.append(base64.b64encode(encrypted_result.auth_hash.get_bytes()).decode('utf-8'))

        # Prepare entities for insertion
        entities = []
        for j in range(len(batch_current_doc_ids)):
            entities.append({
                "doc_id": batch_current_doc_ids[j],
                "title": batch_titles[j],
                "text": batch_texts[j],
                "encrypted_vector": encrypted_batch_vectors[j], # Use encrypted vector
                "iv": encrypted_batch_ivs[j],
                "auth_hash": encrypted_batch_auth_hashes[j]
            })

        zilliz_client.insert(collection_name=collection_name, data=entities)

    end_total_ingestion = time.perf_counter()
    total_ingestion_time = end_total_ingestion - start_total_ingestion
    print(f"Finished inserting encrypted data for '{dataset_name}'. Total ingestion time: {total_ingestion_time:.2f} seconds")

    # Estimated size calculation to account for VARCHAR storage
    # Encrypted vector: DIMENSION * 4 bytes (float32)
    # IV (Base64): 12 bytes raw -> ~16-24 chars, let's assume average 20 bytes/char for VARCHAR storage if not compressed.
    # Auth Hash (Base64): 32 bytes raw -> ~44-48 chars, assume average 46 bytes/char.
    # A VARCHAR(64) field will consume more than just raw byte size, depending on encoding.
    # For simplicity in estimation, we'll use the max_length for VARCHAR fields.
    estimated_single_record_size_bytes = (DIMENSION * 4) + (64) + (64) # Vector + max_length for IV + max_length for AuthHash
    estimated_total_embedding_size_bytes = num_documents * estimated_single_record_size_bytes
    estimated_embedding_size_mb = estimated_total_embedding_size_bytes / (1024 * 1024)
    print(f"Estimated encrypted embedding data size for '{dataset_name}': {estimated_embedding_size_mb:.2f} MB (including IV/AuthHash as VARCHAR)")

    ingestion_metrics[dataset_name] = {
        "total_ingestion_time_seconds": total_ingestion_time,
        "estimated_embedding_size_mb": estimated_embedding_size_mb,
        "num_documents": num_documents
    }

    print(f"Creating HNSW index for '{collection_name}'...")
    index_params = zilliz_client.prepare_index_params()
    # Index on 'encrypted_vector'
    index_params.add_index(field_name="encrypted_vector", index_type="HNSW", metric_type="COSINE")
    zilliz_client.create_index(collection_name=collection_name, index_params=index_params)
    print(f"HNSW index created for '{collection_name}'.")

    print(f"Loading collection '{collection_name}' into memory...")
    zilliz_client.load_collection(collection_name=collection_name)
    print(f"Collection '{collection_name}' loaded successfully.")

# Running queries

In [None]:
QUERY_BATCH_SIZE = 10 # Process queries in batches for embedding and searching - 10 is the max allowed by Zilliz
MAX_SEARCH_RESULTS = 100 # Retrieve top 100 for max k_value

In [None]:
def get_zilliz_search_results_dcpe(
    zilliz_client: MilvusClient,
    embedding_model: SentenceTransformer,
    collection_name: str,
    queries: dict,
    dcpe_key: dcpe.DCPEKey, # DCPE key
    dcpe_approx_factor: float, # DCPE approximation factor
    top_k: int = MAX_SEARCH_RESULTS
) -> tuple[dict, list]:
    print(f"Retrieving results from '{collection_name}' for {len(queries)} queries, top_k={top_k}...")

    search_results = {}
    query_ids = list(queries.keys())
    query_texts = [queries[qid] for qid in query_ids]

    all_query_latencies_ms = []

    for i in tqdm(range(0, len(query_ids), QUERY_BATCH_SIZE), desc=f"Searching {collection_name}"):
        batch_query_ids = query_ids[i:i + QUERY_BATCH_SIZE]
        batch_query_texts = query_texts[i:i + QUERY_BATCH_SIZE]

        start_query_embedding = time.perf_counter()
        original_batch_vectors = embedding_model.encode(batch_query_texts, convert_to_list=True)
        # MODIFIED: Encrypt query vectors - apply .tolist() to each vector before encryption
        encrypted_batch_query_vectors = [
            dcpe.encrypt_vector(dcpe_key, vec.tolist(), dcpe_approx_factor).ciphertext
            for vec in original_batch_vectors
        ]
        end_query_embedding = time.perf_counter()
        query_embedding_time_ms = (end_query_embedding - start_query_embedding) * 1000

        search_params = {
            "metric_type": "COSINE",
            "params": {"ef": max(top_k, 128)}
        }

        start_search_call = time.perf_counter()
        try:
            # Search on 'encrypted_vector' field
            hits_per_query = zilliz_client.search(
                collection_name=collection_name,
                data=encrypted_batch_query_vectors, # Use encrypted vectors for search
                limit=top_k,
                output_fields=["doc_id"],
                search_params=search_params
            )
            end_search_call = time.perf_counter()
            search_call_duration_ms = (end_search_call - start_search_call) * 1000

            if batch_query_ids:
                latency_per_query_in_batch = search_call_duration_ms / len(batch_query_ids)
                all_query_latencies_ms.extend([latency_per_query_in_batch] * len(batch_query_ids))

            for q_idx, query_id in enumerate(batch_query_ids):
                query_hits = {}
                for hit in hits_per_query[q_idx]:
                    query_hits[str(hit.id)] = float(hit.distance)
                search_results[str(query_id)] = query_hits

        except Exception as e:
            print(f"Error during search for batch starting with {batch_query_ids[0]}: {e}")
            continue

    print(f"Finished retrieving results from '{collection_name}'.")
    return search_results, all_query_latencies_ms

In [None]:
# Define the list of k-values for evaluation
k_values = [1, 3, 5, 10, 50, 100]

# List to store results for all datasets
all_evaluation_results = []

## Evaluating Results

In [None]:
!pip install -q psutil memory_profiler pyRAPL

In [None]:
from contextlib import contextmanager
import psutil, os, time, tracemalloc

@contextmanager
def resource_probe(label):
    proc = psutil.Process(os.getpid())
    cpu_t0 = proc.cpu_times().user + proc.cpu_times().system
    mem0 = proc.memory_info().rss
    tracemalloc.start()
    t0 = time.perf_counter()
    yield
    elapsed = time.perf_counter() - t0
    mem_peak = max(stat.size for stat in tracemalloc.take_snapshot().statistics('filename')) if tracemalloc.is_tracing() else 0
    tracemalloc.stop()
    cpu_t1 = proc.cpu_times().user + proc.cpu_times().system
    mem1 = proc.memory_info().rss
    cpu_sec = cpu_t1 - cpu_t0
    delta_rss = (mem1 - mem0) / 1e6
    peak_mb = mem_peak / 1e6
    print(f"[{label}] wall={elapsed*1e3:.2f} ms | CPU={cpu_sec:.3f}s | Î”RSS={delta_rss:.1f} MB | peak={peak_mb:.1f} MB")

In [None]:
# Loop through each dataset
for dataset_name in subset_beir_data:
    print(f"\nStarting evaluation for dataset: {dataset_name}")

    qrels = {
        str(query_id): {str(doc_id): int(score) for doc_id, score in doc_scores_dict.items()}
        for query_id, doc_scores_dict in subset_beir_data[dataset_name]["qrels"].items()
    }
    queries = subset_beir_data[dataset_name]["queries"]

    # Skip evaluation if there are no queries or qrels in the subset
    if not queries or not qrels:
        print(f"  Skipping evaluation for '{dataset_name}' as the query or qrels subset is empty.")
        continue

    collection_name = f"subset_beir_{dataset_name.replace('-', '_')}_dcpe"

    print(f"Retrieving search results from Zilliz for '{dataset_name}'...")

    # Wrap query retrieval in resource probe
    with resource_probe(f"{dataset_name}_query_eval"):
        results, query_latencies_ms = get_zilliz_search_results_dcpe(
            zilliz_client=zilliz_client,
            embedding_model=embedding_model,
            collection_name=collection_name,
            queries=queries,
            dcpe_key=DCPE_KEY,
            dcpe_approx_factor=DCPE_APPROXIMATION_FACTOR,
            top_k=max(k_values)
        )

    if not results:
        print(f"No results retrieved for '{dataset_name}'. Skipping evaluation for this dataset.")
        continue

    print(f"Retrieved {len(results)} queries' results for '{dataset_name}'.")

    avg_search_latency_ms = 0
    p90_search_latency_ms = 0
    p99_search_latency_ms = 0

    if query_latencies_ms:
        avg_search_latency_ms = np.mean(query_latencies_ms)
        p90_search_latency_ms = np.percentile(query_latencies_ms, 90)
        p99_search_latency_ms = np.percentile(query_latencies_ms, 99)

    print(f"  Avg Search Latency: {avg_search_latency_ms:.2f} ms")
    print(f"  P90 Search Latency: {p90_search_latency_ms:.2f} ms")
    print(f"  P99 Search Latency: {p99_search_latency_ms:.2f} ms")

    # Save only operational metrics
    dataset_results = {
        "Dataset": f"{dataset_name}_subset",
        "Avg_Search_Latency_ms": avg_search_latency_ms,
        "P90_Search_Latency_ms": p90_search_latency_ms,
        "P99_Search_Latency_ms": p99_search_latency_ms,
        "Total_Ingestion_Time_s": ingestion_metrics[dataset_name]["total_ingestion_time_seconds"],
        "Estimated_Embedding_Size_MB": ingestion_metrics[dataset_name]["estimated_embedding_size_mb"],
        "Num_Documents": ingestion_metrics[dataset_name]["num_documents"]
    }
    all_evaluation_results.append(dataset_results)

    print(f"Evaluation for '{dataset_name}' complete.")


# Visualizing the results

In [None]:
# --- Flatten and build DataFrame (resource metrics only) ---
flattened_results = []
for dataset_result in all_evaluation_results:
    row = {
        "Dataset": dataset_result["Dataset"],
        "Avg_Search_Latency_ms": dataset_result["Avg_Search_Latency_ms"],
        "P90_Search_Latency_ms": dataset_result["P90_Search_Latency_ms"],
        "P99_Search_Latency_ms": dataset_result["P99_Search_Latency_ms"],
        "Total_Ingestion_Time_s": dataset_result["Total_Ingestion_Time_s"],
        "Estimated_Embedding_Size_MB": dataset_result["Estimated_Embedding_Size_MB"],
        "Num_Documents": dataset_result["Num_Documents"]
    }
    flattened_results.append(row)

results_df = pd.DataFrame(flattened_results)

# Define operational columns
operational_columns = [
    "Total_Ingestion_Time_s", "Estimated_Embedding_Size_MB",
    "Avg_Search_Latency_ms", "P90_Search_Latency_ms", "P99_Search_Latency_ms",
    "Num_Documents"
]
ordered_columns = ['Dataset'] + operational_columns
results_df = results_df[ordered_columns]

print("\n--- Consolidated Resource Utilization Results (DCPE Encrypted) ---")
print(results_df.to_string())

# --- Resource Utilization Visualization (Bar Plots) ---
operational_plot_data = []
for dataset_result in all_evaluation_results:
    operational_plot_data.append({
        "Dataset": dataset_result["Dataset"],
        "Total Ingestion Time (s)": dataset_result["Total_Ingestion_Time_s"],
        "Estimated Embedding Size (MB)": dataset_result["Estimated_Embedding_Size_MB"],
        "Avg Search Latency (ms)": dataset_result["Avg_Search_Latency_ms"],
        "P90 Search Latency (ms)": dataset_result["P90_Search_Latency_ms"],
        "P99 Search Latency (ms)": dataset_result["P99_Search_Latency_ms"]
    })

operational_df = pd.DataFrame(operational_plot_data)

if not operational_df.empty:
    plt.figure(figsize=(20, 7))

    plt.subplot(1, 3, 1)
    sns.barplot(x="Dataset", y="Total Ingestion Time (s)", data=operational_df, palette="crest")
    plt.title("Total Ingestion Time", fontsize=14)
    plt.ylabel("Seconds", fontsize=10)
    plt.xlabel("")

    plt.subplot(1, 3, 2)
    sns.barplot(x="Dataset", y="Estimated Embedding Size (MB)", data=operational_df, palette="mako")
    plt.title("Embedding Size", fontsize=14)
    plt.ylabel("MB", fontsize=10)
    plt.xlabel("")

    plt.subplot(1, 3, 3)
    latency_melted_df = operational_df.melt(id_vars=['Dataset'],
                                            value_vars=["Avg Search Latency (ms)", "P90 Search Latency (ms)", "P99 Search Latency (ms)"],
                                            var_name="Latency Type", value_name="Latency (ms)")
    sns.barplot(x="Latency Type", y="Latency (ms)", hue="Dataset", data=latency_melted_df, palette="flare")
    plt.title("Query Latency", fontsize=14)
    plt.ylabel("ms", fontsize=10)
    plt.xlabel("")
    plt.xticks(rotation=20, ha='right')
    plt.legend(title="Dataset", loc='best')

    plt.tight_layout()
    plt.show()
else:
    print("No operational metrics data available for plotting.")