# MAIR Dataset Download and Organization

## Download MAIR Datasets and Queries from Hugging Face to Google Drive

In [None]:
# 1Ô∏è‚É£ Mount Google Drive to a new empty folder
# This step allows Colab to access files stored in your Google Drive.
# A new directory '/content/gdrive' is created, which will contain your Drive files.
# If you've previously mounted your Drive, Colab might remember the authentication.
from google.colab import drive
drive.mount('/content/gdrive')  # Use a new mount point to avoid conflicts

# 2Ô∏è‚É£ Install Hugging Face hub CLI
# The Hugging Face Hub Command Line Interface (CLI) is used to interact with the Hugging Face Hub,
# which hosts various datasets and models. We upgrade it to ensure we have the latest features.
# This installation only needs to be done once per Colab session.
!pip install huggingface-hub --upgrade

# 3Ô∏è‚É£ Set local directories in your Google Drive
# These variables define the paths within your Google Drive where the downloaded datasets will be stored.
# You can customize these paths if you prefer a different location within your 'My Drive' folder.
# For example: "/content/drive/MyDrive/MyProject/MAIR_Datasets/MAIR-Queries"
# ‚û°Ô∏è USER CONFIGURATION: Modify these paths if you want to store the data elsewhere in your Google Drive.
queries_dir = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Queries"
docs_dir = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Docs"

# 4Ô∏è‚É£ Download datasets using HF CLI
# This command uses the Hugging Face CLI to download two datasets from the 'MAIR-Bench' organization:
# - 'MAIR-Bench/MAIR-Queries': Contains query data for various benchmarks.
# - 'MAIR-Bench/MAIR-Docs': Contains document data that corresponds to the queries.
# The '--repo-type dataset' flag specifies that we are downloading datasets.
# The '--local-dir' flag specifies the local path where the datasets will be saved within your mounted Google Drive.
# This might take some time depending on your internet connection and the size of the datasets.
# Progress will be displayed in the output.
!hf download MAIR-Bench/MAIR-Queries --repo-type dataset --local-dir "{queries_dir}"
!hf download MAIR-Bench/MAIR-Docs --repo-type dataset --local-dir "{docs_dir}"

print("‚úÖ Download completed!")

## Combine MAIR Docs and Queries from Google Drive into a Single Directory

In [None]:
from google.colab import drive
import os
import shutil

# 1Ô∏è‚É£ Mount Google Drive
# This step allows Colab to access files stored in your Google Drive.
# 'force_remount=True' ensures that the drive is always mounted, even if it was previously mounted,
# which can be useful if you restart the Colab runtime.
drive.mount('/content/drive', force_remount=True)

# 2Ô∏è‚É£ Define paths for MAIR datasets and the combined output
# These paths specify where the raw downloaded data is and where the processed, combined data will be saved.
# You can customize SAVE_PATH to store the combined datasets in a different location in your Google Drive.
MAIR_DOCS_PATH = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Docs"
MAIR_QUERIES_PATH = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Queries"
SAVE_PATH = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Combined"

# 3Ô∏è‚É£ Create the save folder if it doesn't already exist
# os.makedirs creates directories recursively. 'exist_ok=True' prevents an error if the directory already exists.
# This ensures the target directory for combined datasets is ready.
os.makedirs(SAVE_PATH, exist_ok=True)

# 4Ô∏è‚É£ List subfolders (individual datasets) within the Docs and Queries paths
# This identifies each specific MAIR dataset (e.g., 'TREC_DL_2019', 'NQ') present in both locations.
# It ensures we only process datasets for which both documents and queries are available.
docs_folders = [f for f in os.listdir(MAIR_DOCS_PATH) if os.path.isdir(os.path.join(MAIR_DOCS_PATH, f))]
queries_folders = [f for f in os.listdir(MAIR_QUERIES_PATH) if os.path.isdir(os.path.join(MAIR_QUERIES_PATH, f))]

# 5Ô∏è‚É£ Find common datasets that exist in both Docs and Queries folders
# We only combine datasets for which both document and query data are available.
# This prevents errors from incomplete datasets. The list is sorted for consistent processing.
common_datasets = sorted(list(set(docs_folders) & set(queries_folders)))
print("Common MAIR Datasets found:", common_datasets)

# 6Ô∏è‚É£ Iterate through each common dataset and copy its contents
# This loop processes each dataset found in both 'MAIR-Docs' and 'MAIR-Queries',
# creating a unified structure in the 'MAIR-Combined' folder.
# Each dataset gets its own folder containing 'docs' and 'queries' subdirectories.
for dataset in common_datasets:
    # Define the target path for the current combined dataset
    dataset_save_path = os.path.join(SAVE_PATH, dataset)
    # Create the dataset-specific directory within the combined save path
    os.makedirs(dataset_save_path, exist_ok=True)

    # Copy docs folder
    src_docs = os.path.join(MAIR_DOCS_PATH, dataset) # Source path for documents of the current dataset
    dst_docs = os.path.join(dataset_save_path, "docs") # Destination path for documents within the combined folder
    # If the destination docs folder already exists, remove it to ensure a clean copy
    if os.path.exists(dst_docs):
        shutil.rmtree(dst_docs)
    # Copy the entire directory tree from source to destination, including all files and subfolders.
    print(f"Copying documents for {dataset} from {src_docs} to {dst_docs}")
    shutil.copytree(src_docs, dst_docs)

    # Copy queries folder
    src_queries = os.path.join(MAIR_QUERIES_PATH, dataset) # Source path for queries of the current dataset
    dst_queries = os.path.join(dataset_save_path, "queries") # Destination path for queries within the combined folder
    # If the destination queries folder already exists, remove it to ensure a clean copy
    if os.path.exists(dst_queries):
        shutil.rmtree(dst_queries)
    # Copy the entire directory tree from source to destination.
    print(f"Copying queries for {dataset} from {src_queries} to {dst_queries}")
    shutil.copytree(src_queries, dst_queries)

print(f"\n‚úÖ All common datasets copied successfully with their original 'docs' and 'queries' subfolders to:\n{SAVE_PATH}")

# Vector (Binary) Benchmarking Based on MAIR Dataset

## Vector (Binary) Search in Moorcheh + Pinecone (with Cohere) + Elasticsearch

In [None]:
# ============================================================
# BEIR + MAIR Benchmark - Binary Embeddings Only
# Moorcheh vs Pinecone vs Elasticsearch Vector Comparison
# Sign-based 1-bit Binarization (>=0 -> 1, <0 -> 0)
# ============================================================

# -------------------- 1. Install Necessary Libraries --------------------
# This section ensures all required Python packages are installed in the Colab environment.
# The installation might take a few moments, especially on the first run.

# `beir`: A comprehensive benchmarking framework for information retrieval tasks.
# `moorcheh-sdk`: The official SDK for interacting with the Moorcheh vector database.
# `cohere`: Used for generating high-quality text embeddings and potentially reranking results.
# `pinecone`: The client library for connecting to the Pinecone vector database.
# `elasticsearch`: The client library for interacting with Elasticsearch, used here for vector search capabilities.
# `numpy`: A fundamental library for numerical computing in Python, essential for handling embedding arrays.
!pip install beir moorcheh-sdk cohere pinecone elasticsearch numpy

import os
import gc
import time
import statistics
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval

# -------------------- 2. User / Environment Setup --------------------
# This section configures paths and retrieves API keys, adapting to either Google Colab or a local environment.

# DRIVE_PATH: Defines where benchmark results will be saved. Default is the current directory.
# If running in Google Colab, it will be automatically updated to a path in your mounted Google Drive.
DRIVE_PATH = "."

# MAIR_COMBINED_PATH: **User Customizable Path**
# This is the path to your combined MAIR datasets in Google Drive. It should match the SAVE_PATH from the
# 'Combine MAIR Docs and Queries' step. Ensure this path is correct for your setup.
MAIR_COMBINED_PATH = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Combined"

# BEIR_PATH: Local directory where BEIR datasets will be downloaded and stored.
# You typically don't need to change this unless you want BEIR datasets stored elsewhere locally.
BEIR_PATH = "./datasets"

# Initialize an empty dictionary to store API keys and configurations for various services.
api_keys = {}

try:
    # This block attempts to configure for Google Colab, leveraging its `drive` and `userdata` (Secrets) features.
    from google.colab import drive, userdata as colab_userdata

    try:
        # Mount Google Drive to allow Colab to access your files. If you've mounted it recently,
        # Colab might remember the authentication. `force_remount=True` ensures a fresh mount if needed.
        drive.mount('/content/gdrive') # Using /content/gdrive as the mount point
        print("‚úÖ Google Drive mounted successfully at /content/gdrive.")
    except Exception as e:
        print(f"‚ö†Ô∏è Drive mount attempt raised: {e}. Continuing without Drive mount, results will be saved locally if Drive path is not writable.")

    # DRIVE_PATH: **User Customizable Path**
    # If running in Colab, results will be saved here within your Google Drive.
    # You can customize this path (e.g., '/content/gdrive/MyDrive/MyProject/BenchmarkResults')
    # to organize your benchmark outputs effectively.
    DRIVE_PATH = '/content/gdrive/MyDrive/Moorcheh/Benchmark_Results/Pinecone.Binary'
    os.makedirs(DRIVE_PATH, exist_ok=True) # Creates the directory if it doesn't exist.
    print(f"‚úÖ Running in Colab. Results will be saved to: {DRIVE_PATH}")

    # Retrieve API keys from Colab secrets. **User Action Required**:
    # To use this feature, you MUST add your API keys to Colab's "Secrets" panel.
    # Look for the key icon (üîë) on the left sidebar of your Colab notebook.
    # Name your secrets EXACTLY as follows:
    # - `MOORCHEH_API_KEY` for Moorcheh
    # - `COHERE_API_KEY` for Cohere (essential for embedding generation)
    # - `PINECONE_API_KEY` or `PINECONE_API_KEY2` for Pinecone
    # - `ELASTIC_URL` for Elasticsearch endpoint (e.g., 'https://your-es-cluster.es.io:9243')
    # - `ELASTIC_API_KEY` for Elasticsearch API Key (preferred) OR
    # - `ELASTIC_USERNAME` and `ELASTIC_PASSWORD` for Elasticsearch Basic Auth.
    api_keys = {
        'moorcheh': colab_userdata.get('MOORCHEH_API_KEY'),
        'cohere': colab_userdata.get('COHERE_API_KEY'),
        'pinecone': colab_userdata.get('PINECONE_API_KEY') or colab_userdata.get('PINECONE_API_KEY2'), # Fallback for Pinecone
        'elasticsearch': {
            'url': colab_userdata.get('ELASTIC_URL'),
            'api_key': colab_userdata.get('ELASTIC_API_KEY'),
            'username': colab_userdata.get('ELASTIC_USERNAME'),
            'password': colab_userdata.get('ELASTIC_PASSWORD')
        }
    }
    # If any Elasticsearch credential is null, set the entire elasticsearch dict to None to indicate missing config.
    # This simplifies checks later for whether Elasticsearch is configured.
    if not any(api_keys['elasticsearch'].values()):
        api_keys['elasticsearch'] = None

except ImportError:
    # This block executes if not in Google Colab (e.g., a local Python environment).
    DRIVE_PATH = "." # Results will be saved in the current working directory.
    # In a local environment, API keys are typically read from environment variables.
    # **User Action Required**: Set these environment variables before running the script.
    # - `MOORCHEH_API_KEY`
    # - `COHERE_API_KEY`
    # - `PINECONE_API_KEY`
    # - `ELASTIC_URL` (e.g., "http://localhost:9200" for local ES)
    # - `ELASTIC_API_KEY` OR `ELASTIC_USERNAME`, `ELASTIC_PASSWORD`
    api_keys = {
        'moorcheh': os.environ.get('MOORCHEH_API_KEY'),
        'cohere': os.environ.get('COHERE_API_KEY'),
        'pinecone': os.environ.get('PINECONE_API_KEY'),
        'elasticsearch': {
            'url': os.environ.get('ELASTIC_URL') or "http://localhost:9200", # Defaults to localhost for local setup
            'api_key': os.environ.get('ELASTIC_API_KEY'),
            'username': os.environ.get('ELASTIC_USERNAME') or "elastic",
            'password': os.environ.get('ELASTIC_PASSWORD')
        }
    }
    # Similar to Colab, if ES config is incomplete, mark it as None.
    if not any(api_keys['elasticsearch'].values()):
        api_keys['elasticsearch'] = None
    print("‚ö†Ô∏è Not running in Google Colab. Saving results locally. Ensure environment variables are set.")

# -------------------- 3. General Benchmark Configuration --------------------
# These parameters control various aspects of the benchmark. Users can adjust these values
# to customize the benchmark's behavior, performance, and resource usage.

# BEIR_DATASETS_SORTED_DISPLAY: A list of BEIR dataset names formatted for user-friendly display.
# This list includes estimated corpus sizes to help users select appropriate datasets.
BEIR_DATASETS_SORTED_DISPLAY = [
    "1. nfcorpus (3,633)",
    "2. scifact (5,183)",
    "3. arguana (8,674)",
    "4. scidocs (25,657)",
    "5. fiqa (57,638)",
    "6. trec-covid (171,332)",
    "7. webis-touche2020 (382,545)",
    "8. quora (522,931)",
]

# BEIR_DATASETS: The actual programmatic names of BEIR datasets used by the `beir` library.
# These correspond to the display names above.
BEIR_DATASETS = [
    "nfcorpus", "scifact", "arguana", "scidocs",
    "fiqa", "trec-covid", "webis-touche2020", "quora",
]

# TOP_K_SEARCH: **User Customizable Value**
# The number of top-ranked results to retrieve from the vector database for each query.
# A higher value might improve recall but generally increases search latency and resource usage.
TOP_K_SEARCH = 100

# K_VALUES: **User Customizable Value**
# A list of 'k' values at which retrieval metrics (NDCG, MAP, Recall, Precision) will be calculated.
# These values define the cut-off points for evaluation (e.g., NDCG@1, MAP@10, Recall@100).
# You can add or remove values based on your evaluation needs.
K_VALUES = [1, 3, 5, 10, 100]

# DATA_ROOT: The local directory where BEIR datasets will be downloaded. This path is relative.
DATA_ROOT = "./datasets"

# MAX_UPLOAD_DOCS: **User Customizable Value**
# Limits the number of documents uploaded to vector databases. This is crucial for managing costs
# and execution time, especially with very large datasets. Set to a lower number (e.g., 10000)
# for quick tests or a very high number (e.g., 700000) for comprehensive runs. Set to `None`
# or a number greater than your dataset size to upload all documents.
MAX_UPLOAD_DOCS = 700000 # Set to 700000 to process up to 700K documents; adjust as needed.

# BATCH_SIZE: **User Customizable Value**
# The number of embeddings to process or upload in a single API request/batch.
# Adjusting this value can significantly impact performance, memory usage, and API rate limits.
# Larger batches are generally faster due to reduced overhead but consume more RAM.
BATCH_SIZE = 100

# EMBEDDING_MODEL: **User Customizable Value**
# The Cohere model used to generate the initial dense float embeddings.
# 'embed-v4.0' is recommended for its performance and higher dimensionality. Other options include
# 'embed-english-v3.0', 'embed-multilingual-v3.0', etc. Changing this will require adjusting VECTOR_DIMENSION.
EMBEDDING_MODEL = "embed-v4.0"

# INPUT_TYPE_CORPUS: Specifies the input type for corpus documents to Cohere's embedding model.
# This helps Cohere optimize embedding generation for different content types (e.g., 'search_document').
INPUT_TYPE_CORPUS = "search_document"

# INPUT_TYPE_QUERY: Specifies the input type for queries to Cohere's embedding model.
# Similar to corpus input type, this optimizes query embedding generation (e.g., 'search_query').
INPUT_TYPE_QUERY = "search_query"

# VECTOR_DIMENSION: **User Customizable Value (Must match EMBEDDING_MODEL)**
# The dimensionality of the generated embeddings. This value is critical and MUST match the output
# dimension of the `EMBEDDING_MODEL` you choose. Cohere's 'embed-v4.0' has 1536 dimensions;
# 'embed-v3.0' has 1024 dimensions. Incorrect dimension will lead to errors in vector databases.
VECTOR_DIMENSION = 1536

# DATASET_SOURCE: This variable determines whether to use BEIR or MAIR datasets.
# It will be set interactively by the user later in the script.
DATASET_SOURCE = "beir"

# CSV_PATH: **User Customizable Path**
# The full path where the final benchmark results will be saved in CSV format.
# This file will be created or appended to in your Google Drive (or local directory).
CSV_PATH = os.path.join(DRIVE_PATH, "BEIR.MAIR.Binary.Embeddings.All.Providers.csv")

# -------------------- 4. MAIR Dataset Helper Functions --------------------
# These functions are specifically designed to assist in loading and processing
# MAIR datasets, handling their diverse structure and relevance judgments (qrels).

def load_jsonl(filepath):
    """Loads data from a JSONL (JSON Lines) file into a dictionary.
    It intelligently identifies various common ID field names to create a mapping
    from item ID to its content. This is flexible for MAIR's varied formats.
    """
    data = {}
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip(): # Ensures that empty lines are skipped.
                    item = json.loads(line)
                    # Attempts to find a unique identifier for the item (document, query, etc.).
                    # Prioritizes common ID fields: '_id', 'id', 'query_id', 'doc_id'.
                    item_id = item.get('_id') or item.get('id') or item.get('query_id') or item.get('doc_id')
                    if item_id:
                        data[str(item_id)] = item # Stores with stringified ID for consistent keying.
                    else:
                        # Fallback: if no standard ID is found, use an incremental integer.
                        # This is less ideal but prevents data loss for malformed entries.
                        data[str(len(data))] = item
        print(f"   Loaded {len(data)} items from {os.path.basename(filepath)}")
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading {filepath}: {e}") # Logs specific file loading errors.
    return data

def extract_qrels_from_queries(queries):
    """Extracts Qrels (query-relevance judgments) from query data structures.
    This function is tailored for MAIR datasets where qrels might be embedded
    within query files, often in fields like 'labels', 'relevance', or 'qrels'.
    It supports different formats for these fields (list of dicts, list of IDs, or dict).
    """
    qrels = {}
    for qid, q in queries.items():
        if isinstance(q, dict):
            # Checks for common field names that indicate relevance judgments.
            labels = q.get('labels') or q.get('relevance') or q.get('qrels')
            if labels:
                qrels[str(qid)] = {} # Initializes the qrels dictionary for the current query ID.
                if isinstance(labels, list):
                    # Handles cases where labels are a list (e.g., [{'id': 'doc1', 'score': 1}]).
                    for label_item in labels:
                        if isinstance(label_item, dict):
                            doc_id = label_item.get('id') or label_item.get('doc_id')
                            score = label_item.get('score', 1) # Defaults score to 1 if not specified.
                            if doc_id:
                                qrels[str(qid)][str(doc_id)] = score # Stores doc_id and its relevance score.
                        else:
                            # Handles simple lists of doc_ids (assumes a default relevance score of 1).
                            qrels[str(qid)][str(label_item)] = 1
                elif isinstance(labels, dict):
                    # Handles dictionaries where keys are doc_ids and values are scores.
                    for doc_id, score in labels.items():
                        qrels[str(qid)][str(doc_id)] = score
    return qrels

def get_mair_datasets():
    """Discovers and categorizes MAIR datasets available in the specified combined path.
    It also estimates the number of documents in each dataset by counting lines in JSONL files,
    providing a comprehensive overview for user selection.
    """
    # Predefined categories for MAIR datasets, used to organize their display to the user.
    # This mapping helps in presenting a structured and navigable list of available datasets.
    DATASET_CATEGORIES = {
        "Legal & Regulatory": ["ACORDAR", "AILA2019-Case", "AILA2019-Statutes", "CUAD", "LeCaRDv2", "LegalQuAD"],
        "Medical & Clinical": ["CliniDS-2014", "CliniDS-2015", "CliniDS-2016", "ClinicalTrials-2021", "NFCorpus"],
        "Code & Programming": ["APPS", "CodeEditSearch", "CodeSearchNet", "Conala", "LeetCode", "MBPP"],
        "Financial": ["ConvFinQA", "FiQA", "FinQA", "FinanceBench"],
        "Academic & Scientific": ["ArguAna", "LitSearch", "ProofWiki-Proof", "Competition-Math"],
        "Conversational & Dialog": ["CAsT-2019", "CAsT-2020", "CAsT-2021", "ProCIS-Dialog", "SParC", "Quora"],
        "News & Social Media": ["ChroniclingAmericaQA", "Microblog-2011", "Microblog-2012", "News21"],
        "API Documentation": ["Apple", "FoodAPI", "HuggingfaceAPI", "PytorchAPI"],
        "Others": ["BSARD", "BillSum", "CARE", "CPCD", "CQADupStack", "DD", "ELI5"]
    }

    datasets_by_category = {category: [] for category in DATASET_CATEGORIES.keys()}
    all_datasets = []
    dataset_sizes = {}

    # Verifies that the MAIR_COMBINED_PATH exists before attempting to scan.
    # This path is configured in the "User / Environment Setup" section.
    if os.path.exists(MAIR_COMBINED_PATH):
        print(f"üîç Scanning MAIR datasets in: {MAIR_COMBINED_PATH}")
        try:
            # Iterates through each item (directory) within the combined MAIR directory.
            # Items are sorted alphabetically for consistent display.
            for item in sorted(os.listdir(MAIR_COMBINED_PATH)): # Sort for consistent order
                if item.startswith('.'): # Skips hidden files/directories (e.g., .ipynb_checkpoints).
                    continue
                dataset_path = os.path.join(MAIR_COMBINED_PATH, item)
                if not os.path.isdir(dataset_path): # Confirms the item is a directory (representing a dataset).
                    continue
                docs_path = os.path.join(dataset_path, 'docs')
                queries_path = os.path.join(dataset_path, 'queries')
                # Checks for the existence of both 'docs' and 'queries' subdirectories,
                # and at least one JSONL file in each, to ensure a complete dataset for benchmarking.
                if os.path.exists(docs_path) and os.path.exists(queries_path):
                    if any(f.endswith('.jsonl') for f in os.listdir(docs_path)) and any(f.endswith('.jsonl') for f in os.listdir(queries_path)):
                        all_datasets.append(item)
                        # Assigns the dataset to its predefined category for organized display.
                        category = "Others"
                        for cat, datasets in DATASET_CATEGORIES.items():
                            if item in datasets:
                                category = cat
                                break
                        datasets_by_category[category].append(item)
                        # Counts documents by counting non-empty lines in all JSONL files within the 'docs' folder.
                        doc_count = 0
                        for file in os.listdir(docs_path):
                            if file.endswith('.jsonl'):
                                with open(os.path.join(docs_path, file), 'r') as f:
                                    doc_count += sum(1 for line in f if line.strip()) # Counts non-empty lines.
                        dataset_sizes[item] = doc_count
                        print(f"   ‚úÖ Found MAIR dataset: {item} - {doc_count} docs")
        except Exception as e:
            print(f"   ‚ùå Error during MAIR dataset scanning: {e}")
    else:
        print(f"‚ö†Ô∏è MAIR combined path not found: {MAIR_COMBINED_PATH}. Please ensure the previous step to combine datasets was executed successfully.")

    return all_datasets, datasets_by_category, dataset_sizes

# -------------------- 5. Binarization and Utility Functions --------------------
# These functions are central to the binary embedding benchmark, handling the
# conversion of dense float embeddings to their binary representation and providing
# common utilities for performance measurement and result management.

def binarize_embeddings(embeddings):
    """Binarizes float embeddings using a simple sign-based method:
    - Any embedding component (dimension) >= 0 becomes 1.
    - Any embedding component (dimension) < 0 becomes 0.
    The output is a NumPy array of float32 (either 0.0 or 1.0). This format is chosen
    to maintain compatibility with vector databases that primarily expect float vectors,
    while still representing the binary nature of the data. This allows for testing
    binary-like representations in systems not explicitly designed for bit vectors.
    """
    embeddings = np.array(embeddings, dtype=np.float32) # Ensures input is a NumPy float32 array.
    # Applies the binarization logic. The result is a boolean array, which is then
    # cast to float32 (True becomes 1.0, False becomes 0.0).
    binary = (embeddings >= 0).astype(np.float32)
    return binary

def save_binary_embeddings(doc_ids, corpus_embeddings_binary, query_ids, query_embeddings_binary, dataset_name, output_dir):
    """Saves the binarized embeddings to disk in multiple common formats.
    This allows users to inspect, share, or use the binarized data with other tools or for offline analysis.
    Supported formats include NumPy (.npy), CSV (.csv), JSON (.json), HDF5 (.h5), and Parquet (.parquet).
    """
    os.makedirs(output_dir, exist_ok=True) # Ensures the target output directory exists.

    print(f"\nüíæ Saving Binarized Embeddings for {dataset_name} to {output_dir}...")

    # ========== NumPy Format (.npy) ==========
    # NumPy's native format is highly efficient for storing and loading numerical arrays,
    # often resulting in the smallest file sizes and fastest I/O for array-based data.
    print(f"\n  1Ô∏è‚É£  Numpy Format (.npy)")
    corpus_npy_path = os.path.join(output_dir, f"{dataset_name}_corpus_binary.npy")
    query_npy_path = os.path.join(output_dir, f"{dataset_name}_query_binary.npy")

    np.save(corpus_npy_path, corpus_embeddings_binary)
    np.save(query_npy_path, query_embeddings_binary)
    print(f"     ‚úÖ Corpus: {corpus_npy_path} ({corpus_embeddings_binary.nbytes / (1024*1024):.2f} MB)")
    print(f"     ‚úÖ Query:  {query_npy_path} ({query_embeddings_binary.nbytes / (1024*1024):.2f} MB)")

    # ========== CSV Format (.csv) ==========
    # CSV is a widely compatible, human-readable format, easily opened in spreadsheet software.
    # Each vector's components are joined into a comma-separated string within a single cell.
    print(f"\n  2Ô∏è‚É£  CSV Format (.csv)")
    corpus_csv_path = os.path.join(output_dir, f"{dataset_name}_corpus_binary.csv")
    query_csv_path = os.path.join(output_dir, f"{dataset_name}_query_binary.csv")

    # Creates a Pandas DataFrame for corpus embeddings, converting each vector to a string.
    corpus_df = pd.DataFrame({
        'doc_id': doc_ids,
        'embedding_binary': [','.join(map(str, row)) for row in corpus_embeddings_binary]
    })
    corpus_df.to_csv(corpus_csv_path, index=False)
    print(f"     ‚úÖ Corpus: {corpus_csv_path}")

    # Creates a Pandas DataFrame for query embeddings.
    query_df = pd.DataFrame({
        'query_id': query_ids,
        'embedding_binary': [','.join(map(str, row)) for row in query_embeddings_binary]
    })
    query_df.to_csv(query_csv_path, index=False)
    print(f"     ‚úÖ Query:  {query_csv_path}")

    # ========== JSON Format (.json) ==========
    # JSON is a lightweight, human-readable data interchange format, widely used in web applications.
    print(f"\n  3Ô∏è‚É£  JSON Format (.json)")
    corpus_json_path = os.path.join(output_dir, f"{dataset_name}_corpus_binary.json")
    query_json_path = os.path.join(output_dir, f"{dataset_name}_query_binary.json")

    # Prepares corpus data as a dictionary: {doc_id: vector_list}.
    corpus_json = {
        'dataset': dataset_name,
        'type': 'corpus',
        'count': len(doc_ids),
        'dimension': VECTOR_DIMENSION,
        'embeddings': {doc_id: vec.tolist() for doc_id, vec in zip(doc_ids, corpus_embeddings_binary)}
    }

    # Prepares query data as a dictionary: {query_id: vector_list}.
    query_json = {
        'dataset': dataset_name,
        'type': 'query',
        'count': len(query_ids),
        'dimension': VECTOR_DIMENSION,
        'embeddings': {qid: vec.tolist() for qid, vec in zip(query_ids, query_embeddings_binary)}
    }

    with open(corpus_json_path, 'w') as f:
        json.dump(corpus_json, f) # Saves the corpus JSON data.
    print(f"     ‚úÖ Corpus: {corpus_json_path}")

    with open(query_json_path, 'w') as f:
        json.dump(query_json, f) # Saves the query JSON data.
    print(f"     ‚úÖ Query:  {query_json_path}")

    # ========== HDF5 Format (.h5) (if available) ==========
    # HDF5 is designed for storing and managing very large numerical datasets efficiently,
    # supporting compression and hierarchical data structures. Requires `h5py` library.
    try:
        import h5py
        print(f"\n  4Ô∏è‚É£  HDF5 Format (.h5)")
        h5_path = os.path.join(output_dir, f"{dataset_name}_binary.h5")

        with h5py.File(h5_path, 'w') as f:
            f.create_dataset('corpus_embeddings', data=corpus_embeddings_binary, compression='gzip') # Stores with gzip compression.
            f.create_dataset('query_embeddings', data=query_embeddings_binary, compression='gzip')
            # Stores IDs as fixed-length strings for compatibility with HDF5 datasets.
            f.create_dataset('doc_ids', data=np.array(doc_ids, dtype=h5py.string_dtype()), dtype=h5py.string_dtype())
            f.create_dataset('query_ids', data=np.array(query_ids, dtype=h5py.string_dtype()), dtype=h5py.string_dtype())
            f.attrs['dimension'] = VECTOR_DIMENSION # Adds metadata attributes.
            f.attrs['dataset'] = dataset_name

        print(f"     ‚úÖ HDF5: {h5_path} ({os.path.getsize(h5_path) / (1024*1024):.2f} MB)")
    except ImportError:
        print(f"     ‚ö†Ô∏è  HDF5 not available (install 'h5py' package to save in this format: `pip install h5py`)")

    # ========== Parquet Format (.parquet) (if available) ==========
    # Parquet is a columnar storage format, highly optimized for analytical queries and big data ecosystems.
    # It provides efficient compression and encoding schemes. Requires `pyarrow` and optionally `fastparquet`.
    try:
        print(f"\n  5Ô∏è‚É£  Parquet Format (.parquet)")
        corpus_parquet_path = os.path.join(output_dir, f"{dataset_name}_corpus_binary.parquet")
        query_parquet_path = os.path.join(output_dir, f"{dataset_name}_query_binary.parquet")

        # Creates Pandas DataFrames for Parquet, converting vectors to lists.
        corpus_parquet_df = pd.DataFrame({
            'doc_id': doc_ids,
            'embedding': [list(vec) for vec in corpus_embeddings_binary]
        })
        corpus_parquet_df.to_parquet(corpus_parquet_path, compression='snappy') # Uses Snappy compression.

        query_parquet_df = pd.DataFrame({
            'query_id': query_ids,
            'embedding': [list(vec) for vec in query_embeddings_binary]
        })
        query_parquet_df.to_parquet(query_parquet_path, compression='snappy')

        print(f"     ‚úÖ Corpus: {corpus_parquet_path}")
        print(f"     ‚úÖ Query:  {query_parquet_path}")
    except Exception as e:
        print(f"     ‚ö†Ô∏è  Parquet not available or failed to save (install 'pyarrow' and 'fastparquet', or check error: {e})")

    # ========== SUMMARY OF SAVED EMBEDDINGS ==========
    # Provides a brief summary of the saved binary embeddings and their respective formats.
    print(f"\nüìä Binary Embeddings Download Summary:")
    print(f"   Dataset: {dataset_name}")
    print(f"   Corpus documents: {len(doc_ids)}")
    print(f"   Query embeddings: {len(query_ids)}")
    print(f"   Dimension: {VECTOR_DIMENSION}")
    print(f"   Saved to directory: {output_dir}")
    print(f"\n   Format Summary (choose the format best suited for your needs):")
    print(f"   ‚Ä¢ .npy   - Fast loading with numpy (often smallest file size)")
    print(f"   ‚Ä¢ .csv   - Human-readable, spreadsheet compatible (vectors as strings)")
    print(f"   ‚Ä¢ .json  - Human-readable, portable (vectors as lists)")
    print(f"   ‚Ä¢ .h5    - HDF5 with compression (for large numerical datasets, if h5py installed)")
    print(f"   ‚Ä¢ .parquet - Columnar storage (efficient for analytical queries, if pyarrow installed)")

    return {
        'corpus_npy': corpus_npy_path,
        'query_npy': query_npy_path,
        'corpus_csv': corpus_csv_path,
        'query_csv': query_csv_path,
        'corpus_json': corpus_json_path,
        'query_json': query_json_path,
        'output_dir': output_dir
    }

def calculate_timing_stats(timing_list):
    """Calculates basic descriptive statistics (mean, median, min, max, std dev, total sum)
    for a given list of numerical timings. This is used to summarize performance metrics
    like search times, upload durations, etc., providing a quick overview of performance.
    """
    if not timing_list:
        return {
            "mean": 0.0, "median": 0.0, "min": 0.0,
            "max": 0.0, "std": 0.0, "total": 0.0
        }
    return {
        "mean": statistics.mean(timing_list),
        "median": statistics.median(timing_list),
        "min": min(timing_list),
        "max": max(timing_list),
        "std": statistics.stdev(timing_list) if len(timing_list) > 1 else 0.0, # std dev requires at least two data points.
        "total": sum(timing_list)
    }

def format_and_print_metrics(ndcg, _map, recall, precision, ks=K_VALUES):
    """Formats and prints retrieval metrics (NDCG, MAP, Recall, Precision) in a clean,
    tabular format for specified K values. This provides an immediate, human-readable
    summary of the retrieval quality for a benchmark run.
    """
    print("\nRetrieval Metrics:")
    print("-------------------")
    for k in ks:
        print(f"NDCG@{k}: {ndcg.get(f'NDCG@{k}', 0.0):.4f} | "
              f"MAP@{k}: {_map.get(f'MAP@{k}', 0.0):.4f} | "
              f"Recall@{k}: {recall.get(f'Recall@{k}', 0.0):.4f} | "
              f"P@{k}: {precision.get(f'P@{k}', 0.0):.4f}") # Includes Precision for completeness.

def extract_all_metrics(ndcg, _map, recall, precision, ks=K_VALUES):
    """Extracts all relevant retrieval metrics into a single dictionary.
    This structured format is ideal for storage, particularly for CSV output,
    ensuring all evaluation results are consistently captured.
    """
    metrics = {}
    for k in ks:
        metrics[f"NDCG@{k}"] = float(ndcg.get(f"NDCG@{k}", 0.0)) # Ensures float conversion for consistency.
        metrics[f"MAP@{k}"] = float(_map.get(f"MAP@{k}", 0.0))
        metrics[f"Recall@{k}"] = float(recall.get(f"Recall@{k}", 0.0))
        metrics[f"P@{k}"] = float(precision.get(f"P@{k}", 0.0))
    return metrics

def save_results_to_csv(new_result: dict, csv_path: str):
    """Appends a new benchmark result entry to a CSV file. If the file doesn't exist,
    it creates it along with the header. Otherwise, it appends the new data,
    ensuring data integrity and continuity of results over multiple runs.
    """
    new_df = pd.DataFrame([new_result]) # Converts the new result dictionary into a single-row DataFrame.
    os.makedirs(os.path.dirname(csv_path), exist_ok=True) # Ensures the directory for the CSV file exists.
    write_header = not os.path.exists(csv_path) # Checks if the file already exists to determine if a header is needed.

    try:
        if write_header:
            new_df.to_csv(csv_path, mode='w', header=True, index=False) # Writes to a new file with header.
        else:
            # Reads existing CSV, concatenates with new data, then overwrites to maintain header and data consistency.
            existing_df = pd.read_csv(csv_path)
            combined_df = pd.concat([existing_df, new_df], ignore_index=True) # Appends new results.
            combined_df.to_csv(csv_path, mode='w', header=True, index=False) # Overwrites with combined data.
        print(f"üíæ Results saved to: {csv_path}")
    except Exception as e:
        print(f"‚ùå CSV save failed: {e}. Please check file permissions or path validity.")

def clean_memory():
    """Forces Python's garbage collector to release memory.
    This is particularly important in resource-constrained environments like Colab,
    especially when processing large datasets, to prevent out-of-memory errors.
    """
    gc.collect()

def should_cleanup_namespace(provider_name, dataset_name):
    """Interactively prompts the user whether to delete the created vector database
    index/namespace after benchmarking. This gives the user control over resource
    management, helping to prevent unintended cloud costs.
    """
    response = input(f"\n‚ùì Delete {provider_name} namespace/index for {dataset_name} after benchmarking? (y/n): ").strip().lower()
    return response in ['y', 'yes'] # Returns True if user confirms 'yes', False otherwise.

# -------------------- 6. Interactive Dataset Source Selection --------------------
# This section allows the user to choose between using standard BEIR datasets
# or the more diverse MAIR datasets for benchmarking. The choice affects which
# datasets are displayed and how they are loaded.
print("\nüìö Select Dataset Source for Benchmarking:")
print("  1) BEIR - Benchmark for Information Retrieval (standard datasets like nfcorpus, scifact)")
print("  2) MAIR - Multi-domain Adversarial Information Retrieval (diverse datasets from various domains)")
dataset_source_choice = input("‚û°Ô∏è Enter your choice [1/2] (default is 1): ").strip() or "1"
DATASET_SOURCE = "mair" if dataset_source_choice == "2" else "beir"
print(f"‚úÖ Selected dataset source: {DATASET_SOURCE.upper()}")

# -------------------- 7. Interactive Provider Selection and API Key Status --------------------
# This section informs the user about available vector database providers, checks the status
# of their API keys (indicating if they are configured correctly), and allows the user to
# select which providers to include in the current benchmark run.
print("\nüîß Available Providers for BINARY Embeddings Benchmarking:")
print(f"  1. Moorcheh (Vector namespace - {VECTOR_DIMENSION}D binary vector search)")
print(f"  2. Pinecone (Vector index - {VECTOR_DIMENSION}D binary vector search)")
print(f"  3. Elasticsearch (Dense vector - {VECTOR_DIMENSION}D binary vector search)")

print("\nüîë API Keys Status (ensure these are set in Colab secrets or environment variables):")
for provider in ['moorcheh', 'cohere', 'pinecone']:
    key = api_keys.get(provider)
    status = "‚úÖ Found" if key else "‚ùå Missing" # Indicates if the API key was successfully loaded.
    display_name = provider.capitalize()
    if provider == 'cohere':
        display_name += " (for embeddings)" # Adds context for Cohere API key.
    print(f"  {display_name}: {status}")

# Checks Elasticsearch connectivity/credentials status separately.
# It's considered configured if a URL is provided AND either an API key or username/password are present.
es_config = api_keys.get('elasticsearch', {})
es_configured = bool(es_config and es_config.get('url') and (es_config.get('api_key') or (es_config.get('username') and es_config.get('password'))))
print(f"  Elasticsearch: {'‚úÖ Configured' if es_configured else '‚ùå Not fully configured (URL or credentials missing)'}")

print(f"\n‚öôÔ∏è  Current Benchmark Configuration (BINARY Embeddings):")
print(f"  ‚Ä¢ Embedding Model: Cohere {EMBEDDING_MODEL} (used to generate initial float embeddings)")
print(f"  ‚Ä¢ Vector Dimension: {VECTOR_DIMENSION}D (float) will be converted to {VECTOR_DIMENSION} bits (binary)")
print(f"  ‚Ä¢ Binarization Method: Sign-based (values >= 0 become 1, values < 0 become 0)")
print(f"  ‚Ä¢ Batch Size for API calls/uploads: {BATCH_SIZE}")
print(f"  ‚Ä¢ This benchmark focuses on BINARY embeddings only (no float comparison in this run)")
print(f"  ‚Ä¢ Expected Space Savings after binarization: ~32x compression (from float32 to binary representation)")

# User selects which providers to test interactively.
# User can enter '1,2' for Moorcheh and Pinecone, or 'all' for all available providers.
provider_choice = input("\n‚û°Ô∏è Select providers to test (e.g., '1,2,3' or 'all') (default is 'all'): ").strip().lower() or "all"
if provider_choice == 'all':
    selected_providers = []
    if api_keys.get('moorcheh'): selected_providers.append('moorcheh')
    if api_keys.get('pinecone'): selected_providers.append('pinecone')
    if es_configured: selected_providers.append('elasticsearch')
else:
    provider_map = {'1': 'moorcheh', '2': 'pinecone', '3': 'elasticsearch'}
    selected_providers = [provider_map[p.strip()] for p in provider_choice.split(',') if p.strip() in provider_map]

# Filters out providers for which API keys are missing or not configured, informing the user.
selected_providers_filtered = []
for p in selected_providers:
    if p == 'moorcheh' and not api_keys.get('moorcheh'):
        print(f"‚ö†Ô∏è Moorcheh not included: API key missing.")
    elif p == 'pinecone' and not api_keys.get('pinecone'):
        print(f"‚ö†Ô∏è Pinecone not included: API key missing.")
    elif p == 'elasticsearch' and not es_configured:
        print(f"‚ö†Ô∏è Elasticsearch not included: Configuration incomplete.")
    else:
        selected_providers_filtered.append(p)

selected_providers = selected_providers_filtered

if not selected_providers:
    print("‚ùå No providers selected or configured. Please check your API keys and try again.")
    exit(1) # Exits if no providers can be benchmarked, as there's nothing to test.

print(f"‚úÖ Selected providers for benchmarking: {', '.join(selected_providers).capitalize()}")

# -------------------- 8. Initialize API Clients --------------------
# This section initializes the API clients for the selected vector database providers
# and Cohere (for embeddings). Clients are only initialized if their respective
# API keys are available, ensuring secure and functional connections.
clients = {}
es_client = None

# Initializes Cohere client, which is essential for generating embeddings for ALL providers.
# The script will exit if the Cohere API key is not found.
if api_keys['cohere']:
    import cohere
    cohere_client = cohere.Client(api_keys['cohere']) # Uses the retrieved Cohere API key.
    print(f"\nüß† Cohere client initialized successfully for embedding generation (model: {EMBEDDING_MODEL})")
else:
    print("\n‚ùå Cohere API key required! Please set 'COHERE_API_KEY' in Colab secrets or environment variables.")
    exit(1) # Stops execution if Cohere API key is missing.

# Initializes Moorcheh client if it was selected and its API key is present.
if 'moorcheh' in selected_providers and api_keys['moorcheh']:
    from moorcheh_sdk import MoorchehClient, ConflictError # Imports Moorcheh specific exceptions.
    clients['moorcheh'] = MoorchehClient(api_key=api_keys['moorcheh']) # Initializes with API key.
    print(f"‚úÖ Moorcheh client initialized.")

# Initializes Pinecone client if it was selected and its API key is present.
if 'pinecone' in selected_providers and api_keys['pinecone']:
    try:
        # Attempts to import Pinecone v2 style if available, with fallback to older v1 syntax.
        from pinecone import Pinecone, ServerlessSpec
    except ImportError:
        import pinecone as pc
        Pinecone = pc.Pinecone
        ServerlessSpec = pc.ServerlessSpec

    clients['pinecone'] = Pinecone(api_key=api_keys['pinecone']) # Initializes with API key.
    print(f"‚úÖ Pinecone client initialized.")

# Initializes Elasticsearch client if it was selected and properly configured.
if 'elasticsearch' in selected_providers and es_configured:
    try:
        from elasticsearch import Elasticsearch
        es_config = api_keys['elasticsearch'] # Retrieves Elasticsearch specific configuration.

        # Connects to Elasticsearch using either API key or basic authentication, prioritizing API key.
        if es_config.get('api_key'):
            es_client = Elasticsearch(es_config['url'], api_key=es_config['api_key'], request_timeout=60) # API key authentication.
        elif es_config.get('username') and es_config.get('password'):
            es_client = Elasticsearch(es_config['url'], basic_auth=(es_config['username'], es_config['password']), request_timeout=60) # Basic authentication.
        else:
            # Fallback for local Elasticsearch without authentication (less secure, for local testing only).
            es_client = Elasticsearch(es_config['url'], request_timeout=60)

        if es_client and es_client.ping(): # Tests the connection to Elasticsearch.
            print("‚úÖ Elasticsearch connected successfully.")
        else:
            es_client = None
            print("‚ö†Ô∏è Elasticsearch connection failed. Please check URL and credentials.")
    except Exception as e:
        print(f"‚ö†Ô∏è Elasticsearch client initialization failed: {e}. Please ensure Elasticsearch is running and accessible.")
        es_client = None # Sets client to None if initialization fails.


# -------------------- 9. Vector Database Provider Classes --------------------
# These classes encapsulate the specific logic for interacting with each vector
# database provider (Moorcheh, Pinecone, Elasticsearch). Each class handles
# operations such as vector upload, search queries, and resource cleanup,
# tailored to the provider's API and specifically for binary embeddings.

class MoorchehBinaryProvider:
    """Manages interaction with the Moorcheh vector database for binary vector benchmarking.
    This class handles the creation of namespaces, uploading of binarized vectors,
    performing similarity searches, and cleaning up resources within Moorcheh.
    """
    def __init__(self, client, namespace_name, precomputed_vectors, query_embeddings):
        self.client = client # Moorcheh API client instance.
        self.namespace_name = namespace_name # Unique name for the Moorcheh namespace to be used.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and binarized vectors.
        self.query_embeddings = query_embeddings # List of binarized query embeddings.
        self.upload_timings = {
            "server_upload_time_s": 0.0, # Accumulates total server-side time for vector uploads.
            "batch_details": [] # Stores detailed timings and info for each batch upload.
        }
        self.search_timings = [] # Stores timings for individual search queries.

    def upload(self):
        """Uploads binarized vectors to a Moorcheh vector namespace.
        It attempts to create a namespace, and if it already exists, proceeds to use it.
        Vectors are uploaded in batches to optimize performance.
        """
        from moorcheh_sdk import ConflictError # Imports specific Moorcheh exception for existing namespaces.
        try:
            # Attempts to create a new namespace. Moorcheh organizes vectors into namespaces.
            # If a namespace with the same name already exists, a ConflictError is typically raised.
            self.client.create_namespace(
                namespace_name=self.namespace_name,
                type="vector", # Specifies that this is a vector namespace.
                vector_dimension=VECTOR_DIMENSION # Defines the expected dimension of vectors.
            )
            print(f"‚úÖ Created Moorcheh namespace: {self.namespace_name}")
        except ConflictError:
            print(f"‚ö†Ô∏è Moorcheh namespace '{self.namespace_name}' already exists, proceeding to use it.")
        except Exception as e:
            print(f"‚ùå Error creating Moorcheh namespace: {e}")
            raise # Re-raises critical errors (e.g., authentication issues) to halt execution.

        print(f"\nüìä Binary Upload Chunks Details for Moorcheh:")
        print(f"   Total vectors to upload: {len(self.precomputed_vectors)}")
        print(f"   Batch size for uploads: {BATCH_SIZE}")
        print(f"   Total number of batches: {(len(self.precomputed_vectors) + BATCH_SIZE - 1) // BATCH_SIZE}")
        print(f"\n   Batch Breakdown:")
        print(f"   {'Batch':<8} {'Vectors':<12} {'First ID':<15} {'Sample Vector (first 10 dims)':<50} {'Server Time (s)':<8}")
        print(f"   {'-'*100}")

        batch_num = 0
        # Iterates through the precomputed vectors in batches, displaying a progress bar.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading binary vectors to Moorcheh"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Gets a slice of vectors for the current batch.
            batch_num += 1

            try:
                # Calls the Moorcheh SDK to upload the batch of vectors.
                # The 'vectors' argument expects a list of dictionaries like {'id': 'doc1', 'vector': [0.0, 1.0, ...]}.
                response = self.client.upload_vectors(
                    namespace_name=self.namespace_name,
                    vectors=batch
                )

                # Extracts server-side execution time from the response for performance analysis.
                server_time = 0.0
                if isinstance(response, dict):
                    server_time = response.get("execution_time", 0.0) # Primary timing key.
                    if "timings" in response: # Checks for more detailed nested timings.
                        server_time = response["timings"].get("total", server_time)

                # Gathers sample vector information for logging and verification purposes.
                first_id = batch[0]['id']
                sample_vector = batch[0]['vector'][:10]  # Takes the first 10 dimensions for brevity.
                sample_str = f"[{', '.join([f'{v:.1f}' for v in sample_vector])}...]".ljust(50) # Formats for display.

                self.upload_timings["server_upload_time_s"] += server_time # Accumulates total server upload time.
                self.upload_timings["batch_details"].append({
                    "batch_num": batch_num,
                    "batch_size": len(batch),
                    "first_id": first_id,
                    "server_time_s": server_time,
                    "sample_vector": sample_vector # Stores sample for debugging/verification.
                }) # Stores details for each batch.

                print(f"   {batch_num:<8} {len(batch):<12} {first_id:<15} {sample_str:<50} {server_time:<8.4f}")

            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed to upload to Moorcheh: {e}") # Logs any batch upload errors.

        print(f"\n‚è±Ô∏è  Moorcheh Upload Summary (Server-Side Timings):")
        print(f"    Total server-side upload time: {self.upload_timings['server_upload_time_s']:.4f}s")
        if batch_num > 0:
            print(f"    Average server-side time per batch: {self.upload_timings['server_upload_time_s'] / batch_num:.4f}s")

        return len(self.precomputed_vectors) # Returns the total number of vectors intended for upload.

    def search(self, query_idx, top_k=100):
        """Performs a vector search with a binarized query vector against Moorcheh.
        It retrieves the top_k most similar documents based on the stored binary embeddings.
        """
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.

        # Executes the search query using the Moorcheh SDK.
        resp = self.client.search(
            namespaces=[self.namespace_name], # Specifies the namespace(s) to search within.
            query=query_embedding, # The binarized query vector.
            top_k=top_k # The number of top results to retrieve.
        )

        # Extracts server-side search time from the response.
        server_time = resp.get("execution_time", 0.0) if isinstance(resp, dict) else 0.0

        # Extracts more detailed timing components if provided in the Moorcheh response.
        timing_detail = {"server_time_s": server_time}
        if isinstance(resp, dict) and "timings" in resp:
            timings = resp["timings"]
            for key, value in timings.items():
                if isinstance(value, (int, float)): # Only logs numerical timing values.
                    timing_detail[f"moorcheh_{key}_s"] = value # Stores detailed timings with a prefix.

        self.search_timings.append(timing_detail) # Records timings for this search operation.

        # Formats search results into a dictionary of {doc_id: score} for evaluation compatibility.
        hits = resp.get("results", []) if isinstance(resp, dict) else []
        return {str(r["id"]): float(r["score"]) for r in hits} # Returns retrieved document IDs and their scores.

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Moorcheh, including detailed component breakdowns.
        This provides insights into the time spent on different stages of the search process (e.g., indexing, reranking).
        """
        server_times = [t["server_time_s"] for t in self.search_timings] # Extracts overall server times.
        stats = {"overall": calculate_timing_stats(server_times)} # Calculates stats for overall server time.

        # Calculates statistics for each detailed timing component provided by Moorcheh.
        # This helps in understanding where the time is spent during a search operation.
        if self.search_timings and len(self.search_timings) > 0:
            first_timing = self.search_timings[0] # Uses the first search timing entry to identify available component keys.
            # Collects all unique detailed timing keys (e.g., 'moorcheh_indexing_s', 'moorcheh_reranking_s').
            timing_keys = [k for k in first_timing.keys() if k != "server_time_s"]

            for key in timing_keys:
                values = [t.get(key, 0.0) for t in self.search_timings if key in t] # Gathers values for each component.
                if values:
                    stats[key] = calculate_timing_stats(values) # Calculates stats for each component.

        return stats

    def cleanup(self):
        """Deletes the Moorcheh namespace created for the benchmark.
        This helps in managing cloud resources and cleaning up temporary data.
        """
        try:
            self.client.delete_namespace(self.namespace_name) # Calls the SDK to delete the namespace.
            print(f"üßπ Deleted Moorcheh namespace: {self.namespace_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete Moorcheh namespace '{self.namespace_name}': {e}. You may need to delete it manually via the Moorcheh dashboard if it persists.")


class PineconeBinaryProvider:
    """Manages interaction with the Pinecone vector database for binary vector benchmarking.
    This class handles index creation, upserting of binarized vectors, performing searches,
    and cleaning up the Pinecone index.
    """
    def __init__(self, client, index_name, precomputed_vectors, query_embeddings):
        self.client = client # Pinecone API client instance.
        self.index_name = index_name # Unique name for the Pinecone index to be used.
        self.index = None # Will store the Pinecone Index object after creation.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and binarized vectors.
        self.query_embeddings = query_embeddings # List of binarized query embeddings.
        self.upload_timings = {
            "index_creation_s": 0.0, # Time taken to create the Pinecone index.
            "upsert_time_s": 0.0, # Total time for upserting (uploading) vectors.
            "batch_details": [] # Detailed timings for each batch upsert operation.
        }
        self.search_timings = [] # Timings for individual search queries.

    def upload(self):
        """Uploads binarized vectors to a Pinecone index.
        It first checks for and deletes any existing index with the same name to ensure a clean run,
        then creates a new index and upserts vectors in batches.
        """
        from pinecone import ServerlessSpec # Imports Pinecone's specification for index creation (e.g., cloud and region).

        t0 = time.perf_counter() # Starts timer for index creation.
        # Checks if an index with the same name already exists and deletes it.
        # This prevents issues with stale data or conflicting configurations from previous runs.
        if self.index_name in [idx.name for idx in self.client.list_indexes()]:
            print(f"üóëÔ∏è Deleting existing Pinecone index: {self.index_name}")
            self.client.delete_index(self.index_name)
            time.sleep(5) # Pauses to allow Pinecone to complete the index deletion.

        # Creates a new Pinecone index configured for binary-like vectors.
        # 'metric': 'cosine' is a common similarity metric, often used even for binarized vectors when treated as floats.
        # 'spec': ServerlessSpec defines the cloud provider and region for a managed serverless index.
        print(f"‚ú® Creating new Pinecone index: {self.index_name} (Dimension: {VECTOR_DIMENSION}, Metric: cosine)")
        self.client.create_index(
            name=self.index_name,
            dimension=VECTOR_DIMENSION,
            metric='cosine', # Cosine similarity is a suitable choice for comparing vector directions.
            spec=ServerlessSpec(cloud='aws', region='us-east-1') # **User Customizable**: Adjust cloud/region as needed.
        )

        # Waits for the newly created index to be ready before proceeding with data upsert operations.
        # Index creation can take a few moments to provision resources.
        print(f"‚è≥ Waiting for Pinecone index '{self.index_name}' to be ready...")
        while not self.client.describe_index(self.index_name).status['ready']:
            time.sleep(1) # Polls every second until the index status is 'ready'.

        t1 = time.perf_counter() # Ends timer for index creation.
        self.upload_timings["index_creation_s"] = t1 - t0 # Records the duration of index creation.

        self.index = self.client.Index(self.index_name) # Gets the Pinecone Index object for data operations.

        print(f"\nüìä Binary Upload Chunks Details for Pinecone:")
        print(f"   Total vectors to upsert: {len(self.precomputed_vectors)}")
        print(f"   Batch size for upserts: {BATCH_SIZE}")
        print(f"   Total number of batches: {(len(self.precomputed_vectors) + BATCH_SIZE - 1) // BATCH_SIZE}")
        print(f"\n   Batch Breakdown:")
        print(f"   {'Batch':<8} {'Vectors':<12} {'First ID':<15} {'Sample Vector (first 10 dims)':<50} {'Client Time (s)':<8}")
        print(f"   {'-'*100}")

        batch_num = 0
        # Iterates through the precomputed vectors in batches and upserts them to Pinecone.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading binary vectors to Pinecone"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Gets the current batch of vectors.
            # Pinecone's `upsert` method expects a list of (id, vector) tuples.
            vectors = [(v['id'], v['vector']) for v in batch]

            batch_num += 1
            t_batch_start = time.perf_counter() # Starts client-side timer for the batch upsert.
            self.index.upsert(vectors=vectors) # Performs the upsert operation.
            t_batch_end = time.perf_counter() # Ends client-side timer.
            batch_time = t_batch_end - t_batch_start # Calculates the duration of the batch upsert.

            # Logs batch details for monitoring progress and verifying data.
            first_id = batch[0]['id']
            sample_vector = batch[0]['vector'][:10]
            sample_str = f"[{','.join([f'{v:.1f}' for v in sample_vector])}...]".ljust(50)

            self.upload_timings["upsert_time_s"] += batch_time # Accumulates total upsert time.
            self.upload_timings["batch_details"].append({
                "batch_num": batch_num,
                "batch_size": len(batch),
                "first_id": first_id,
                "upsert_time_s": batch_time,
                "sample_vector": sample_vector
            }) # Stores details for each batch.

            print(f"   {batch_num:<8} {len(batch):<12} {first_id:<15} {sample_str:<50} {batch_time:<8.4f}")

        print(f"\n‚è±Ô∏è  Pinecone Upload Summary:")
        print(f"    Index Creation Time: {self.upload_timings['index_creation_s']:.4f}s")
        print(f"    Total Upsert Time: {self.upload_timings['upsert_time_s']:.4f}s")
        if batch_num > 0:
            print(f"    Average Upsert Time per batch: {self.upload_timings['upsert_time_s'] / batch_num:.4f}s")

        return len(self.precomputed_vectors) # Returns the total number of vectors intended for upsert.

    def search(self, query_idx, top_k=100):
        """Performs a vector similarity search with a binarized query vector against Pinecone.
        It retrieves the top_k most similar documents based on cosine similarity.
        """
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.

        t0 = time.perf_counter() # Starts client-side timer for the query operation.
        results = self.index.query(vector=query_embedding, top_k=top_k) # Executes the Pinecone query.
        t1 = time.perf_counter() # Ends client-side timer.

        query_time = t1 - t0 # Calculates client-side query time.
        self.search_timings.append({
            "query_time_s": query_time,
            "query_time_ms": query_time * 1000 # Stores in milliseconds for easier readability.
        }) # Records timings for this search operation.

        # Formats results into a dictionary of {doc_id: score} for evaluation compatibility.
        # Pinecone returns results in a 'matches' list, each with 'id' and 'score'.
        return {str(match['id']): float(match['score']) for match in results.get('matches', [])}

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Pinecone (client-side durations).
        This provides performance metrics for the vector search operations.
        """
        query_times_s = [t["query_time_s"] for t in self.search_timings] # Extracts query times in seconds.
        query_times_ms = [t["query_time_ms"] for t in self.search_timings] # Extracts query times in milliseconds.
        return {
            "query_time_s": calculate_timing_stats(query_times_s), # Stats for seconds.
            "query_time_ms": calculate_timing_stats(query_times_ms) # Stats for milliseconds.
        }

    def cleanup(self):
        """Deletes the Pinecone index created for the benchmark.
        This is important for managing cloud resources and avoiding unnecessary charges.
        """
        try:
            self.client.delete_index(self.index_name) # Calls the SDK to delete the index.
            print(f"üßπ Deleted Pinecone index: {self.index_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete Pinecone index '{self.index_name}': {e}. You may need to delete it manually via the Pinecone console if it persists.")


class ElasticsearchBinaryProvider:
    """Manages interaction with Elasticsearch for binary vector benchmarking,
    using its `dense_vector` field type. This class handles index creation,
    bulk ingestion of binarized vectors, k-Nearest Neighbor (kNN) searches,
    and index cleanup.
    """
    def __init__(self, client, index_name, precomputed_vectors, query_embeddings):
        self.client = client # Elasticsearch API client instance.
        self.index_name = index_name # Unique name for the Elasticsearch index to be used.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and binarized vectors.
        self.query_embeddings = query_embeddings # List of binarized query embeddings.
        self.upload_timings = {
            "index_creation_s": 0.0, # Time taken to create the Elasticsearch index.
            "server_bulk_time_ms": 0.0, # Total server-side time for bulk ingestion.
            "client_total_time_s": 0.0, # Total client-side time for sending bulk requests.
            "batch_details": [] # Detailed timings for each bulk batch.
        }
        self.search_timings = [] # Timings for individual search queries.

    def upload(self):
        """Uploads binarized vectors to an Elasticsearch index using a `dense_vector` field.
        It first attempts to delete any existing index, then creates a new one with a specific mapping,
        and finally ingests vectors in batches using Elasticsearch's bulk API.
        """
        try:
            # Deletes an existing index if it's found, ensuring a fresh start for the benchmark.
            print(f"üóëÔ∏è Attempting to delete existing Elasticsearch index: {self.index_name}")
            # `ignore=[400, 404]` prevents errors if the index doesn't exist or deletion fails for non-critical reasons.
            self.client.indices.delete(index=self.index_name, ignore=[400, 404])
            time.sleep(2) # Pauses to allow Elasticsearch to process the deletion.
        except Exception as e:
            print(f"‚ö†Ô∏è Could not delete Elasticsearch index '{self.index_name}' (it might not exist or there was an error): {e}")

        t0 = time.perf_counter() # Starts timer for index creation.

        # Defines the index mapping with a `dense_vector` field for storing embeddings.
        # 'dims' specifies the vector dimension. 'similarity' can be 'cosine', 'dot_product', or 'l2_norm'.
        # 'index': True enables vector indexing for efficient kNN search (HNSW by default in recent ES versions).
        index_config = {
            "mappings": {
                "properties": {
                    "doc_id": {"type": "keyword"}, # Stores document ID as a keyword for exact matching.
                    "embedding": {
                        "type": "dense_vector",
                        "dims": VECTOR_DIMENSION, # Sets the dimension to match the binarized embeddings.
                        "index": True, # Enables vector indexing for kNN search for performance.
                        "similarity": "cosine", # Uses cosine similarity, suitable for binarized float vectors.
                    }
                }
            },
            "settings": {
                "number_of_shards": 1, # **User Customizable**: Number of primary shards (1 is common for single-node).
                "number_of_replicas": 0 # **User Customizable**: Number of replica shards (0 for dev, >0 for prod HA).
            }
        }

        try:
            self.client.indices.create(index=self.index_name, body=index_config) # Creates the index with the defined mapping.
            print(f"‚úÖ Created Elasticsearch index: {self.index_name} (Dimension: {VECTOR_DIMENSION}D, Type: BINARY)")
        except Exception as e:
            print(f"‚ùå Elasticsearch index creation error: {e}")
            raise # Re-raises critical errors (e.g., invalid mapping) to halt execution.

        t1 = time.perf_counter() # Ends timer for index creation.
        self.upload_timings["index_creation_s"] = t1 - t0 # Records index creation time.

        print(f"\nüìä Binary Upload Chunks Details for Elasticsearch:")
        print(f"   Total vectors to index: {len(self.precomputed_vectors)}")
        print(f"   Batch size for bulk requests: {BATCH_SIZE}")
        print(f"   Total number of batches: {(len(self.precomputed_vectors) + BATCH_SIZE - 1) // BATCH_SIZE}")

        batch_num = 0
        total_uploaded = 0
        t_upload_start = time.perf_counter() # Starts timer for total client-side upload duration.

        # Uses Elasticsearch's bulk API for efficient ingestion of many documents.
        # The bulk API expects newline-delimited JSON for action/metadata and source data.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading binary vectors to Elasticsearch"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Gets a slice of vectors for the current batch.
            batch_num += 1

            bulk_body = []
            for vec in batch:
                # Each item in a bulk request consists of an action (index) and the document itself.
                bulk_body.append(json.dumps({"index": {"_index": self.index_name, "_id": vec["id"]}})) # Index action with document ID.
                bulk_body.append(json.dumps({"doc_id": vec["id"], "embedding": vec["vector"]})) # Document source data.

            bulk_data = "\n".join(bulk_body) + "\n" # Bulk API requires a newline at the end of each JSON object and a final newline.

            try:
                t_batch_start = time.perf_counter() # Starts client-side timer for this bulk request.
                # `refresh=False` improves ingestion performance; the index will be refreshed manually later.
                response = self.client.bulk(body=bulk_data, refresh=False)
                t_batch_end = time.perf_counter() # Ends client-side timer.
                batch_client_time = t_batch_end - t_batch_start # Records client-side time for this batch.

                server_took_ms = response.get("took", 0) # Time taken by Elasticsearch server to process the bulk request.
                self.upload_timings["server_bulk_time_ms"] += server_took_ms # Accumulates total server-side bulk time.

                # Counts successfully indexed items by checking the status codes in the response.
                batch_success = sum(1 for item in response.get("items", [])
                                  if "index" in item and item["index"].get("status") in [200, 201])
                total_uploaded += batch_success

                self.upload_timings["batch_details"].append({
                    "batch_num": batch_num,
                    "batch_size": len(batch),
                    "server_time_ms": server_took_ms,
                    "client_time_s": batch_client_time
                }) # Stores details for each batch.
            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed to upload to Elasticsearch: {e}") # Logs any bulk upload errors.

        t_upload_end = time.perf_counter() # Ends timer for total client-side upload duration.
        self.upload_timings["client_total_time_s"] = t_upload_end - t_upload_start # Records total client-side upload time.

        try:
            # Refreshes the index to make newly ingested documents searchable.
            # This is critical after `refresh=False` was used during bulk ingestion.
            self.client.indices.refresh(index=self.index_name)
            print(f"‚úÖ Elasticsearch index '{self.index_name}' refreshed, documents are now searchable.")
        except Exception as e:
            print(f"‚ö†Ô∏è Error refreshing Elasticsearch index: {e}")

        print(f"\n‚è±Ô∏è  Elasticsearch Upload Timing Summary:")
        print(f"    Index Creation Time: {self.upload_timings['index_creation_s']:.4f}s")
        print(f"    Bulk Upload (SERVER processing time): {self.upload_timings['server_bulk_time_ms']/1000:.4f}s (Total time Elasticsearch spent processing all bulk requests)")
        print(f"    Bulk Upload (CLIENT total request time): {self.upload_timings['client_total_time_s']:.4f}s (Total time spent by client making bulk requests)")
        print(f"    Total Upload Time (Client-side, incl. index creation): {self.upload_timings['index_creation_s'] + self.upload_timings['client_total_time_s']:.4f}s")

        return total_uploaded # Returns the number of documents successfully uploaded.

    def search(self, query_idx, top_k=100):
        """Performs a k-Nearest Neighbor (kNN) search with a binarized query vector against Elasticsearch.
        It leverages Elasticsearch's `knn` query type to find the most similar documents.
        """
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.

        # Constructs the kNN search query body for Elasticsearch.
        # 'k': The number of results to return per shard.
        # 'num_candidates': The number of candidates to consider from each shard (higher = more accurate but slower).
        query_body = {
            "knn": {
                "field": "embedding", # The `dense_vector` field containing document embeddings.
                "query_vector": query_embedding, # The binarized query vector for similarity search.
                "k": top_k, # Number of results to return from the total kNN search.
                "num_candidates": min(top_k * 10, 10000) # **User Customizable**: Adjust `num_candidates` for speed/accuracy trade-off.
            },
            "size": top_k, # Ensures only the top_k results are returned in the overall response.
            "_source": ["doc_id"] # Only retrieves the 'doc_id' field to minimize data transfer.
        }

        try:
            t_start = time.perf_counter() # Starts client-side timer for the search request.
            response = self.client.search(index=self.index_name, body=query_body) # Executes the Elasticsearch search.
            t_end = time.perf_counter() # Ends client-side timer.
            client_time_s = t_end - t_start # Calculates client-side search time.

            server_took_ms = response.get("took", 0) # Extracts server-side processing time in milliseconds.
            server_took_s = server_took_ms / 1000.0 # Converts server time to seconds.

            self.search_timings.append({
                "server_time_ms": server_took_ms,
                "server_time_s": server_took_s,
                "client_time_s": client_time_s,
                "client_time_ms": client_time_s * 1000
            }) # Records detailed timings for this search.

            hits = response.get("hits", {}).get("hits", []) # Extracts the actual search hits.
            results = {}
            for i, hit in enumerate(hits):
                doc_id = hit.get("_id") # Document ID is available in the `_id` field in ES.
                score = hit.get("_score", 1.0 / (1.0 + i)) # Retrieval score; falls back to a decreasing score if `_score` is missing.
                results[str(doc_id)] = score # Stores results in {doc_id: score} format.

            return results
        except Exception as e:
            print(f"\n‚ùå Elasticsearch search error: {e}")
            # Logs zero timings for failed searches to avoid breaking statistics calculations.
            self.search_timings.append({
                "server_time_ms": 0,
                "server_time_s": 0.0,
                "client_time_s": 0.0,
                "client_time_ms": 0.0
            })
            return {} # Returns empty results on error.

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Elasticsearch (both server-side and client-side).
        This provides a detailed breakdown of where time is spent during search operations.
        """
        server_times_s = [t["server_time_s"] for t in self.search_timings] # Server times in seconds.
        server_times_ms = [t["server_time_ms"] for t in self.search_timings] # Server times in milliseconds.
        client_times_s = [t["client_time_s"] for t in self.search_timings] # Client times in seconds.
        client_times_ms = [t["client_time_ms"] for t in self.search_timings] # Client times in milliseconds.
        return {
            "server_s": calculate_timing_stats(server_times_s),
            "server_ms": calculate_timing_stats(server_times_ms),
            "client_s": calculate_timing_stats(client_times_s),
            "client_ms": calculate_timing_stats(client_times_ms)
        }

    def cleanup(self):
        """Deletes the Elasticsearch index created for the benchmark.
        This is crucial for managing cloud resources and ensuring a clean state for future runs.
        """
        try:
            self.client.indices.delete(index=self.index_name, ignore=[400, 404]) # Deletes index; ignores errors if it's already gone.
            print(f"üßπ Deleted Elasticsearch index: {self.index_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete Elasticsearch index '{self.index_name}': {e}. You may need to delete it manually via Kibana or ES API if it persists.")


# -------------------- 10. Main Benchmark Orchestration Function --------------------
# The `run_binary_benchmark` function orchestrates the entire benchmarking workflow
# for a given dataset and selected providers. It encompasses data loading, embedding
# generation, binarization, uploading to various vector databases, performing searches,
# evaluating retrieval quality, and saving detailed results.

def run_binary_benchmark(dataset_name, provider_names, source="beir"):
    """Runs a full benchmark with BINARY embeddings for specified providers and a chosen dataset.
    This function coordinates data loading, embedding, binarization, upload, search, and evaluation.
    """
    print(f"\n{'='*70}")
    print(f"üöÄ Starting Benchmark for Dataset: {dataset_name} (Source: {source.upper()})")
    print(f"üìä Using: BINARY Embeddings (Sign-based Binarization)")
    print(f"Providers to test: {', '.join(provider_names).capitalize()}")
    print(f"{'='*70})")

    # Loads dataset based on the selected source (BEIR or MAIR).
    if source == "beir":
        dataset_path = f"{DATA_ROOT}/{dataset_name}"

        # Downloads BEIR dataset if it's not already present locally.
        if not os.path.exists(dataset_path):
            print(f"üì¶ Downloading BEIR dataset: {dataset_name}. This may take some time.")
            url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
            util.download_and_unzip(url, DATA_ROOT) # Uses BEIR's utility for download and extraction.

        # Loads BEIR corpus, queries, and qrels using their `GenericDataLoader`.
        corpus, queries, qrels = GenericDataLoader(dataset_path).load(split="test")
        print(f"‚úÖ Loaded BEIR dataset: {len(corpus)} documents, {len(queries)} queries, {len(qrels)} qrels.")
    else:  # Logic for MAIR dataset loading.
        dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
        docs_path = os.path.join(dataset_path, 'docs')
        queries_path = os.path.join(dataset_path, 'queries')

        # Checks if MAIR dataset files exist; if not, skips this dataset and logs a warning.
        if not os.path.exists(docs_path) or not os.path.exists(queries_path):
            print(f"‚ö†Ô∏è MAIR dataset '{dataset_name}' not found at {dataset_path}. Skipping this dataset.")
            return []

        print(f"\nüì• Loading MAIR dataset: {dataset_name} from {dataset_path}")
        corpus = {}
        # MAIR datasets can have multiple .jsonl files for documents; all are loaded.
        docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
        for file in docs_files:
            corpus.update(load_jsonl(os.path.join(docs_path, file))) # Uses helper to load JSONL.

        queries = {}
        # Similarly, all query .jsonl files are loaded.
        query_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
        for file in query_files:
            queries.update(load_jsonl(os.path.join(queries_path, file)))

        # Extracts qrels (relevance judgments) from MAIR query data using the helper function.
        qrels = extract_qrels_from_queries(queries)
        print(f"‚úÖ Loaded MAIR dataset: {len(corpus)} documents, {len(queries)} queries, {len(qrels)} qrels.")

    # Early exit if any critical component (corpus, queries, or qrels) is missing or empty.
    if not corpus or not queries or not qrels:
        print(f"‚ö†Ô∏è Essential data (corpus, queries, or qrels) is empty for {dataset_name}. Skipping benchmark.")
        return []

    # ========== STEP 1: Generate Float Embeddings with Cohere (Once per dataset) ==========
    # This step uses the Cohere API to generate dense float embeddings for both the corpus
    # documents and the queries. These float embeddings are then binarized in the next step.
    print(f"\nüß† Step 1: Generating float embeddings with Cohere {EMBEDDING_MODEL} (Dimension: {VECTOR_DIMENSION}D)...")

    # Limits the number of documents to embed and upload, based on `MAX_UPLOAD_DOCS`.
    # This is a user-configurable parameter to manage benchmark scale and cost.
    docs = list(corpus.items())[:MAX_UPLOAD_DOCS]
    print(f"  Processing up to {len(docs)} documents for embedding (limited by MAX_UPLOAD_DOCS={MAX_UPLOAD_DOCS}).")

    # Extracts text content from corpus documents for embedding. It handles various common
    # field names (`text`, `contents`, `body`, etc.) and falls back to string conversion.
    texts = []
    corpus_texts = {} # Stores original texts, which can be useful for potential reranking or debugging.
    for doc_id, doc_content in docs:
        text = None
        if isinstance(doc_content, dict):
            # Prioritizes common fields for text content in documents.
            for field in ['text', 'contents', 'content', 'body', 'passage', 'document', 'title', 'abstract']:
                if field in doc_content and doc_content[field]:
                    val = doc_content[field]
                    if isinstance(val, str) and val.strip(): # Ensures it's a non-empty string.
                        text = val
                        break
            # Fallback to combining title and text if both are present.
            if not text and 'title' in doc_content and 'text' in doc_content:
                text = f"{doc_content['title']}. {doc_content['text']}".strip()
            # Last resort: converts the entire dictionary to a string if no specific text field is found.
            # This might not be ideal for embedding but prevents errors.
            if not text:
                text = str(doc_content)
        else:
            text = str(doc_content) # For non-dict corpus entries (e.g., if corpus directly contains strings).

        final_text = text if text else "document" # Provides a default if text extraction still yields empty.
        texts.append(final_text)
        corpus_texts[str(doc_id)] = final_text # Stores the processed text associated with its ID.

    doc_ids = [str(d[0]) for d in docs] # Ensures all document IDs are strings for consistency.

    print(f"  üìÑ Generating corpus embeddings for {len(texts)} documents using Cohere {EMBEDDING_MODEL}...")
    t_embed_start = time.perf_counter() # Starts timer for corpus embedding generation.

    corpus_embeddings_float = []
    # Processes corpus embeddings in batches to optimize API calls to Cohere and manage memory.
    for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Embedding corpus documents"):
        batch_texts = texts[i:i+BATCH_SIZE]
        response = cohere_client.embed(
            texts=batch_texts,
            model=EMBEDDING_MODEL,
            input_type=INPUT_TYPE_CORPUS # Specifies input type for optimal embedding generation.
        )
        corpus_embeddings_float.extend(response.embeddings) # Collects generated embeddings.

    t_embed_end = time.perf_counter() # Ends timer for corpus embedding generation.
    embedding_time = t_embed_end - t_embed_start # Calculates total corpus embedding time.

    # Generates query embeddings in a similar batched manner.
    print(f"  üîç Generating query embeddings for {len(queries)} queries using Cohere {EMBEDDING_MODEL}...")
    query_ids = list(queries.keys())
    query_texts = []

    for qid in query_ids:
        q = queries[qid]
        text = None
        if isinstance(q, dict):
            # Prioritizes common query text fields.
            text = (q.get('text') or q.get('query') or q.get('instruction') or
                   q.get('question') or q.get('query_text') or str(q)) # Falls back to string conversion.
        else:
            text = str(q) # For non-dict query entries.
        query_texts.append(text if text else "query") # Provides a default if query text is empty.

    t_query_start = time.perf_counter() # Starts timer for query embedding generation.
    query_embeddings_float = []
    for i in tqdm(range(0, len(query_texts), BATCH_SIZE), desc="Embedding queries"):
        batch_texts = [str(t) if t else "query" for t in query_texts[i:i+BATCH_SIZE]] # Ensures texts are strings.
        response = cohere_client.embed(
            texts=batch_texts,
            model=EMBEDDING_MODEL,
            input_type=INPUT_TYPE_QUERY # Specifies input type for optimal embedding generation.
        )
        query_embeddings_float.extend(response.embeddings) # Collects generated embeddings.

    t_query_end = time.perf_counter() # Ends timer for query embedding generation.
    query_embedding_time = t_query_end - t_query_start # Calculates total query embedding time.

    print(f"\n‚è±Ô∏è  Embedding Generation Time Summary:")
    print(f"    Corpus embeddings ({len(corpus_embeddings_float)} documents): {embedding_time:.4f}s")
    print(f"    Query embeddings ({len(query_embeddings_float)} queries): {query_embedding_time:.4f}s")

    # ========== STEP 2: Binarize Embeddings (Sign-based) ==========
    # This is the core step for binary embedding benchmarking. It converts the
    # dense float embeddings generated by Cohere into a binary representation
    # (0s and 1s) based on the sign of each component.
    print(f"\nüîÑ Step 2: Binarizing embeddings (sign-based method: values >= 0 -> 1, values < 0 -> 0)...")

    t_binarize_start = time.perf_counter() # Starts timer for binarization.

    # Binarizes corpus and query embeddings. The `binarize_embeddings` function
    # converts them to float32 (0.0 or 1.0) to maintain compatibility with vector databases.
    corpus_embeddings_binary = binarize_embeddings(corpus_embeddings_float)
    query_embeddings_binary = binarize_embeddings(query_embeddings_float)
    t_binarize_end = time.perf_counter() # Ends timer for binarization.
    binarization_time = t_binarize_end - t_binarize_start # Calculates total binarization time.

    # Calculates space savings achieved by binarization. This provides an estimate
    # of memory reduction, assuming the binary data is stored as float32 for compatibility.
    # Actual savings can be higher if the database supports native bit storage.
    float_size_mb = (len(corpus_embeddings_float) * VECTOR_DIMENSION * 4) / (1024 * 1024) # Size of original float embeddings (float32 = 4 bytes).
    binary_size_mb = (corpus_embeddings_binary.nbytes) / (1024 * 1024) # Size of binarized embeddings in memory (still float32).
    space_savings = (1 - (binary_size_mb / float_size_mb)) * 100 # Percentage of space saved.

    print(f"\nüíæ Space and Performance Efficiency After Binarization:")
    print(f"    Original Float Embeddings Size: {float_size_mb:.2f} MB (if stored as float32)")
    print(f"    Binarized Embeddings Size (in memory as float32): {binary_size_mb:.2f} MB")
    print(f"    Achieved Space Savings: {space_savings:.1f}% (effectively {float_size_mb/binary_size_mb:.1f}x compression when stored as float32 representing binary)")
    print(f"    Binarization processing time: {binarization_time:.4f}s")
    print(f"    Note: Actual storage savings can be much higher (up to 32x) if the vector database supports native bit/binary types.")

    # ========== STEP 3: Prepare Vectors for Upload to Vector Databases ==========
    # This step formats the binarized vectors into a list of dictionaries, which is
    # the standard format expected by most vector database APIs for bulk uploads.
    print(f"\nüì¶ Step 3: Preparing vectors for upload to selected vector databases...")

    # Creates a list of dictionaries, each containing an 'id' and a 'vector' (list of float 0.0s or 1.0s).
    binary_vectors_for_upload = []
    for doc_id, vec in zip(doc_ids, corpus_embeddings_binary):
        binary_vectors_for_upload.append({
            "id": doc_id,
            "vector": vec.tolist() if isinstance(vec, np.ndarray) else list(vec), # Ensures vector is a standard Python list.
        })

    # Converts query embeddings to a list of lists format for consistency with API calls.
    query_embeddings_list_for_search = [
        q.tolist() if isinstance(q, np.ndarray) else list(q) for q in query_embeddings_binary
    ]

    # ========== OPTIONAL: Save Binarized Embeddings to Disk ==========
    # This section provides an interactive option for the user to save the generated
    # binary embeddings locally. This can be useful for inspection, debugging, or
    # for using the embeddings with other tools outside of this notebook.
    download_choice = input(f"\nüíæ Do you want to download the binarized embeddings (saved in multiple formats to {DRIVE_PATH})? (y/n): ").strip().lower()
    if download_choice in ['y', 'yes']:
        # Defines a subdirectory within DRIVE_PATH to store the embeddings.
        embeddings_output_dir = os.path.join(DRIVE_PATH, f"{source}_{dataset_name}_binary_embeddings")
        save_binary_embeddings(doc_ids, corpus_embeddings_binary, query_ids, query_embeddings_binary,
                              dataset_name, embeddings_output_dir) # Uses the helper function to save.
    else:
        print(f"‚è≠Ô∏è  Skipping binary embedding download.")

    results_all = [] # A list to collect benchmark results from all provider runs for this dataset.

    # ========== STEP 4: Benchmark Each Selected Provider with Binary Vectors ==========
    # This is the core loop of the benchmark. For each selected vector database provider,
    # it initializes the provider, uploads the binarized data, performs searches using
    # binarized queries, evaluates the retrieval quality, and records performance metrics.
    for provider_name in provider_names:
        print(f"\n{'‚îÄ'*70}")
        print(f"üîß Running benchmark for Provider: {provider_name.upper()} | Using Binarized Vectors")
        print(f"{'‚îÄ'*70}")

        try:
            provider = None
            # Initializes the appropriate provider class based on the `provider_name`.
            # This ensures that provider-specific API calls and logic are used.
            if provider_name == 'moorcheh':
                if 'moorcheh' not in clients: # Skips if Moorcheh client was not initialized (e.g., missing API key).
                    print(f"‚ö†Ô∏è Skipping {provider_name}: Moorcheh client not initialized (API key missing?).")
                    continue
                safe_name = dataset_name.replace('/', '-').replace('.', '_') # Creates a URL-safe namespace name.
                namespace_name = f"{source}-{safe_name}-binary-v4"[:63] # Moorcheh namespace names max 63 chars.
                provider = MoorchehBinaryProvider(
                    clients['moorcheh'], namespace_name,
                    binary_vectors_for_upload, query_embeddings_list_for_search
                )
            elif provider_name == 'pinecone':
                if 'pinecone' not in clients: # Skips if Pinecone client was not initialized.
                    print(f"‚ö†Ô∏è Skipping {provider_name}: Pinecone client not initialized (API key missing?).")
                    continue
                safe_name = dataset_name.replace('/', '-').replace('_', '-').lower() # Pinecone index names have specific rules.
                index_name = f"{source}-{safe_name}-binary"[:45] # Pinecone index names max 45 chars.
                provider = PineconeBinaryProvider(
                    clients['pinecone'], index_name,
                    binary_vectors_for_upload, query_embeddings_list_for_search
                )
            elif provider_name == 'elasticsearch':
                if es_client is None: # Skips if ES client failed to connect.
                    print(f"‚ö†Ô∏è Skipping {provider_name}: Elasticsearch client not connected or configured.")
                    continue
                safe_name = dataset_name.replace('/', '-').replace('_', '-').lower() # ES index names are typically lowercase.
                index_name = f"{source}-{safe_name}-binary" # ES index names can be longer.
                provider = ElasticsearchBinaryProvider(
                    es_client, index_name,
                    binary_vectors_for_upload, query_embeddings_list_for_search
                )

            if provider is None: # Catches cases where no provider object was created due to configuration issues.
                print(f"‚ùå Could not initialize provider: {provider_name}. Check configuration and API keys.")
                continue

            # Uploads binarized data to the current provider's vector database.
            print(f"\nüì§ Uploading binary vectors to {provider_name}...")
            num_uploaded = provider.upload()
            if num_uploaded == 0: # If no documents were uploaded, skips search and evaluation for this provider.
                print(f"‚ö†Ô∏è No documents uploaded to {provider_name}. Skipping search and evaluation for this provider.")
                del provider # Cleans up the provider object.
                clean_memory() # Forces garbage collection.
                continue

            # Performs searches for all queries against the current provider.
            print(f"\nüîç Performing searches with binary queries on {provider_name} (top_k={TOP_K_SEARCH})...")
            results_per_provider = {} # Stores search results for this specific provider.

            for i, qid in enumerate(tqdm(query_ids, desc=f"Searching {provider_name} for {dataset_name}")):
                try:
                    results_per_provider[qid] = provider.search(i, top_k=TOP_K_SEARCH)
                except Exception as e:
                    print(f"\n‚ùå Query {i+1} ('{query_ids[i]}') failed for {provider_name}: {e}")
                    results_per_provider[qid] = {} # Logs empty results for failed queries to prevent further errors.

            # Retrieves and prints search timing statistics from the provider object.
            search_stats = provider.get_search_stats()

            if provider_name == 'moorcheh':
                print(f"\n‚è±Ô∏è  {provider_name.capitalize()} Search Timing Summary (Server-Side):")
                overall = search_stats.get('overall', {}) # Retrieves overall timing statistics.
                print(f"    Overall Mean Query Time: {overall.get('mean', 0)*1000:.2f}ms")
                print(f"    Overall Median Query Time: {overall.get('median', 0)*1000:.2f}ms")
                print(f"    Min: {overall.get('min', 0)*1000:.2f}ms, Max: {overall.get('max', 0)*1000:.2f}ms")
                # Optionally prints detailed component timings for Moorcheh if available.
                if len(search_stats) > 1: # Checks if detailed components exist beyond 'overall'.
                    print(f"    Detailed Components (Mean):")
                    for key, stats_comp in search_stats.items():
                        if key != 'overall':
                            print(f"      - {key.replace('moorcheh_', '').replace('_s', '')}: {stats_comp.get('mean',0)*1000:.2f}ms")
            elif provider_name == 'pinecone':
                print(f"\n‚è±Ô∏è  {provider_name.capitalize()} Search Timing Summary (Client-Side):")
                query_stats = search_stats.get('query_time_ms', {}) # Retrieves query timing statistics in milliseconds.
                print(f"    Mean Query Time: {query_stats.get('mean', 0):.2f}ms")
                print(f"    Median Query Time: {query_stats.get('median', 0):.2f}ms")
                print(f"    Min: {query_stats.get('min', 0):.2f}ms, Max: {query_stats.get('max', 0):.2f}ms")
                print(f"    Std Dev: {query_stats.get('std', 0):.2f}ms")
            elif provider_name == 'elasticsearch':
                print(f"\n‚è±Ô∏è  {provider_name.capitalize()} Search Timing Summary:")
                # Displays both server-side and client-side timings for Elasticsearch.
                server_stats_ms = search_stats.get('server_ms', {})
                client_stats_ms = search_stats.get('client_ms', {})
                print(f"    Server-side Mean Query Time: {server_stats_ms.get('mean', 0):.2f}ms")
                print(f"    Client-side Mean Query Time: {client_stats_ms.get('mean', 0):.2f}ms")

            # Evaluates retrieval quality using BEIR's evaluation suite or a custom equivalent.
            print(f"\nüìä Evaluating retrieval quality for {provider_name} (metrics at K={K_VALUES})...")
            # The `EvaluateRetrieval` class from BEIR is a robust tool for calculating standard IR metrics.
            evaluator = EvaluateRetrieval()
            ndcg, _map, recall, precision = evaluator.evaluate(qrels, results_per_provider, k_values=K_VALUES)
            format_and_print_metrics(ndcg, _map, recall, precision, ks=K_VALUES)

            # Stores all metrics and timing data into a result dictionary for CSV output.
            metrics = extract_all_metrics(ndcg, _map, recall, precision, ks=K_VALUES)
            result_entry = {
                "Dataset": dataset_name,
                "Source": source.upper(),
                "Provider": provider_name,
                "Num_Corpus": len(corpus), # Original number of corpus documents.
                "Num_Uploaded": num_uploaded, # Actual number of documents successfully uploaded to the provider.
                "Num_Queries": len(queries),
                "Vector_Dimension": VECTOR_DIMENSION,
                "Embedding_Model": EMBEDDING_MODEL,
                "Binarization_Method": "sign-based",
                "Embedding_Time_s": round(embedding_time, 4), # Time to generate initial float embeddings.
                "Query_Embedding_Time_s": round(query_embedding_time, 4), # Time to generate query float embeddings.
                "Binarization_Time_s": round(binarization_time, 4), # Time to convert to binary.
                "Float_Size_MB": round(float_size_mb, 2), # Estimated size of original float embeddings.
                "Binary_Size_MB": round(binary_size_mb, 2), # Estimated size of binarized embeddings (as float32).
                "Space_Savings_Pct": round(space_savings, 1), # Percentage of space saved.
                "Compression_Ratio": round(float_size_mb / binary_size_mb, 1), # Compression ratio.
            }

            # Adds provider-specific timing details to the result dictionary.
            if provider_name == 'moorcheh':
                overall_stats = search_stats.get('overall', {}) # Retrieves overall search statistics.
                result_entry["Upload_Time_s"] = round(provider.upload_timings["server_upload_time_s"], 4)
                result_entry["Search_Server_Total_s"] = round(overall_stats.get('total', 0), 4)
                result_entry["Search_Server_Mean_s"] = round(overall_stats.get('mean', 0), 4)
                result_entry["Search_Server_Median_s"] = round(overall_stats.get('median', 0), 4)
                result_entry["Search_Server_Min_s"] = round(overall_stats.get('min', 0), 4)
                result_entry["Search_Server_Max_s"] = round(overall_stats.get('max', 0), 4)
                result_entry["Search_Server_Std_s"] = round(overall_stats.get('std', 0), 4)
                result_entry["Search_Server_Mean_ms"] = round(overall_stats.get('mean', 0) * 1000, 2)

                # Adds detailed Moorcheh timing breakdowns in milliseconds (if available).
                for key, stats_comp in search_stats.items():
                    if key != 'overall' and isinstance(stats_comp, dict):
                        clean_key = key.replace('moorcheh_', '').replace('_s', '') # Cleans key name for CSV column.
                        result_entry[f"Moorcheh_{clean_key}_mean_s"] = round(stats_comp.get('mean', 0), 4)
                        result_entry[f"Moorcheh_{clean_key}_mean_ms"] = round(stats_comp.get('mean', 0) * 1000, 4)

            elif provider_name == 'pinecone':
                # Pinecone upload timing includes index creation and upsert time.
                result_entry["Upload_Index_Creation_s"] = round(provider.upload_timings["index_creation_s"], 4)
                result_entry["Upload_Upsert_s"] = round(provider.upload_timings["upsert_time_s"], 4)
                result_entry["Upload_Total_s"] = round(
                    provider.upload_timings["index_creation_s"] + provider.upload_timings["upsert_time_s"], 4
                )
                # Pinecone search timing (client-side).
                query_stats_s = search_stats.get('query_time_s', {})
                query_stats_ms = search_stats.get('query_time_ms', {})

                result_entry["Search_Total_s"] = round(query_stats_s.get('total', 0), 4)
                result_entry["Search_Mean_s"] = round(query_stats_s.get('mean', 0), 4)
                result_entry["Search_Median_s"] = round(query_stats_s.get('median', 0), 4)
                result_entry["Search_Min_s"] = round(query_stats_s.get('min', 0), 4)
                result_entry["Search_Max_s"] = round(query_stats_s.get('max', 0), 4)
                result_entry["Search_Std_s"] = round(query_stats_s.get('std', 0), 4)

                result_entry["Search_Mean_ms"] = round(query_stats_ms.get('mean', 0), 2)
                result_entry["Search_Median_ms"] = round(query_stats_ms.get('median', 0), 2)
                result_entry["Search_Min_ms"] = round(query_stats_ms.get('min', 0), 2)
                result_entry["Search_Max_ms"] = round(query_stats_ms.get('max', 0), 2)
                result_entry["Search_Std_ms"] = round(query_stats_ms.get('std', 0), 2)

            elif provider_name == 'elasticsearch':
                # Elasticsearch upload timing, differentiating server and client times.
                result_entry["Upload_Index_Creation_s"] = round(provider.upload_timings["index_creation_s"], 4)
                result_entry["Upload_Bulk_Server_s"] = round(provider.upload_timings["server_bulk_time_ms"] / 1000, 4)
                result_entry["Upload_Bulk_Client_s"] = round(provider.upload_timings["client_total_time_s"], 4)
                result_entry["Upload_Total_s"] = round(
                    provider.upload_timings["index_creation_s"] + provider.upload_timings["client_total_time_s"], 4
                )
                # Elasticsearch search timing, differentiating server and client times.
                server_stats_s = search_stats.get('server_s', {})
                server_stats_ms = search_stats.get('server_ms', {})
                client_stats_s = search_stats.get('client_s', {})
                client_stats_ms = search_stats.get('client_ms', {})

                result_entry["Search_Server_Total_s"] = round(server_stats_s.get('total', 0), 4)
                result_entry["Search_Server_Mean_s"] = round(server_stats_s.get('mean', 0), 4)
                result_entry["Search_Server_Median_s"] = round(server_stats_s.get('median', 0), 4)
                result_entry["Search_Server_Min_s"] = round(server_stats_s.get('min', 0), 4)
                result_entry["Search_Server_Max_s"] = round(server_stats_s.get('max', 0), 4)
                result_entry["Search_Server_Std_s"] = round(server_stats_s.get('std', 0), 4)

                result_entry["Search_Server_Mean_ms"] = round(server_stats_ms.get('mean', 0), 2)
                result_entry["Search_Server_Median_ms"] = round(server_stats_ms.get('median', 0), 2)
                result_entry["Search_Server_Min_ms"] = round(server_stats_ms.get('min', 0), 2)
                result_entry["Search_Server_Max_ms"] = round(server_stats_ms.get('max', 0), 2)
                result_entry["Search_Server_Std_ms"] = round(server_stats_ms.get('std', 0), 2)

                result_entry["Search_Client_Total_s"] = round(client_stats_s.get('total', 0), 4)
                result_entry["Search_Client_Mean_s"] = round(client_stats_s.get('mean', 0), 4)
                result_entry["Search_Client_Median_s"] = round(client_stats_s.get('median', 0), 4)
                result_entry["Search_Client_Min_s"] = round(client_stats_s.get('min', 0), 4)
                result_entry["Search_Client_Max_s"] = round(client_stats_s.get('max', 0), 4)
                result_entry["Search_Client_Std_s"] = round(client_stats_s.get('std', 0), 4)

                result_entry["Search_Client_Mean_ms"] = round(client_stats_ms.get('mean', 0), 2)
                result_entry["Search_Client_Median_ms"] = round(client_stats_ms.get('median', 0), 2)
                result_entry["Search_Client_Min_ms"] = round(client_stats_ms.get('min', 0), 2)
                result_entry["Search_Client_Max_ms"] = round(client_stats_ms.get('max', 0), 2)
                result_entry["Search_Client_Std_ms"] = round(client_stats_ms.get('std', 0), 2)

            result_entry.update(metrics) # Adds retrieval quality metrics to the result entry.

            save_results_to_csv(result_entry, CSV_PATH) # Saves results to CSV after each provider/dataset run.
            results_all.append(result_entry) # Appends to the list of all results for final summary.

            # Cleanup: Asks the user whether to delete the created index/namespace to manage cloud resources.
            should_delete = should_cleanup_namespace(provider_name, dataset_name)
            if should_delete:
                provider.cleanup() # Calls the provider's cleanup method.
            else:
                print(f"üíæ Keeping {provider_name} resources for {dataset_name}. Remember to delete them manually if no longer needed to avoid charges.")

            del provider # Explicitly deletes the provider object to free up resources.
            del results_per_provider # Clears search results for the next iteration.
            clean_memory() # Forces garbage collection.

        except Exception as e:
            print(f"‚ùå An unexpected error occurred during {provider_name} benchmark for {dataset_name}: {e}")
            import traceback
            traceback.print_exc() # Prints full traceback for debugging.
            clean_memory()

    # Clears large embedding arrays from memory after all providers for a dataset have been processed.
    del corpus_embeddings_float
    del query_embeddings_float
    del corpus_embeddings_binary
    del query_embeddings_binary
    del binary_vectors_for_upload
    del query_embeddings_list_for_search
    clean_memory()

    return results_all


# -------------------- 11. Interactive Dataset Selection Loop --------------------
# This section allows the user to interactively select which datasets to run the
# benchmark on. The benchmark will execute for each selected dataset and provider combination.

if DATASET_SOURCE == "beir":
    print("\nüìä Available BEIR datasets for benchmarking:")
    for ds in BEIR_DATASETS_SORTED_DISPLAY:
        print(ds) # Displays user-friendly list of BEIR datasets.
    available_datasets = BEIR_DATASETS
    dataset_map = {i+1: ds for i, ds in enumerate(BEIR_DATASETS)} # Maps display number to actual dataset name.
else:  # Logic for MAIR dataset selection.
    print("\nüìä Discovering MAIR datasets in your combined path...")
    # Retrieves available MAIR datasets using the helper function, categorized and with sizes.
    mair_datasets, datasets_by_category, dataset_sizes = get_mair_datasets()

    if not mair_datasets: # If no MAIR datasets are found, prompts the user and exits.
        print("‚ùå No MAIR datasets found! Please ensure you have run the previous step to download and combine them into the specified MAIR_COMBINED_PATH.")
        exit(1) # Exits if no datasets to benchmark.

    print(f"\nüìÇ Available MAIR datasets by category ({len(mair_datasets)} total found):")
    dataset_index = 1
    dataset_map = {}
    # Displays MAIR datasets grouped by category, sorted by document count (largest first).
    for category in sorted(datasets_by_category.keys()):
        datasets_in_category = datasets_by_category[category]
        if datasets_in_category:
            # Sorts datasets within each category by document count in descending order.
            sorted_datasets_in_category = sorted(datasets_in_category, key=lambda d: dataset_sizes.get(d, 0), reverse=True)
            print(f"\n  üìÅ {category} ({len(sorted_datasets_in_category)} datasets):")
            for dataset in sorted_datasets_in_category:
                # Formats the document count for readability (e.g., K for thousands, M for millions).
                size_str = f"{dataset_sizes.get(dataset, 0):,}" if dataset_sizes.get(dataset, 0) < 1000000 else f"{dataset_sizes.get(dataset, 0)/1000000:.1f}M"
                print(f"     {dataset_index}. {dataset} ({size_str} docs)")
                dataset_map[dataset_index] = dataset # Maps the display index to the actual dataset name.
                dataset_index += 1
    available_datasets = mair_datasets # Keeps a list of all discovered MAIR datasets.

print(f"\nüí° Binary Embedding Benchmark Overview (Key Configuration):")
print(f"  ‚Ä¢ Embeddings will be generated once per dataset using Cohere {EMBEDDING_MODEL}")
print(f"  ‚Ä¢ Embeddings are then binarized with a sign-based method (values >= 0 become 1, values < 0 become 0)")
print(f"  ‚Ä¢ Benchmarking will be performed on the selected providers: {', '.join(selected_providers).capitalize()} (if configured)")
print(f"  ‚Ä¢ Expected space compression: ~32x (from float32 to binary representation, if database supports native bit types)")
print(f"  ‚Ä¢ Evaluation includes retrieval quality (NDCG, MAP, Recall, Precision) and performance (upload/search times)")

all_results = [] # List to store results from all benchmark runs across all datasets and providers.

while True:
    # Prompts the user for dataset selection based on the chosen source (BEIR or MAIR).
    if DATASET_SOURCE == "beir":
        choice = input(f"\n‚û°Ô∏è Enter dataset number (1-{len(BEIR_DATASETS)}) to benchmark, or type 'stop' to finish: ").strip().lower()
    else:
        choice = input(f"\n‚û°Ô∏è Enter dataset number (1-{len(dataset_map)}) to benchmark, or type 'stop' to finish: ").strip().lower()

    if choice == "stop":
        print("üèÅ Stopping benchmark process as requested.")
        break

    try:
        idx = int(choice) # Converts user input to an integer index.
        dataset_name = None
        if DATASET_SOURCE == "beir":
            if 1 <= idx <= len(BEIR_DATASETS):
                dataset_name = BEIR_DATASETS[idx - 1] # Retrieves BEIR dataset name from the list.
            else:
                print(f"‚ö†Ô∏è Invalid number. Please enter a number between 1 and {len(BEIR_DATASETS)} for BEIR datasets.")
                continue # Asks for input again.
        else: # MAIR dataset selection.
            if idx in dataset_map:
                dataset_name = dataset_map[idx] # Retrieves MAIR dataset name from the mapped dictionary.
            else:
                print(f"‚ö†Ô∏è Invalid number. Please enter a valid number from the MAIR list above.")
                continue # Asks for input again.
    except ValueError:
        print("‚ö†Ô∏è Invalid input. Please enter a number corresponding to a dataset, or type 'stop'.")
        continue # Asks for input again.

    # Runs the benchmark for the selected dataset and the chosen providers.
    # The results from this single dataset run are appended to the overall results list.
    results_for_current_dataset = run_binary_benchmark(dataset_name, selected_providers, source=DATASET_SOURCE)
    all_results.extend(results_for_current_dataset)

# -------------------- 12. Final Benchmark Summary and Analysis --------------------
# After all selected datasets are processed, this section provides an overall summary
# of the benchmark results. It includes tables of key metrics, average performance
# comparisons by provider, and insights into space efficiency and timing.

if all_results: # Checks if any results were generated from the benchmark runs.
    print(f"\n{'='*70}")
    print("üèÅ BENCHMARK COMPLETE - BINARY EMBEDDINGS (Overall Summary)")
    print(f"{'='*70}")

    df = pd.DataFrame(all_results) # Converts the list of result dictionaries into a Pandas DataFrame for easy analysis.

    print(f"\nüíæ All detailed results have been saved to: {CSV_PATH}")

    # Summary table for key metrics, providing a concise overview of each benchmark run.
    print("\nüìä Overall Results Summary (Key Metrics per Run):")
    print("‚îÄ" * 70)
    # Defines key columns for a concise summary display in the console.
    summary_cols = ['Dataset', 'Source', 'Provider', 'NDCG@10', 'MAP@10', 'Recall@100', 'Search_Mean_ms']
    available_cols = [col for col in summary_cols if col in df.columns] # Filters for columns that actually exist in the DataFrame.
    if available_cols:
        # Displays a concise table of results using Pandas' `to_string` for better console formatting.
        print(df[available_cols].to_string(index=False))

    # Calculates and displays average performance by provider across all tested datasets.
    print("\nüìà Average Performance by Provider (across all tested datasets and runs):")
    print("‚îÄ" * 70)
    for provider in df['Provider'].unique(): # Iterates through each unique provider that was benchmarked.
        provider_data = df[df['Provider'] == provider] # Filters DataFrame for the current provider.
        print(f"\n{provider.upper()}:")
        print(f"  Average NDCG@10:     {provider_data['NDCG@10'].mean():.4f}")
        print(f"  Average MAP@10:      {provider_data['MAP@10'].mean():.4f}")
        print(f"  Average Recall@100:  {provider_data['Recall@100'].mean():.4f}")
        if 'Search_Mean_ms' in provider_data.columns: # General search time (e.g., Pinecone client-side total).
            print(f"  Average Search Mean: {provider_data['Search_Mean_ms'].mean():.2f}ms")
        if 'Search_Server_Mean_ms' in provider_data.columns: # Server-side search time (e.g., for ES, Moorcheh).
            print(f"  Average Search Server Mean: {provider_data['Search_Server_Mean_ms'].mean():.2f}ms")
        if 'Search_Client_Mean_ms' in provider_data.columns: # Client-side search time (e.g., for ES).
            print(f"  Average Search Client Mean: {provider_data['Search_Client_Mean_ms'].mean():.2f}ms")

    # Compares speed (average search time per query) across providers.
    print("\n‚ö° Speed Comparison (Average Search Time per Query across providers):")
    print("‚îÄ" * 70)
    for provider in df['Provider'].unique():
        provider_data = df[df['Provider'] == provider]
        if 'Search_Mean_ms' in provider_data.columns: # Uses general mean search time if available.
            avg_time = provider_data['Search_Mean_ms'].mean()
            print(f"  {provider.upper():15} {avg_time:8.2f}ms")
        elif 'Search_Server_Mean_ms' in provider_data.columns: # Fallback to server time for providers like ES/Moorcheh.
            avg_server = provider_data['Search_Server_Mean_ms'].mean()
            avg_client = provider_data.get('Search_Client_Mean_ms', pd.Series([0.0])).mean() # Retrieves client time if it exists.
            print(f"  {provider.upper():15} Server: {avg_server:8.2f}ms | Client: {avg_client:8.2f}ms")

    # Compares upload times (total average upload duration) across providers.
    print("\nüì§ Upload Time Comparison (Average across providers and datasets):")
    print("‚îÄ" * 70)
    for provider in df['Provider'].unique():
        provider_data = df[df['Provider'] == provider]
        if provider == 'moorcheh' and 'Upload_Time_s' in provider_data.columns: # Moorcheh has a specific upload time key.
            avg_upload = provider_data['Upload_Time_s'].mean()
            print(f"  {provider.upper():15} {avg_upload:8.2f}s (server-side total upload)")
        elif provider == 'pinecone' and 'Upload_Total_s' in provider_data.columns: # Pinecone has a total upload key.
            total_upload = provider_data['Upload_Total_s'].mean()
            index_creation = provider_data['Upload_Index_Creation_s'].mean()
            upsert = provider_data['Upload_Upsert_s'].mean()
            print(f"  {provider.upper():15}")
            print(f"    Index Creation: {index_creation:8.2f}s")
            print(f"    Upsert:         {upsert:8.2f}s")
            print(f"    Total:          {total_upload:8.2f}s (Index Creation + Upsert)")
        elif provider == 'elasticsearch' and 'Upload_Total_s' in provider_data.columns: # Elasticsearch has a total upload key.
            total_upload = provider_data['Upload_Total_s'].mean()
            index_creation = provider_data['Upload_Index_Creation_s'].mean()
            bulk_client = provider_data['Upload_Bulk_Client_s'].mean()
            bulk_server = provider_data['Upload_Bulk_Server_s'].mean()
            print(f"  {provider.upper():15}")
            print(f"    Index Creation:  {index_creation:8.8f}s") # Increased precision for small times.
            print(f"    Bulk (Server):   {bulk_server:8.8f}s")
            print(f"    Bulk (Client):   {bulk_client:8.8f}s")
            print(f"    Total:           {total_upload:8.8f}s (Index Creation + Client Bulk)")

    # Summarizes space efficiency for binary embeddings.
    print("\nüíæ Space Efficiency (Binary Embeddings, Average across all runs):")
    print("‚îÄ" * 70)
    if 'Float_Size_MB' in df.columns: # Checks if space efficiency metrics exist.
        avg_float = df['Float_Size_MB'].mean()
        avg_binary = df['Binary_Size_MB'].mean()
        avg_savings = df['Space_Savings_Pct'].mean()
        avg_compression = df['Compression_Ratio'].mean()
        print(f"  Average Original Float Size (per dataset): {avg_float:.2f} MB")
        print(f"  Average Binarized Size (per dataset):      {avg_binary:.2f} MB")
        print(f"  Average Space Savings:                     {avg_savings:.1f}%")
        print(f"  Average Compression Ratio:                 {avg_compression:.1f}x")

    # Provides a timing breakdown for embedding generation and binarization.
    print("\n‚è±Ô∏è  Embedding Generation & Binarization Time (Average per dataset):")
    print("‚îÄ" * 70)
    avg_embed_time = df['Embedding_Time_s'].mean()
    avg_query_time = df['Query_Embedding_Time_s'].mean()
    avg_binarize_time = df['Binarization_Time_s'].mean()
    total_prep_time = avg_embed_time + avg_query_time + avg_binarize_time
    print(f"  Corpus embedding generation:     {avg_embed_time:.4f}s")
    print(f"  Query embedding generation:      {avg_query_time:.4f}s")
    print(f"  Binarization process:            {avg_binarize_time:.4f}s")
    print(f"  Total Embedding Prep Time:       {total_prep_time:.4f}s")

    # Details quality metrics for all K values tested.
    print("\nüéØ Detailed Quality Metrics (Average across datasets, all K values):")
    print("‚îÄ" * 70)
    for provider in df['Provider'].unique():
        provider_data = df[df['Provider'] == provider]
        print(f"\n{provider.upper()}:")
        for k in K_VALUES:
            ndcg_col = f'NDCG@{k}'
            map_col = f'MAP@{k}'
            recall_col = f'Recall@{k}'
            precision_col = f'P@{k}'
            if ndcg_col in df.columns: # Checks if metric column exists.
                print(f"  @{k:3d} - NDCG: {provider_data[ndcg_col].mean():.4f} | "
                      f"MAP: {provider_data[map_col].mean():.4f} | "
                      f"Recall: {provider_data[recall_col].mean():.4f} | "
                      f"Precision: {provider_data[precision_col].mean():.4f}")

    # Provides a per-dataset breakdown for quick comparison of key metrics.
    print("\nüìã Per-Dataset Breakdown (Summary of key metrics for each run):")
    print("‚îÄ" * 70)
    for dataset in df['Dataset'].unique():
        dataset_data = df[df['Dataset'] == dataset]
        source = dataset_data['Source'].iloc[0] if 'Source' in dataset_data.columns else 'UNKNOWN'
        print(f"\nDataset: {dataset.upper()} (Source: {source}):")
        for provider in dataset_data['Provider'].unique():
            provider_dataset_data = dataset_data[dataset_data['Provider'] == provider]
            row = provider_dataset_data.iloc[0] # Assumes one entry per dataset-provider combination.

            search_time_str = "N/A"
            if 'Search_Mean_ms' in row: # Prioritizes general search mean.
                search_time_str = f"{row['Search_Mean_ms']:.2f}ms"
            elif 'Search_Server_Mean_ms' in row: # Falls back to server search mean.
                search_time_str = f"{row['Search_Server_Mean_ms']:.2f}ms (server)"
            elif 'Search_Client_Mean_ms' in row: # Falls back to client search mean.
                 search_time_str = f"{row['Search_Client_Mean_ms']:.2f}ms (client)"

            upload_time_str = "N/A"
            if 'Upload_Time_s' in row: # Moorcheh upload.
                upload_time_str = f"{row['Upload_Time_s']:.2f}s"
            elif 'Upload_Total_s' in row: # Pinecone/Elasticsearch total upload.
                upload_time_str = f"{row['Upload_Total_s']:.2f}s"

            print(f"  {provider.capitalize():15} | "
                  f"NDCG@10: {row['NDCG@10']:.4f} | "
                  f"Search Time: {search_time_str:15} | "
                  f"Upload Time: {upload_time_str}")

    # Identifies top performers for key metrics (e.g., best NDCG, fastest search).
    print("\nüèÜ Performance Leaders (Across all runs for averaged metrics):")
    print("‚îÄ" * 70)

    # Best NDCG.
    if 'NDCG@10' in df.columns:
        best_ndcg_provider_avg = df.groupby('Provider')['NDCG@10'].mean().idxmax()
        best_ndcg_score_avg = df.groupby('Provider')['NDCG@10'].mean().max()
        print(f"  Best Average NDCG@10: {best_ndcg_provider_avg.upper()} - {best_ndcg_score_avg:.4f}")
        # Identifies the single best run for NDCG@10 (may involve a different provider/dataset).
        best_ndcg_idx = df['NDCG@10'].idxmax()
        best_ndcg_row = df.loc[best_ndcg_idx]
        print(f"  Highest Single NDCG@10: {best_ndcg_row['Provider'].upper()} - {best_ndcg_row['NDCG@10']:.4f} on {best_ndcg_row['Dataset']} ({best_ndcg_row['Source']})")

    # Best MAP.
    if 'MAP@10' in df.columns:
        best_map_provider_avg = df.groupby('Provider')['MAP@10'].mean().idxmax()
        best_map_score_avg = df.groupby('Provider')['MAP@10'].mean().max()
        print(f"  Best Average MAP@10: {best_map_provider_avg.upper()} - {best_map_score_avg:.4f}")
        best_map_idx = df['MAP@10'].idxmax()
        best_map_row = df.loc[best_map_idx]
        print(f"  Highest Single MAP@10: {best_map_row['Provider'].upper()} - {best_map_row['MAP@10']:.4f} on {best_map_row['Dataset']} ({best_map_row['Source']})")

    # Best Recall.
    if 'Recall@100' in df.columns:
        best_recall_provider_avg = df.groupby('Provider')['Recall@100'].mean().idxmax()
        best_recall_score_avg = df.groupby('Provider')['Recall@100'].mean().max()
        print(f"  Best Average Recall@100: {best_recall_provider_avg.upper()} - {best_recall_score_avg:.4f}")
        best_recall_idx = df['Recall@100'].idxmax()
        best_recall_row = df.loc[best_recall_idx]
        print(f"  Highest Single Recall@100: {best_recall_row['Provider'].upper()} - {best_recall_row['Recall@100']:.4f} on {best_recall_row['Dataset']} ({best_recall_row['Source']})")

    # Fastest search (using the most relevant mean search time metric).
    if 'Search_Mean_ms' in df.columns: # Prioritizes general search mean.
        fastest_provider_avg = df.groupby('Provider')['Search_Mean_ms'].mean().idxmin()
        fastest_time_avg = df.groupby('Provider')['Search_Mean_ms'].mean().min()
        print(f"  Fastest Average Search: {fastest_provider_avg.upper()} - {fastest_time_avg:.2f}ms")
        fastest_idx = df['Search_Mean_ms'].idxmin()
        fastest_row = df.loc[fastest_idx]
        print(f"  Fastest Single Search: {fastest_row['Provider'].upper()} - {fastest_row['Search_Mean_ms']:.2f}ms on {fastest_row['Dataset']} ({fastest_row['Source']})")
    elif 'Search_Server_Mean_ms' in df.columns: # Falls back to server search mean for providers like ES/Moorcheh.
        fastest_provider_avg = df.groupby('Provider')['Search_Server_Mean_ms'].mean().idxmin()
        fastest_time_avg = df.groupby('Provider')['Search_Server_Mean_ms'].mean().min()
        print(f"  Fastest Average Search (Server-side): {fastest_provider_avg.upper()} - {fastest_time_avg:.2f}ms")
        fastest_idx = df['Search_Server_Mean_ms'].idxmin()
        fastest_row = df.loc[fastest_idx]
        print(f"  Fastest Single Search (Server-side): {fastest_row['Provider'].upper()} - {fastest_row['Search_Server_Mean_ms']:.2f}ms on {fastest_row['Dataset']} ({fastest_row['Source']})")

    # Fastest upload (using the most relevant total upload time metric).
    if 'Upload_Time_s' in df.columns or 'Upload_Total_s' in df.columns: # Checks for either upload column.
        upload_col = 'Upload_Time_s' if 'Upload_Time_s' in df.columns else 'Upload_Total_s'
        fastest_upload_provider_avg = df.groupby('Provider')[upload_col].mean().idxmin()
        fastest_upload_time_avg = df.groupby('Provider')[upload_col].mean().min()
        print(f"  Fastest Average Upload: {fastest_upload_provider_avg.upper()} - {fastest_upload_time_avg:.2f}s")
        fastest_upload_idx = df[upload_col].idxmin()
        fastest_upload_row = df.loc[fastest_upload_idx]
        print(f"  Fastest Single Upload: {fastest_upload_row['Provider'].upper()} - {fastest_upload_row[upload_col]:.2f}s on {fastest_upload_row['Dataset']} ({fastest_upload_row['Source']})")

    # Provides key insights and overall conclusions derived from the benchmark results.
    print("\nüí° Key Insights from the Benchmark:")
    print("‚îÄ" * 70)
    print(f"  ‚Ä¢ Dataset source(s) tested: {', '.join(df['Source'].unique())} (This includes both BEIR and MAIR data if selected)")
    print(f"  ‚Ä¢ Total unique datasets benchmarked: {len(df['Dataset'].unique())}")
    print(f"  ‚Ä¢ Total queries processed across all runs: {df['Num_Queries'].sum():,}")
    print(f"  ‚Ä¢ Total documents indexed across all successful runs: {df['Num_Uploaded'].sum():,}")
    if 'Space_Savings_Pct' in df.columns:
        print(f"  ‚Ä¢ Average space compression achieved by binarization (when stored as float32): {df['Space_Savings_Pct'].mean():.1f}%")
    print(f"  ‚Ä¢ Binarization method used: {df['Binarization_Method'].iloc[0]} (Sign-based is standard for this type of binary embedding)")
    print(f"  ‚Ä¢ All initial embeddings generated using Cohere's {df['Embedding_Model'].iloc[0]} model for consistency.")
    print(f"  ‚Ä¢ This benchmark provides valuable insights into the quality, speed, and space trade-offs when using binary embeddings with different vector databases.")
    print(f"  ‚Ä¢ Remember that actual performance can vary based on specific dataset characteristics, cloud region, network latency, and vector database configuration.")

    # Presents a more detailed provider comparison table if multiple providers were tested.
    if len(df['Provider'].unique()) > 1:
        print("\nüìä Provider Comparison Table (Averages across all datasets for key metrics):")
        print("‚îÄ" * 70)

        providers = sorted(df['Provider'].unique())
        metrics_to_compare = ['NDCG@10', 'MAP@10', 'Recall@100', 'P@10']

        # Constructs the header for the comparison table dynamically based on tested providers.
        header = f"{'Metric':<20}"
        for provider in providers:
            header += f" {provider.capitalize():<15}"
        print(header)
        print("‚îÄ" * (20 + 15 * len(providers)))

        # Populates the metrics rows for the comparison table.
        for metric in metrics_to_compare:
            if metric in df.columns:
                row = f"{metric:<20}"
                for provider in providers:
                    provider_data = df[df['Provider'] == provider]
                    val = provider_data[metric].mean() # Average metric value for the provider.
                    row += f" {val:<15.4f}" # Formats to 4 decimal places.
                print(row)

        # Adds a row for Search Time (using a consolidated mean search time).
        row = f"{'Search Time (ms)':<20}"
        for provider in providers:
            provider_data = df[df['Provider'] == provider]
            if 'Search_Mean_ms' in provider_data.columns: # Prioritizes general search mean.
                val = provider_data['Search_Mean_ms'].mean()
                row += f" {val:<15.2f}"
            elif 'Search_Server_Mean_ms' in provider_data.columns: # Fallback to server time.
                val = provider_data['Search_Server_Mean_ms'].mean()
                row += f" {val:<15.2f}"
            else:
                row += f" {'N/A':<15}" # Indicates if no search time is available.
        print(row)

        # Adds a row for Upload Time (using a consolidated total upload time).
        row = f"{'Upload Time (s)':<20}"
        for provider in providers:
            provider_data = df[df['Provider'] == provider]
            if 'Upload_Time_s' in provider_data.columns: # Moorcheh upload key.
                val = provider_data['Upload_Time_s'].mean()
                row += f" {val:<15.2f}"
            elif 'Upload_Total_s' in provider_data.columns: # Pinecone/ES total upload key.
                val = provider_data['Upload_Total_s'].mean()
                row += f" {val:<15.2f}"
            else:
                row += f" {'N/A':<15}"
        print(row)

    print("\n‚úÖ Benchmark Conclusion and Key Takeaways:")
    print("  ‚Ä¢ This benchmark provides a comparative analysis of retrieval quality and performance for various vector databases using binary embeddings.")
    print("  ‚Ä¢ The results highlight the potential for significant space savings through binarization, along with the trade-offs in retrieval quality and search speed across different vector database implementations.")
    print(f"  ‚Ä¢ The complete, raw results in CSV format can be found at: {CSV_PATH}. This file contains all detailed metrics for further analysis.")
    print("  ‚Ä¢ Consider your specific application needs (e.g., latency, storage costs, retrieval accuracy) when choosing a vector database and embedding strategy.")

else:
    print("\n‚ö†Ô∏è No benchmark results were generated. Please ensure at least one dataset and one provider were successfully processed. Check for API key issues or errors during execution.")

## Vector (Binary) Search with PGVector and PostgreSQL in Google Colab

In [None]:
# ============================================================
# MAIR Benchmark - PGVector Binary Edition (Colab Optimized)
# This notebook section focuses on benchmarking binary embeddings (1-bit)
# using PostgreSQL with the pgvector extension for vector search.
# ============================================================

# -------------------- STEP 1: Install PostgreSQL & pgvector --------------------
print("üîß Installing PostgreSQL and pgvector in Google Colab environment...")
print("This process involves installing system packages and compiling pgvector from source. It will take approximately 2-3 minutes on the first run.")
print("Subsequent runs in the same Colab session will be faster as packages might be cached.")

import os
import subprocess
import time

# Install PostgreSQL server and client utilities.
# `apt-get update -qq`: Updates package lists quietly.
# `apt-get install -y postgresql postgresql-contrib`: Installs PostgreSQL server and extensions.
print("üì¶ Installing PostgreSQL server...")
os.system('apt-get update -qq > /dev/null 2>&1')
os.system('apt-get install -y postgresql postgresql-contrib > /dev/null 2>&1')

# Install build dependencies for pgvector (needed to compile from source).
# `build-essential`: Provides compilers (gcc, g++).
# `git`: To clone the pgvector repository.
# `postgresql-server-dev-14`: Development headers for PostgreSQL 14 (adjust version if needed).
print("üì¶ Installing build tools and PostgreSQL development headers for pgvector compilation...")
os.system('apt-get install -y build-essential git postgresql-server-dev-14 > /dev/null 2>&1')

# Install pgvector extension by cloning its repository and compiling from source.
# This ensures we get the latest features, including BIT type support for binary vectors.
print("üì¶ Cloning and installing pgvector extension from GitHub...")
os.system('cd /tmp && rm -rf pgvector && git clone --quiet https://github.com/pgvector/pgvector.git')
os.system('cd /tmp/pgvector && make > /dev/null 2>&1 && make install > /dev/null 2>&1')

# Start PostgreSQL service.
print("üöÄ Starting PostgreSQL service...")
os.system('service postgresql start > /dev/null 2>&1')
time.sleep(2) # Give the service a moment to fully start up.

# Configure the PostgreSQL database.
# `sudo -u postgres`: Executes commands as the 'postgres' user, which has administrative privileges.
# `DROP DATABASE IF EXISTS vectordb`: Ensures a clean slate for the 'vectordb' database.
# `CREATE DATABASE vectordb`: Creates a new database named 'vectordb' for our vector data.
# `ALTER USER postgres PASSWORD 'postgres'`: Sets a simple password for the default 'postgres' user.
# `CREATE EXTENSION IF NOT EXISTS vector`: Enables the pgvector extension within the 'vectordb'.
print("üîß Configuring 'vectordb' database and enabling pgvector extension...")
os.system('sudo -u postgres psql -c "DROP DATABASE IF EXISTS vectordb;" > /dev/null 2>&1')
os.system('sudo -u postgres psql -c "CREATE DATABASE vectordb;" > /dev/null 2>&1')
os.system('sudo -u postgres psql -c "ALTER USER postgres PASSWORD \'postgres\';" > /dev/null 2>&1')
os.system('sudo -u postgres psql -d vectordb -c "CREATE EXTENSION IF NOT EXISTS vector;" > /dev/null 2>&1')

print("‚úÖ PostgreSQL with pgvector is ready for use!
")

# -------------------- STEP 2: Install Python Dependencies --------------------
# Install Python packages required for connecting to PostgreSQL, generating embeddings, and data analysis.
# `psycopg2-binary`: PostgreSQL adapter for Python.
# `pgvector`: Python client for pgvector.
# `cohere`: For generating embeddings.
# `pandas`, `numpy`, `tqdm`: Standard data science libraries.
print("üì¶ Installing Python packages: psycopg2-binary, pgvector, cohere, pandas, numpy, tqdm...")
os.system('pip install -q psycopg2-binary pgvector cohere pandas numpy tqdm')
print("‚úÖ Python packages installed!
")

# -------------------- STEP 3: Import Libraries --------------------
# Import all necessary Python libraries for the benchmark script.
import gc # For garbage collection to manage memory.
import json # For handling JSONL dataset files.
import statistics # For calculating performance statistics.
import numpy as np # For numerical operations, especially with embeddings.
import pandas as pd # For data manipulation and CSV output.
from tqdm import tqdm # For displaying progress bars during long operations.
import psycopg2 # The Python PostgreSQL adapter.
import psycopg2.extras # For advanced psycopg2 features like execute_values.
from pgvector.psycopg2 import register_vector # Utility to register vector type with psycopg2.

# -------------------- Environment Setup --------------------
# DRIVE_PATH: This variable determines where your benchmark results will be saved.
# MAIR_COMBINED_PATH: Path to the combined MAIR datasets in your Google Drive.
DRIVE_PATH = "."
MAIR_COMBINED_PATH = "/content/drive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Combined"

try:
    # Attempt to import Colab specific modules to check if running in Google Colab.
    from google.colab import drive, userdata as colab_userdata

    try:
        # Mount Google Drive to allow Colab to access your files.
        drive.mount('/content/drive', force_remount=True)
        print("‚úÖ Google Drive mounted successfully.")
    except Exception as e:
        print(f"‚ö†Ô∏è Drive mount warning: {e}. Continuing without Drive mount, results will be saved locally.")

    # Set the DRIVE_PATH for results within Google Drive.
    # You can customize this path to organize your benchmark results.
    DRIVE_PATH = '/content/drive/MyDrive/Moorcheh/Benchmark_Results/MAIR.PGVector.Binary'
    os.makedirs(DRIVE_PATH, exist_ok=True) # Ensure the results directory exists.
    print(f"‚úÖ Running in Colab. Benchmark results will be saved to: {DRIVE_PATH}")

    # Retrieve COHERE_API_KEY from Colab secrets.
    # To use this, you need to add your API key to Colab's "Secrets" feature (look for the key icon üîë on the left sidebar).
    # Give it the name `COHERE_API_KEY`.
    COHERE_API_KEY = colab_userdata.get('COHERE_API_KEY')

except ImportError:
    # This block runs if not in Google Colab (e.g., local environment).
    DRIVE_PATH = "."
    print("‚ö†Ô∏è Not running in Google Colab. Results will be saved locally. Ensure COHERE_API_KEY is set as an environment variable.")
    # Retrieve COHERE_API_KEY from environment variables.
    COHERE_API_KEY = os.environ.get('COHERE_API_KEY')

# PostgreSQL connection parameters.
# These define how the script connects to the local PostgreSQL instance.
# 'host', 'port', 'database', 'user', 'password' are set based on the setup in Step 1.
PG_CONN_PARAMS = {
    'host': 'localhost',
    'port': '5432',
    'database': 'vectordb',
    'user': 'postgres',
    'password': 'postgres'
}

# Configuration parameters. These can be adjusted by the user.
# TOP_K_SEARCH: The number of top results to retrieve from pgvector for each query.
TOP_K_SEARCH = 100
# K_VALUES: A list of 'k' values for which retrieval metrics (NDCG, MAP, Recall, Precision) will be calculated.
K_VALUES = [1, 3, 5, 10, 100]
# MAX_UPLOAD_DOCS: Limits the number of documents uploaded to pgvector.
# Useful for testing with very large datasets to manage execution time.
MAX_UPLOAD_DOCS = 700000
# BATCH_SIZE: Number of embeddings to process or upload in a single request/batch.
# Adjusting this can impact performance and memory usage.
BATCH_SIZE = 100
# EMBEDDING_MODEL: The Cohere model used to generate embeddings.
EMBEDDING_MODEL = "embed-v4.0"
# INPUT_TYPE_CORPUS: Input type for corpus documents when generating Cohere embeddings.
INPUT_TYPE_CORPUS = "search_document"
# INPUT_TYPE_QUERY: Input type for queries when generating Cohere embeddings.
INPUT_TYPE_QUERY = "search_query"
# VECTOR_DIMENSION: The dimensionality of the embeddings generated by Cohere.
# This defines the length of the `BIT` string in PostgreSQL.
VECTOR_DIMENSION = 1536
# CSV_PATH: The path where the final benchmark results will be saved in CSV format.
CSV_PATH = os.path.join(DRIVE_PATH, "MAIR.PGVector.Binary.Cohere.V4.csv")

# Cleanup Policy for PostgreSQL tables.
# Users can choose how the benchmark handles the created pgvector tables after each dataset run.
print("\nüßπ PGVector table cleanup policy options:")
print("  1) ask_each_time  -> Prompt after each dataset (default behaviour)")
print("  2) always_delete  -> Automatically delete tables after benchmarking each dataset")
print("  3) always_keep    -> Never delete tables (keep for future searches or debugging)")

cleanup_choice = input("‚û°Ô∏è Choose a cleanup policy [1/2/3] (default is 1): ").strip() or "1"
CLEANUP_POLICY_MAP = {"1": "ask_each_time", "2": "always_delete", "3": "always_keep"}
CLEANUP_POLICY = CLEANUP_POLICY_MAP.get(cleanup_choice, "ask_each_time")
print(f"‚úÖ Selected cleanup policy: {CLEANUP_POLICY}")

# -------------------- Helper Functions --------------------
# These functions facilitate data loading, processing, and metric calculation.

def load_jsonl(filepath):
    """Load JSONL file into a dictionary, handling various ID keys (e.g., _id, id, query_id, doc_id)."""
    data = {}
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip(): # Skip empty lines
                    item = json.loads(line)
                    item_id = item.get('_id') or item.get('id') or item.get('query_id') or item.get('doc_id')
                    if item_id:
                        data[str(item_id)] = item
                    else:
                        data[str(len(data))] = item # Fallback if no ID is found
        print(f"   Loaded {len(data)} items from {os.path.basename(filepath)}")
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading {filepath}: {e}")
    return data

def extract_qrels_from_queries(queries):
    """Extract qrels (relevance judgments) from the 'labels' field in queries, supporting various formats."""
    qrels = {}
    for qid, q in queries.items():
        if isinstance(q, dict):
            labels = q.get('labels') or q.get('relevance') or q.get('qrels') # Check common qrels field names
            if labels:
                qrels[str(qid)] = {} # Initialize qrels for the current query
                if isinstance(labels, list):
                    for label_item in labels:
                        if isinstance(label_item, dict):
                            doc_id = label_item.get('id') or label_item.get('doc_id')
                            score = label_item.get('score', 1) # Default score to 1 if not provided
                            if doc_id:
                                qrels[str(qid)][str(doc_id)] = score
                        else:
                            qrels[str(qid)][str(label_item)] = 1 # Assume score of 1 if just a list of doc_ids
                elif isinstance(labels, dict):
                    for doc_id, score in labels.items():
                        qrels[str(qid)][str(doc_id)] = score
    return qrels

def binarize_embeddings(embeddings):
    """
    Binarize float embeddings using a sign-based method (>=0 -> 1, <0 -> 0).
    This function specifically converts the resulting binary array into a list of strings
    (e.g., '10100...') suitable for PostgreSQL's `BIT` data type in pgvector.
    Returns: A tuple containing (list of binary strings, original float embeddings size in MB).
    """
    float_embeddings = np.array(embeddings, dtype=np.float32) # Convert input to numpy array of float32
    # Calculate size of original float embeddings for space efficiency comparison.
    float_size_mb = (float_embeddings.nbytes) / (1024*1024)

    # Convert to boolean array (True for >=0, False for <0).
    binary_bool = (float_embeddings >= 0)

    # Efficiently convert each row (vector) of the boolean array into a string of '1's and '0's.
    # This format is required for the `BIT` column type in PostgreSQL.
    binary_strings = []
    for row in binary_bool:
        binary_strings.append("".join(row.astype(int).astype(str))) # Convert bool to int (0/1), then to string, then join.

    return binary_strings, float_size_mb

# Predefined categories for MAIR datasets for structured display.
DATASET_CATEGORIES = {
    "Legal & Regulatory": ["ACORDAR", "AILA2019-Case", "AILA2019-Statutes", "CUAD", "LeCaRDv2", "LegalQuAD", "REGIR-EU2UK", "REGIR-UK2EU"],
    "Medical & Clinical": ["CliniDS-2014", "CliniDS-2015", "CliniDS-2016", "ClinicalTrials-2021", "ClinicalTrials-2022", "ClinicalTrials-2023", "NFCorpus", "PrecisionMedicine", "Genomics-AdHoc"],
    "Code & Programming": ["APPS", "CodeEditSearch", "CodeSearchNet", "Conala", "HumanEval-X", "LeetCode", "MBPP", "RepoBench", "SWE-Bench-Lite"],
    "Financial": ["ConvFinQA", "FiQA", "FinQA", "FinanceBench", "HC3Finance"],
    "Academic & Scientific": ["ArguAna", "LitSearch", "ProofWiki-Proof", "ProofWiki-Reference", "Competition-Math"],
    "Conversational & Dialog": ["CAsT-2019", "CAsT-2020", "CAsT-2021", "CAsT-2022", "ProCIS-Dialog", "ProCIS-Turn", "SParC", "Quora"],
    "News & Social Media": ["ChroniclingAmericaQA", "Microblog-2011", "Microblog-2012", "Microblog-2013", "Microblog-2014", "News21"],
    "API Documentation": ["Apple", "FoodAPI", "HuggingfaceAPI", "PytorchAPI"],
    "Others": ["BSARD", "BillSum", "CARE", "CPCD", "CQADupStack", "DD", "ELI5", "ExcluIR", "FairRanking", "Fever", "GerDaLIR", "IFEval", "InstructIR", "MISeD", "Monant", "NTCIR", "NeuCLIR", "NevIR", "PointRec", "ProductSearch_2023", "QuanTemp", "Robust04"]
}

def get_dataset_category(dataset_name):
    """Assigns a given dataset name to its predefined category."""
    for category, datasets in DATASET_CATEGORIES.items():
        if dataset_name in datasets:
            return category
    return "Others" # Default category if not found in specific lists

def get_dataset_size(dataset_name):
    """Counts the number of documents in a MAIR dataset by reading its JSONL files."""
    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    if not os.path.exists(docs_path):
        return 0
    doc_count = 0
    try:
        docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
        for file in docs_files:
            file_path = os.path.join(docs_path, file)
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f: # Iterate through lines to count documents
                    if line.strip(): # Count non-empty lines as documents
                        doc_count += 1
    except Exception as e:
        print(f"‚ö†Ô∏è Error counting documents in {dataset_name}: {e}")
    return doc_count

def format_size(num):
    """Formats a number (document count) with K (thousands) or M (millions) suffix for readability."""
    if num >= 1_000_000:
        return f"{num/1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num/1_000:.1f}K"
    return str(num)

def get_mair_datasets():
    """Discovers available MAIR datasets in the combined path, categorizes them, and gets their sizes."""
    datasets_by_category = {category: [] for category in DATASET_CATEGORIES.keys()}
    all_datasets = []
    dataset_sizes = {}
    if os.path.exists(MAIR_COMBINED_PATH): # Check if the combined MAIR path exists
        print(f"üîç Scanning for MAIR datasets in: {MAIR_COMBINED_PATH}")
        try:
            items = os.listdir(MAIR_COMBINED_PATH)
            for item in sorted(items): # Iterate through directories, sorted alphabetically
                if item.startswith('.') or item.startswith('_'): # Skip hidden/system files
                    continue
                dataset_path = os.path.join(MAIR_COMBINED_PATH, item)
                if not os.path.isdir(dataset_path): # Ensure it's a directory
                    continue
                docs_path = os.path.join(dataset_path, 'docs')
                queries_path = os.path.join(dataset_path, 'queries')
                # Verify that both 'docs' and 'queries' subdirectories exist and contain .jsonl files
                if os.path.exists(docs_path) and os.path.exists(queries_path):
                    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
                    queries_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
                    if docs_files and queries_files:
                        all_datasets.append(item)
                        category = get_dataset_category(item) # Get assigned category
                        datasets_by_category[category].append(item)
                        doc_count = get_dataset_size(item) # Get document count
                        dataset_sizes[item] = doc_count
                        print(f"   ‚úÖ Found dataset: {item} [Category: {category}] - {format_size(doc_count)} docs")
        except Exception as e:
            print(f"   ‚ùå Error scanning directory {MAIR_COMBINED_PATH}: {e}")
    return all_datasets, datasets_by_category, dataset_sizes

def calculate_timing_stats(timing_list):
    """Calculates basic statistics (mean, median, min, max, std, total) for a list of numerical timings."""
    if not timing_list:
        return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0, "total": 0.0}
    return {
        "mean": statistics.mean(timing_list),
        "median": statistics.median(timing_list),
        "min": min(timing_list),
        "max": max(timing_list),
        "std": statistics.stdev(timing_list) if len(timing_list) > 1 else 0.0, # Standard deviation requires at least two samples
        "total": sum(timing_list)
    }

def calculate_retrieval_metrics(retrieved_results, qrels, k_values=K_VALUES):
    """Calculates NDCG, MAP, Recall, and Precision @K for given retrieval results and ground truth qrels."""
    ndcg_scores = {f"NDCG@{k}": [] for k in k_values}
    map_scores = {f"MAP@{k}": [] for k in k_values}
    recall_scores = {f"Recall@{k}": [] for k in k_values}
    precision_scores = {f"P@{k}": [] for k in k_values}

    for qid in retrieved_results:
        if qid not in qrels: # Skip if no qrels for this query
            continue
        relevant_docs = set(qrels[qid].keys()) # Set of relevant document IDs for this query
        retrieved_docs = list(retrieved_results[qid].keys()) # Ordered list of retrieved document IDs
        if len(relevant_docs) == 0: # Cannot calculate metrics if no relevant docs exist
            continue

        for k in k_values:
            top_k = retrieved_docs[:k] # Consider only top K retrieved documents
            hits = len(set(top_k) & relevant_docs) # Count relevant documents in top K

            # Recall@k: Proportion of relevant documents found in the top K.
            recall = hits / len(relevant_docs) if len(relevant_docs) > 0 else 0.0
            recall_scores[f"Recall@{k}"].append(recall)

            # Precision@k: Proportion of retrieved documents that are relevant in the top K.
            precision = hits / k if k > 0 else 0.0
            precision_scores[f"P@{k}"].append(precision)

            # Average Precision (AP@k): Sum of (Precision@i * relevance_at_i) / num_relevant.
            ap = 0.0
            hits_so_far = 0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    hits_so_far += 1
                    ap += hits_so_far / i
            # Normalize AP by the minimum of actual relevant docs or k. Avoid division by zero.
            ap = ap / min(len(relevant_docs), k) if len(relevant_docs) > 0 else 0.0
            map_scores[f"MAP@{k}"].append(ap)

            # Normalized Discounted Cumulative Gain (NDCG@k): Measures ranking quality.
            dcg = 0.0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    dcg += 1.0 / np.log2(i + 1) # Logarithmic discount for lower-ranked documents

            # Ideal DCG (IDCG): DCG of a perfect ranking.
            # Assumes all relevant docs have score 1.
            idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), k))) # +2 because index starts at 0, rank at 1
            ndcg = dcg / idcg if idcg > 0 else 0.0
            ndcg_scores[f"NDCG@{k}"].append(ndcg)

    # Calculate average metrics across all queries.
    ndcg_avg = {k: np.mean(v) if v else 0.0 for k, v in ndcg_scores.items()}
    map_avg = {k: np.mean(v) if v else 0.0 for k, v in map_scores.items()}
    recall_avg = {k: np.mean(v) if v else 0.0 for k, v in recall_scores.items()}
    precision_avg = {k: np.mean(v) if v else 0.0 for k, v in precision_scores.items()}
    return ndcg_avg, map_avg, recall_avg, precision_avg

def format_and_print_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Prints retrieval metrics (NDCG, MAP, Recall, Precision) in a formatted, readable table."""
    print("\nRetrieval Metrics:")
    print("‚îÄ" * 80)
    for k in k_values:
        print(f"NDCG@{k}: {ndcg.get(f'NDCG@{k}', 0.0):.4f} | "
              f"MAP@{k}: {_map.get(f'MAP@{k}', 0.0):.4f} | "
              f"Recall@{k}: {recall.get(f'Recall@{k}', 0.0):.4f} | "
              f"P@{k}: {precision.get(f'P@{k}', 0.0):.4f}")

def extract_all_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Extracts all calculated retrieval metrics into a single dictionary for consistent storage (e.g., in CSV)."""
    metrics = {}
    for k in k_values:
        metrics[f"NDCG@{k}"] = float(ndcg.get(f"NDCG@{k}", 0.0))
        metrics[f"MAP@{k}"] = float(_map.get(f"MAP@{k}", 0.0))
        metrics[f"Recall@{k}"] = float(recall.get(f"Recall@{k}", 0.0))
        metrics[f"P@{k}"] = float(precision.get(f"P@{k}", 0.0))
    return metrics

def save_results_to_csv(new_result: dict, csv_path: str):
    """Appends a new benchmark result entry to a CSV file. Creates the file and header if it doesn't exist."""
    new_df = pd.DataFrame([new_result]) # Convert the new result dictionary into a Pandas DataFrame row
    os.makedirs(os.path.dirname(csv_path), exist_ok=True) # Ensure the directory for the CSV file exists
    write_header = not os.path.exists(csv_path) # Check if the CSV file already exists to decide if a header is needed
    try:
        if write_header:
            new_df.to_csv(csv_path, mode='w', header=True, index=False) # Write to a new file with header
        else:
            existing_df = pd.read_csv(csv_path) # Read existing data
            combined_df = pd.concat([existing_df, new_df], ignore_index=True) # Concatenate new data
            combined_df.to_csv(csv_path, mode='w', header=True, index=False) # Overwrite with combined data (maintaining header)
        print(f"üíæ Results saved to: {csv_path}")
    except Exception as e:
        print(f"‚ùå CSV save failed: {e}. Please check file permissions or path validity.")

def clean_memory():
    """Forces Python's garbage collector to free up unreferenced objects, crucial in Colab to prevent OOM errors."""
    gc.collect()

def should_cleanup_table(table_name, dataset_name):
    """Determines whether to delete the PGVector table based on the user-selected cleanup policy."""
    if CLEANUP_POLICY == "always_delete":
        return True # Automatically delete without asking
    elif CLEANUP_POLICY == "always_keep":
        return False # Never delete, keep the table
    else: # CLEANUP_POLICY == "ask_each_time"
        response = input(f"\n‚ùì Delete PGVector table '{table_name}' for dataset '{dataset_name}'? (y/n): ").strip().lower()
        return response in ['y', 'yes'] # Ask user for confirmation

# -------------------- PGVector Provider Class (BINARY) --------------------

class PGVectorBinaryProvider:
    """Manages interaction with PostgreSQL and pgvector for binary embedding benchmarking."""
    def __init__(self, conn_params, table_name, precomputed_vectors, query_embeddings, vector_dim=1536):
        self.conn_params = conn_params # Dictionary of PostgreSQL connection parameters
        self.table_name = table_name # Name of the table to create/use in PostgreSQL
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and binary vector strings
        self.query_embeddings = query_embeddings # List of query binary vector strings
        self.vector_dim = vector_dim # Dimension of the vectors (length of the BIT string)
        self.conn = None # Placeholder for the database connection object
        self.upload_timings = {
            "table_creation_s": 0.0, # Time to create the database table
            "insert_time_s": 0.0, # Total time for inserting vectors
            "index_creation_s": 0.0, # Time to create the HNSW index
            "batch_details": [] # Detailed timings for each insert batch
        }
        self.search_timings = [] # Timings for individual search queries

    def connect(self):
        """Establishes a connection to the PostgreSQL database using the provided parameters."""
        try:
            self.conn = psycopg2.connect(**self.conn_params) # Connect to DB
            register_vector(self.conn) # Register the pgvector extension with the psycopg2 connection
            print(f"‚úÖ Connected to PostgreSQL database: {self.conn_params['database']}")
        except Exception as e:
            print(f"‚ùå PostgreSQL connection failed: {e}. Please check PG_CONN_PARAMS and server status.")
            raise # Re-raise to stop execution if connection fails

    def upload(self):
        """Uploads binary vectors to the PGVector table using the `BIT` data type."""
        if not self.conn: # Ensure connection is established
            self.connect()

        cur = self.conn.cursor() # Create a cursor object for executing SQL commands

        # Drop table if it exists to ensure a clean benchmark run.
        try:
            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
            self.conn.commit() # Commit the transaction
            print(f"üóëÔ∏è Dropped existing table: {self.table_name} (if it existed).")
        except Exception as e:
            self.conn.rollback() # Rollback on error
            print(f"‚ö†Ô∏è Could not drop table {self.table_name}: {e}")

        # Create table with BIT type instead of regular vector type for binary embeddings.
        # `bit(N)` specifies a fixed-length bit string of N bits.
        t0 = time.time() # Start timer for table creation
        try:
            cur.execute("CREATE EXTENSION IF NOT EXISTS vector") # Ensure pgvector extension is enabled
            cur.execute(f"""
                CREATE TABLE {self.table_name} (
                    id TEXT PRIMARY KEY,                       -- Document ID
                    embedding bit({self.vector_dim})           -- Binary embedding stored as a bit string
                )
            """)
            self.conn.commit() # Commit the table creation
            t1 = time.time() # End timer
            self.upload_timings["table_creation_s"] = t1 - t0
            print(f"‚úÖ Created table: {self.table_name} with column 'embedding' of type bit({self.vector_dim})")
        except Exception as e:
            print(f"‚ùå Table creation failed: {e}. Check PostgreSQL logs for details.")
            self.conn.rollback() # Rollback on error
            raise

        # Insert vectors in batches using `execute_values` for efficiency.
        batch_num = 0
        total_inserted = 0
        t_insert_start = time.time() # Start timer for total insert time

        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading Binary Vectors to PGVector"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Get a batch of vectors
            batch_num += 1

            t_batch_start = time.time() # Start timer for current batch
            try:
                # Prepare data for `execute_values`: list of tuples (id, vector_string).
                values = [(v['id'], v['vector']) for v in batch]
                # `psycopg2.extras.execute_values` is much faster than individual INSERTs.
                # `template="(%s, %s::bit varying)"` specifies how the values are formatted for the SQL query.
                # `::bit varying` casts the string to a bit string type, ensuring correct interpretation.
                psycopg2.extras.execute_values(
                    cur,
                    f"INSERT INTO {self.table_name} (id, embedding) VALUES %s",
                    values,
                    template="(%s, %s::bit varying)"
                )
                self.conn.commit() # Commit the batch insert

                t_batch_end = time.time() # End timer for current batch
                batch_time = t_batch_end - t_batch_start
                total_inserted += len(batch)

                self.upload_timings["batch_details"].append({
                    "batch_num": batch_num,
                    "batch_size": len(batch),
                    "insert_time_s": batch_time
                })

            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed to insert to PGVector: {e}")
                self.conn.rollback() # Rollback the current batch on error

        t_insert_end = time.time() # End timer for total insert time
        self.upload_timings["insert_time_s"] = t_insert_end - t_insert_start

        # Create HNSW index using Hamming distance (`bit_hamming_ops`).
        # HNSW (Hierarchical Navigable Small World) is a popular approximate nearest neighbor (ANN) algorithm.
        # `bit_hamming_ops` tells pgvector to use Hamming distance for this index type.
        # `m` and `ef_construction` are HNSW parameters influencing index quality and build time.
        print(f"\nüîß Creating HNSW index on 'embedding' column using Hamming distance ('bit_hamming_ops')...")
        t_index_start = time.time() # Start timer for index creation
        try:
            cur.execute(f"""
                CREATE INDEX ON {self.table_name}
                USING hnsw (embedding bit_hamming_ops)
                WITH (m = 16, ef_construction = 64)
            """) # SQL command to create HNSW index
            self.conn.commit() # Commit the index creation
            t_index_end = time.time() # End timer
            self.upload_timings["index_creation_s"] = t_index_end - t_index_start
            print(f"‚úÖ HNSW (Hamming) index created in {self.upload_timings['index_creation_s']:.4f}s")
        except Exception as e:
            print(f"‚ö†Ô∏è Index creation failed: {e}. Indexing is important for search performance. Proceeding without index.")
            self.conn.rollback() # Rollback on error

        cur.close() # Close the cursor

        print(f"\n‚è±Ô∏è  PGVector Upload Timing Summary:")
        print(f"    Table Creation: {self.upload_timings['table_creation_s']:.4f}s")
        print(f"    Total Data Insert: {self.upload_timings['insert_time_s']:.4f}s")
        print(f"    Index Creation: {self.upload_timings['index_creation_s']:.4f}s")
        print(f"    Total Upload Process Time: {sum([self.upload_timings['table_creation_s'], self.upload_timings['insert_time_s'], self.upload_timings['index_creation_s']]):.4f}s")

        return total_inserted # Return count of successfully inserted vectors

    def search(self, query_idx, top_k=100):
        """Performs a vector search for similar vectors using Hamming distance in pgvector."""
        if not self.conn: # Ensure connection is established
            self.connect()

        query_embedding_str = self.query_embeddings[query_idx] # Get the binary query vector string
        cur = self.conn.cursor() # Create a cursor

        t_start = time.time() # Start timer for search query
        try:
            # Use `<~>` operator for Hamming distance comparison in pgvector.
            # The query string `bit varying` casts the query string to a bit string for comparison.
            # We select `id` and the Hamming distance as `hamming_dist`.
            # Results are ordered by `hamming_dist` (lower distance = more similar).
            cur.execute(f"""
                SELECT id, (embedding <~> %s) as hamming_dist
                FROM {self.table_name}
                ORDER BY embedding <~> %s
                LIMIT %s
            """, (query_embedding_str, query_embedding_str, top_k))

            results = cur.fetchall() # Fetch all results
            t_end = time.time() # End timer

            search_time = t_end - t_start
            self.search_timings.append({"search_time_s": search_time}) # Record search time

            # Convert Hamming distance to a similarity score [0, 1] for evaluation compatibility.
            # A common way is: Score = 1 - (HammingDistance / VectorDimension).
            formatted_results = {}
            for row in results:
                doc_id = str(row[0])
                dist = float(row[1])
                score = 1.0 - (dist / self.vector_dim) # Calculate similarity score
                formatted_results[doc_id] = score

            return formatted_results

        except Exception as e:
            print(f"\n‚ùå PGVector search error for query {query_idx}: {e}")
            self.search_timings.append({"search_time_s": 0.0}) # Log 0 time for failed searches
            return {} # Return empty results on error
        finally:
            cur.close() # Always close the cursor

    def get_search_stats(self):
        """Calculates and returns search timing statistics for PGVector."""
        search_times = [t["search_time_s"] for t in self.search_timings] # Extract all search times
        return calculate_timing_stats(search_times) # Use helper to calculate stats

    def cleanup(self):
        """Drops the PostgreSQL table created for the benchmark and closes the database connection."""
        if not self.conn: # If no connection, nothing to clean up
            return

        try:
            cur = self.conn.cursor() # Create cursor
            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") # Drop table
            self.conn.commit() # Commit transaction
            cur.close() # Close cursor
            print(f"üßπ Deleted PGVector table: {self.table_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete PGVector table '{self.table_name}': {e}. You may need to drop it manually.")
        finally:
            self.conn.close() # Always close the database connection
            self.conn = None

# -------------------- Main Benchmark Function --------------------
# This function orchestrates the entire benchmarking process for a single dataset.

def run_benchmark_pgvector(dataset_name, pg_conn_params):
    """Runs a full benchmark for PGVector with binary embeddings for a specified MAIR dataset."""
    print(f"\n{'='*70}")
    print(f"üöÄ Starting PGVector Binary Benchmark for Dataset: {dataset_name}")
    print(f"üìä Using Cohere {EMBEDDING_MODEL} embeddings binarized for pgvector BIT type.")
    print(f"{('='*70})")

    # Construct paths to the dataset within the combined MAIR directory.
    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    queries_path = os.path.join(dataset_path, 'queries')

    # Verify dataset files exist before proceeding.
    if not os.path.exists(docs_path) or not os.path.exists(queries_path):
        print(f"‚ö†Ô∏è Dataset '{dataset_name}' not found at {dataset_path}. Skipping this benchmark run.")
        return None

    print(f"\nüì• Loading MAIR dataset: {dataset_name} from {dataset_path}...")
    corpus = {}
    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
    for file in docs_files:
        corpus.update(load_jsonl(os.path.join(docs_path, file))) # Load document JSONL files

    queries = {}
    query_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
    for file in query_files:
        queries.update(load_jsonl(os.path.join(queries_path, file))) # Load query JSONL files

    qrels = extract_qrels_from_queries(queries) # Extract relevance judgments from queries
    print(f"‚úÖ Loaded: {len(corpus)} documents, {len(queries)} queries, {len(qrels)} qrels.")

    if not corpus or not queries or not qrels: # Early exit if essential data is missing
        print(f"‚ö†Ô∏è Essential data (corpus, queries, or qrels) is empty for {dataset_name}. Skipping benchmark.")
        return None

    print(f"\nüß† Generating float embeddings with Cohere {EMBEDDING_MODEL} (Dimension: {VECTOR_DIMENSION}D)...")
    # Limit the number of documents to embed and upload, based on MAX_UPLOAD_DOCS configuration.
    docs = list(corpus.items())[:MAX_UPLOAD_DOCS]
    print(f"  Processing up to {len(docs)} documents for embedding (limited by MAX_UPLOAD_DOCS={MAX_UPLOAD_DOCS}).")

    # Extract text content from corpus documents for embedding.
    texts = []
    for doc_id, doc_content in docs:
        if isinstance(doc_content, dict):
            text = None
            for field in ['doc', 'text', 'content', 'body', 'passage', 'document', 'title', 'abstract']:
                if field in doc_content and doc_content[field]:
                    val = doc_content[field]
                    if isinstance(val, str) and val.strip():
                        text = val
                        break
            if not text: # Fallback for text extraction if primary fields are empty
                for key, val in doc_content.items():
                    if isinstance(val, str) and val.strip() and len(val) > 10:
                        text = val
                        break
            if not text:
                text = str(doc_content) # Last resort, convert entire dict to string
        else:
            text = str(doc_content)

        texts.append(text if text else "document") # Default text if extraction yields empty string

    doc_ids = [str(d[0]) for d in docs] # Ensure all document IDs are strings

    print(f"  üìÑ Generating corpus embeddings for {len(texts)} documents...")
    t_embed_start = time.time() # Start timer for corpus embedding

    corpus_embeddings_float = []
    # Batch processing for Cohere API calls for corpus embeddings.
    for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Embedding corpus documents"):
        batch_texts = texts[i:i+BATCH_SIZE]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_CORPUS)
        corpus_embeddings_float.extend(response.embeddings)

    t_embed_end = time.time() # End timer
    embedding_time = t_embed_end - t_embed_start

    print(f"  üîç Generating query embeddings for {len(queries)} queries...")
    query_ids = list(queries.keys())
    query_texts = []
    for qid in query_ids:
        q = queries[qid]
        if isinstance(q, dict):
            text = (q.get('text') or q.get('query') or q.get('instruction') or q.get('question') or q.get('query_text') or str(q))
        else:
            text = str(q)
        query_texts.append(text if text else "query")

    t_query_start = time.time() # Start timer for query embedding
    query_embeddings_float = []
    # Batch processing for Cohere API calls for query embeddings.
    for i in tqdm(range(0, len(query_texts), BATCH_SIZE), desc="Embedding queries"):
        batch_texts = [str(t) if t else "query" for t in query_texts[i:i+BATCH_SIZE]]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_QUERY)
        query_embeddings_float.extend(response.embeddings)

    t_query_end = time.time() # End timer
    query_embedding_time = t_query_end - t_query_start

    # ------------------ BINARIZATION ------------------
    print(f"\nüîÑ Binarizing float embeddings (Sign-based method: >=0 -> '1', <0 -> '0')...")
    t_bin_start = time.time() # Start timer for binarization

    corpus_binary_strs, float_size_mb = binarize_embeddings(corpus_embeddings_float) # Binarize corpus embeddings into strings
    query_binary_strs, _ = binarize_embeddings(query_embeddings_float) # Binarize query embeddings into strings

    t_bin_end = time.time() # End timer
    binarize_time = t_bin_end - t_bin_start

    # Clean up large float embedding arrays from memory as they are no longer needed.
    # This is important to free up RAM, especially in Colab.
    del corpus_embeddings_float
    del query_embeddings_float
    gc.collect() # Force garbage collection

    print(f"    Original Float Embeddings Size: {float_size_mb:.2f} MB")
    print(f"    Binarization to Bit Strings Time: {binarize_time:.4f}s")
    print(f"    Note: `BIT` type in PostgreSQL natively stores 1 bit per dimension, offering significant storage savings.")
    # --------------------------------------------------

    # Prepare precomputed vectors in the format expected by the PGVector provider.
    precomputed_vectors_for_upload = []
    for doc_id, bin_vec_str in zip(doc_ids, corpus_binary_strs):
        precomputed_vectors_for_upload.append({"id": doc_id, "vector": bin_vec_str})

    print(f"\n‚è±Ô∏è  Total Embedding & Binarization Preparation Time:")
    print(f"    Float Embedding Generation: {embedding_time + query_embedding_time:.4f}s")
    print(f"    Binary String Conversion: {binarize_time:.4f}s")

    # Generate a safe table name for PostgreSQL from the dataset name.
    safe_dataset_name = dataset_name.replace('/', '_').replace('-', '_').lower()
    table_name = f"mair_bin_{safe_dataset_name}"[:63] # PostgreSQL table names have a max length of 63 characters.

    try:
        # Initialize the PGVector provider class.
        provider = PGVectorBinaryProvider(
            conn_params=pg_conn_params,
            table_name=table_name,
            precomputed_vectors=precomputed_vectors_for_upload,
            query_embeddings=query_binary_strs,
            vector_dim=VECTOR_DIMENSION
        )

        print(f"\nüì§ Uploading binary vectors to PGVector table: {table_name}...")
        num_uploaded = provider.upload() # Upload vectors to PostgreSQL
        if num_uploaded == 0: # If no documents were uploaded, skip search and evaluation.
            print(f"‚ö†Ô∏è No documents uploaded to PGVector. Skipping search and evaluation.")
            # Proceed to cleanup, but don't try to search.
            should_delete = should_cleanup_table(table_name, dataset_name)
            if should_delete:
                provider.cleanup()
            else:
                print(f"üíæ Keeping PGVector table: {table_name}. Remember to delete it manually if no longer needed.")
            del provider
            clean_memory()
            return None

        print(f"\nüîç Performing binary vector search (Hamming distance) in PGVector...")
        results_per_query = {} # Dictionary to store retrieved results {query_id: {doc_id: score}}
        for i, qid in enumerate(tqdm(query_ids, desc="Searching PGVector")):
            try:
                results_per_query[qid] = provider.search(i, top_k=TOP_K_SEARCH)
            except Exception as e:
                print(f"\n‚ùå Search for query {i+1} failed: {e}")
                results_per_query[qid] = {} # Log empty results for failed queries

        search_stats = provider.get_search_stats() # Get search timing statistics
        print(f"\n‚è±Ô∏è  PGVector Search Timing Summary:")
        print(f"    Total search time across all queries: {search_stats['total']:.4f}s")
        print(f"    Mean query search time: {search_stats['mean']:.4f}s ({search_stats['mean']*1000:.2f}ms)")
        print(f"    Median query search time: {search_stats['median']:.4f}s")
        print(f"    Min query search time: {search_stats['min']:.4f}s, Max query search time: {search_stats['max']:.4f}s")

        print(f"\nüìä Evaluating retrieval quality (NDCG, MAP, Recall, Precision @K={K_VALUES})...")
        ndcg, _map, recall, precision = calculate_retrieval_metrics(results_per_query, qrels, K_VALUES) # Calculate metrics
        format_and_print_metrics(ndcg, _map, recall, precision) # Print formatted metrics

        metrics = extract_all_metrics(ndcg, _map, recall, precision, K_VALUES) # Extract all metrics into a dictionary
        dataset_category = get_dataset_category(dataset_name) # Get dataset category for result entry

        result_entry = { # Compile all results and metadata into a dictionary
            "Dataset": dataset_name,
            "Category": dataset_category,
            "Provider": "pgvector_binary", # Indicate provider
            "Num_Corpus": len(corpus),
            "Num_Uploaded": num_uploaded,
            "Num_Queries": len(queries),
            "Embedding_Model": EMBEDDING_MODEL,
            "Vector_Dimension": VECTOR_DIMENSION,
            "Embedding_Type": "Binary (1-bit)", # Explicitly state embedding type
            "Batch_Size": BATCH_SIZE,
            "Cleanup_Policy": CLEANUP_POLICY,
            "Embedding_Generation_Time_s": round(embedding_time, 4), # Time to generate float embeddings
            "Query_Embedding_Generation_Time_s": round(query_embedding_time, 4), # Time to generate query float embeddings
            "Binarization_Time_s": round(binarize_time, 4), # Time to convert to binary strings
            "Upload_Table_Creation_s": round(provider.upload_timings["table_creation_s"], 4),
            "Upload_Insert_Total_s": round(provider.upload_timings["insert_time_s"], 4),
            "Upload_Index_Creation_s": round(provider.upload_timings["index_creation_s"], 4),
            "Upload_Total_s": round(
                provider.upload_timings["table_creation_s"] +
                provider.upload_timings["insert_time_s"] +
                provider.upload_timings["index_creation_s"], 4
            ),
            "Search_Total_s": round(search_stats['total'], 4),
            "Search_Mean_s": round(search_stats['mean'], 4),
            "Search_Median_s": round(search_stats['median'], 4),
            "Search_Min_s": round(search_stats['min'], 4),
            "Search_Max_s": round(search_stats['max'], 4),
            "Search_Std_s": round(search_stats['std'], 4),
            "Search_Mean_ms": round(search_stats['mean'] * 1000, 2), # Mean search time in milliseconds
            "Index_Type": "HNSW (bit_hamming_ops)", # Type of index used
            "Similarity_Metric": "Hamming" # Similarity metric used for search
        }

        result_entry.update(metrics) # Add the retrieval metrics to the result dictionary
        save_results_to_csv(result_entry, CSV_PATH) # Save this run's results to the CSV file

        should_delete = should_cleanup_table(table_name, dataset_name) # Determine cleanup action based on policy
        if should_delete:
            provider.cleanup() # Delete table and close connection
        else:
            print(f"üíæ Keeping PGVector table: {table_name}. Remember to delete it manually if no longer needed to avoid resource consumption.")

        del provider # Explicitly delete provider object
        del results_per_query # Clear search results
        clean_memory() # Force garbage collection

        print(f"\n‚úÖ PGVector Binary benchmark completed successfully for dataset: {dataset_name}")
        return result_entry # Return the result entry for this dataset

    except Exception as e:
        print(f"‚ùå An error occurred during the PGVector benchmark for dataset {dataset_name}: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging
        clean_memory()
        return None # Return None on error
    finally:
        # Clean up large objects from memory regardless of success or failure.
        del precomputed_vectors_for_upload
        del corpus
        clean_memory()


# -------------------- MAIN EXECUTION BLOCK --------------------
# This block handles the overall flow of the benchmark, including API key checks,
# dataset discovery, and the interactive loop for running benchmarks.

print("\n" + "="*70)
print("üéØ STARTING MAIR BENCHMARK - PGVECTOR BINARY EDITION")
print("="*70)

print("\nüîë Checking API Keys and Database Connection Status...")

# Check Cohere API key availability.
if not COHERE_API_KEY:
    print("\n‚ùå Cohere API key required for embedding generation!")
    print("   Please add your API key to Colab Secrets (look for the key icon üîë on the left sidebar)")
    print("   and name it 'COHERE_API_KEY'. Alternatively, set it as an environment variable.")
    exit(1) # Stop execution if key is missing

import cohere
cohere_client = cohere.Client(COHERE_API_KEY) # Initialize Cohere client.
print(f"‚úÖ Cohere client initialized for embedding generation (model: {EMBEDDING_MODEL}).")

# Test PostgreSQL connection.
try:
    test_conn = psycopg2.connect(**PG_CONN_PARAMS) # Attempt to connect to DB
    register_vector(test_conn) # Register vector type
    cur = test_conn.cursor() # Create cursor
    cur.execute("SELECT version();") # Get PostgreSQL version
    pg_version = cur.fetchone()[0]
    cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector';") # Get pgvector version
    pgvector_version = cur.fetchone()
    cur.close() # Close cursor
    test_conn.close() # Close connection
    print(f"‚úÖ PostgreSQL connection successful to '{PG_CONN_PARAMS['database']}'.")
    print(f"   PostgreSQL version: {pg_version.split(',')[0]}")
    if pgvector_version:
        print(f"   pgvector extension version: v{pgvector_version[0]}")
    else:
        print(f"   ‚ö†Ô∏è pgvector extension not found in 'vectordb'. Please check installation in Step 1.")
except Exception as e:
    print(f"‚ùå PostgreSQL connection failed: {e}. Please ensure PostgreSQL is running and configured correctly.")
    exit(1) # Stop execution if DB connection fails

# Discover available MAIR datasets.
print("\nüìä Discovering MAIR datasets from the combined path...")
mair_datasets, datasets_by_category, dataset_sizes = get_mair_datasets() # Use helper function

if not mair_datasets: # If no datasets are found, inform the user and exit.
    print("‚ùå No MAIR datasets found in the combined path!")
    print(f"   Please ensure datasets are downloaded and combined into: {MAIR_COMBINED_PATH}")
    print("   Run the previous notebook sections to prepare the datasets.")
    exit(1)

print(f"\nüìÇ Available MAIR datasets for benchmarking ({len(mair_datasets)} total):")
dataset_index = 1
dataset_map = {} # Map user-friendly index to actual dataset name

# Display datasets grouped by category and sorted by size for easy selection.
for category in sorted(datasets_by_category.keys()):
    datasets_in_category = datasets_by_category[category]
    if datasets_in_category:
        # Sort datasets within each category by document count (descending).
        sorted_datasets_in_category = sorted(datasets_in_category, key=lambda d: dataset_sizes.get(d, 0), reverse=True)
        print(f"\n  üìÅ {category} ({len(sorted_datasets_in_category)} datasets):")
        for dataset in sorted_datasets_in_category:
            size_str = format_size(dataset_sizes.get(dataset, 0)) # Format document count for display
            print(f"     {dataset_index}. {dataset} ({size_str} docs)")
            dataset_map[dataset_index] = dataset
            dataset_index += 1

all_results = [] # List to store all benchmark results.

# Main interactive benchmarking loop.
print("\n" + "="*70)
print("üöÄ READY TO START PGVECTOR BINARY EMBEDDINGS BENCHMARK!")
print("   Choose datasets to test. Results will be saved to: {CSV_PATH}")
print("="*70)

while True:
    choice = input(f"\n‚û°Ô∏è Enter dataset number (1-{len(dataset_map)}) to benchmark, or type 'stop' to finish: ").strip().lower()
    if choice == "stop":
        break # Exit loop if user types 'stop'

    try:
        idx = int(choice)
        if idx in dataset_map:
            dataset_name = dataset_map[idx] # Get the selected dataset name
        else:
            print(f"‚ö†Ô∏è Invalid number '{choice}'. Please choose a number between 1 and {len(dataset_map)} from the list above.")
            continue # Ask for input again
    except ValueError:
        print("‚ö†Ô∏è Invalid input. Please enter a number or 'stop'.")
        continue # Ask for input again

    # Run the benchmark for the selected dataset.
    result = run_benchmark_pgvector(dataset_name, PG_CONN_PARAMS)
    if result: # If the benchmark run was successful, add its result.
        all_results.append(result)

# Final Summary after all chosen datasets are processed.
if all_results:
    print(f"\n{'='*70}")
    print("üèÅ PGVECTOR BINARY BENCHMARK COMPLETE! (Overall Summary)")
    print(f"{('='*70}")

    df = pd.DataFrame(all_results) # Create a DataFrame from collected results.
    print(f"\nüíæ All detailed results have been saved to: {CSV_PATH}")

    print("\nüìä Summary Table of Key Metrics (per dataset):")
    # Display key retrieval metrics and performance times.
    summary_cols = ['Dataset', 'Category', 'NDCG@10', 'MAP@10', 'Recall@100', 'Search_Mean_ms', 'Upload_Total_s']
    available_cols = [col for col in summary_cols if col in df.columns] # Ensure columns exist
    print(df[available_cols].to_string(index=False))

    print("\nüìà Average Performance Across All Benchmarked Datasets:")
    # Calculate averages for numeric metrics.
    numeric_cols = ['NDCG@10', 'MAP@10', 'Recall@100', 'Search_Mean_ms', 'Upload_Total_s']
    valid_cols = [col for col in numeric_cols if col in df.columns]
    if valid_cols:
        avg_vals = df[valid_cols].mean()
        print("\n  Retrieval Quality (Averages):")
        for col in ['NDCG@10', 'MAP@10', 'Recall@100']:
            if col in valid_cols:
                print(f"    {col}: {avg_vals[col]:.4f}")
        print("\n  Performance (Averages):")
        if 'Search_Mean_ms' in valid_cols:
            print(f"    Average Search Time per Query: {avg_vals['Search_Mean_ms']:.2f}ms")
        if 'Upload_Total_s' in valid_cols:
            print(f"    Average Total Upload Time: {avg_vals['Upload_Total_s']:.2f}s")

    print("\n" + "="*70)
    print("‚ú® PGVector Binary Benchmark Concluded Successfully!")
    print("   Detailed results are available in the CSV file for further analysis.")
    print("   Remember to manage your cloud resources if you chose to keep tables.")
    print("="*70)
else:
    print("\n‚ö†Ô∏è No results to display. Please ensure you successfully ran at least one benchmark for a dataset.")

# Vector (Non-Binary) Benchmarking Based on MAIR Dataset

## Vector (Non-Binary) Search in Moorcheh + Pinecone (with Cohere) + Elasticsearch

In [None]:
# ============================================================
# MAIR Benchmark - COMPLETE: Moorcheh vs Pinecone vs Elasticsearch
# Apple-to-Apple Comparison: Embed ONCE, Test ALL providers
# ============================================================

# 1Ô∏è‚É£ Install Necessary Libraries
# This section ensures all required Python packages are installed in the Colab environment.
# The installation might take a few moments, especially on the first run.
!pip install moorcheh-sdk cohere pinecone elasticsearch datasets

import os
import gc
import json
import time
import statistics
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

# 2Ô∏è‚É£ Environment Setup (Colab / Local Compatibility)
# This section configures paths and retrieves API keys, adapting to either Google Colab or a local environment.

# DRIVE_PATH: Defines where benchmark results will be saved. Default is the current directory.
# If running in Google Colab, it will be automatically updated to a path in your mounted Google Drive.
DRIVE_PATH = "."

# MAIR_COMBINED_PATH: **User Customizable Path**
# This is the path to your combined MAIR datasets in Google Drive. It should match the SAVE_PATH from the
# 'Combine MAIR Docs and Queries' step. Ensure this path is correct for your setup.
MAIR_COMBINED_PATH = "/content/gdrive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Combined"

try:
    # This block attempts to configure for Google Colab, leveraging its `drive` and `userdata` (Secrets) features.
    from google.colab import drive, userdata as colab_userdata

    try:
        # Mount Google Drive to allow Colab to access your files. If you've mounted it recently,
        # Colab might remember the authentication. `force_remount=True` ensures a fresh mount if needed.
        drive.mount('/content/gdrive')
        print("‚úÖ Google Drive mounted successfully")
    except Exception as e:
        print(f"‚ö†Ô∏è Drive mount warning: {e}")
        print("Continuing without Drive mount...")

    # DRIVE_PATH: **User Customizable Path**
    # If running in Colab, results will be saved here within your Google Drive.
    # You can customize this path (e.g., '/content/gdrive/MyDrive/MyProject/BenchmarkResults')
    # to organize your benchmark outputs effectively.
    DRIVE_PATH = '/content/gdrive/MyDrive/Moorcheh/Benchmark_Results/MAIR.Moorcheh.vs.Elasticsearch.Results'
    os.makedirs(DRIVE_PATH, exist_ok=True) # Creates the directory if it doesn't exist.
    print(f"‚úÖ Running in Colab. Results will be saved to: {DRIVE_PATH}")

    # Retrieve API keys from Colab secrets. **User Action Required**:
    # To use this feature, you MUST add your API keys to Colab's "Secrets" panel.
    # Look for the key icon (üîë) on the left sidebar of your Colab notebook.
    # Name your secrets EXACTLY as follows:
    # - `MOORCHEH_API_KEY` for Moorcheh
    # - `COHERE_API_KEY` for Cohere (essential for embedding generation)
    # - `PINECONE_API_KEY` for Pinecone
    # - `ELASTIC_URL` for Elasticsearch endpoint (e.g., 'https://your-es-cluster.es.io:9243')
    # - `ELASTIC_API_KEY` for Elasticsearch API Key (preferred) OR
    # - `ELASTIC_USERNAME` and `ELASTIC_PASSWORD` for Elasticsearch Basic Auth.
    api_keys = {
        'moorcheh': colab_userdata.get('MOORCHEH_API_KEY'),
        'cohere': colab_userdata.get('COHERE_API_KEY'),
        'pinecone': colab_userdata.get('PINECONE_API_KEY'),
        'elasticsearch': {
            'url': colab_userdata.get('ELASTIC_URL'),
            'api_key': colab_userdata.get('ELASTIC_API_KEY'),
            'username': colab_userdata.get('ELASTIC_USERNAME'),
            'password': colab_userdata.get('ELASTIC_PASSWORD')
        }
    }
    # If any Elasticsearch credential is null, set the entire elasticsearch dict to None to indicate missing config.
    # This simplifies checks later for whether Elasticsearch is configured.
    if not any(api_keys['elasticsearch'].values()):
        api_keys['elasticsearch'] = None

except ImportError:
    # This block executes if not in Google Colab (e.g., a local Python environment).
    DRIVE_PATH = "." # Results will be saved in the current working directory.
    # In a local environment, API keys are typically read from environment variables.
    # **User Action Required**: Set these environment variables before running the script.
    # - `MOORCHEH_API_KEY`
    # - `COHERE_API_KEY`
    # - `PINECONE_API_KEY`
    # - `ELASTIC_URL` (e.g., "http://localhost:9200" for local ES)
    # - `ELASTIC_API_KEY` OR `ELASTIC_USERNAME`, `ELASTIC_PASSWORD`
    api_keys = {
        'moorcheh': os.environ.get('MOORCHEH_API_KEY'),
        'cohere': os.environ.get('COHERE_API_KEY'),
        'pinecone': os.environ.get('PINECONE_API_KEY'),
        'elasticsearch': {
            'url': os.environ.get('ELASTIC_URL') or "http://localhost:9200", # Defaults to localhost for local setup
            'api_key': os.environ.get('ELASTIC_API_KEY'),
            'username': os.environ.get('ELASTIC_USERNAME') or "elastic",
            'password': os.environ.get('ELASTIC_PASSWORD')
        }
    }
    # Similar to Colab, if ES config is incomplete, mark it as None.
    if not any(api_keys['elasticsearch'].values()):
        api_keys['elasticsearch'] = None
    print("‚ö†Ô∏è Not running in Google Colab. Saving results locally. Ensure environment variables are set.")

# 3Ô∏è‚É£ Benchmark Configuration
# These parameters control various aspects of the benchmark. Users can adjust these values
# to customize the benchmark's behavior, performance, and resource usage.

# TOP_K_SEARCH: **User Customizable Value**
# The number of top-ranked results to retrieve from the vector database for each query.
# A higher value might improve recall but generally increases search latency and resource usage.
TOP_K_SEARCH = 100

# K_VALUES: **User Customizable Value**
# A list of 'k' values at which retrieval metrics (NDCG, MAP, Recall, Precision) will be calculated.
# These values define the cut-off points for evaluation (e.g., NDCG@1, MAP@10, Recall@100).
# You can add or remove values based on your evaluation needs.
K_VALUES = [1, 3, 5, 10, 100]

# MAX_UPLOAD_DOCS: **User Customizable Value**
# Limits the number of documents uploaded to vector databases. This is crucial for managing costs
# and execution time, especially with very large datasets. Set to a lower number (e.g., 10000)
# for quick tests or a very high number (e.g., 700000) for comprehensive runs. Set to `None`
# or a number greater than your dataset size to upload all documents.
MAX_UPLOAD_DOCS = 700000

# BATCH_SIZE: **User Customizable Value**
# The number of embeddings to process or upload in a single API request/batch.
# Adjusting this value can significantly impact performance, memory usage, and API rate limits.
# Larger batches are generally faster due to reduced overhead but consume more RAM.
BATCH_SIZE = 100

# EMBEDDING_MODEL: **User Customizable Value**
# The Cohere model used to generate the initial dense float embeddings.
# 'embed-v4.0' is recommended for its performance and higher dimensionality. Other options include
# 'embed-english-v3.0', 'embed-multilingual-v3.0', etc. Changing this will require adjusting VECTOR_DIMENSION.
EMBEDDING_MODEL = "embed-v4.0"

# RERANK_MODEL: **User Customizable Value (Pinecone only)**
# The Cohere reranking model used to re-order the initial search results from Pinecone.
# 'rerank-english-v3.0' or 'rerank-multilingual-v3.0' can be used. This significantly improves
# retrieval quality but adds latency. Moorcheh and Elasticsearch utilize built-in scoring.
RERANK_MODEL = "rerank-english-v3.0" # You can change this to 'rerank-english-v3.0' for faster reranking.

# INPUT_TYPE_CORPUS: Specifies the input type for corpus documents to Cohere's embedding model.
# This helps Cohere optimize embedding generation for different content types (e.g., 'search_document').
INPUT_TYPE_CORPUS = "search_document"

# INPUT_TYPE_QUERY: Specifies the input type for queries to Cohere's embedding model.
# Similar to corpus input type, this optimizes query embedding generation (e.g., 'search_query').
INPUT_TYPE_QUERY = "search_query"

# VECTOR_DIMENSION: **User Customizable Value (Must match EMBEDDING_MODEL)**
# The dimensionality of the generated embeddings. This value is critical and MUST match the output
# dimension of the `EMBEDDING_MODEL` you choose. Cohere's 'embed-v4.0' has 1536 dimensions;
# 'embed-v3.0' has 1024 dimensions. Incorrect dimension will lead to errors in vector databases.
VECTOR_DIMENSION = 1536

# CSV_PATH: **User Customizable Path**
# The full path where the final benchmark results will be saved in CSV format.
# This file will be created or appended to in your Google Drive (or local directory).
CSV_PATH = os.path.join(DRIVE_PATH, "MAIR.Moorcheh.vs.Elasticsearch.Cohere.V4.csv")

# 4Ô∏è‚É£ Cleanup Policy
# This section allows the user to define how vector database resources (namespaces/indexes)
# are handled after each benchmark run. This helps manage cloud costs and maintain a clean environment.
print("\nüßπ Namespace cleanup policy options:")
print("  1) ask_each_time  -> Prompt after each dataset (default)")
print("  2) always_delete  -> Automatically delete after benchmarking")
print("  3) always_keep    -> Never delete (keep for future searches)")

cleanup_choice = input("Choose policy [1/2/3] (default 1): ").strip() or "1"
CLEANUP_POLICY_MAP = {"1": "ask_each_time", "2": "always_delete", "3": "always_keep"}
CLEANUP_POLICY = CLEANUP_POLICY_MAP.get(cleanup_choice, "ask_each_time")
print(f"‚úÖ Selected cleanup policy: {CLEANUP_POLICY}")

# 5Ô∏è‚É£ Helper Functions
# These functions are designed to assist in loading and processing MAIR datasets,
# extracting qrels, calculating retrieval metrics, and managing benchmark results.

def load_jsonl(filepath):
    """Loads data from a JSONL (JSON Lines) file into a dictionary.
    It intelligently identifies various common ID field names to create a mapping
    from item ID to its content. This is flexible for MAIR's varied formats.
    """
    data = {}
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    item_id = item.get('_id') or item.get('id') or item.get('query_id') or item.get('doc_id')
                    if item_id:
                        data[str(item_id)] = item
                    else:
                        data[str(len(data))] = item
        print(f"   Loaded {len(data)} items from {os.path.basename(filepath)}")
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading {filepath}: {e}")
    return data

def extract_qrels_from_queries(queries):
    """Extracts Qrels (query-relevance judgments) from query data structures.
    This function is tailored for MAIR datasets where qrels might be embedded
    within query files, often in fields like 'labels', 'relevance', or 'qrels'.
    It supports different formats for these fields (list of dicts, list of IDs, or dict).
    """
    qrels = {}
    for qid, q in queries.items():
        if isinstance(q, dict):
            labels = q.get('labels') or q.get('relevance') or q.get('qrels')
            if labels:
                qrels[str(qid)] = {}
                if isinstance(labels, list):
                    for label_item in labels:
                        if isinstance(label_item, dict):
                            doc_id = label_item.get('id') or label_item.get('doc_id')
                            score = label_item.get('score', 1)
                            if doc_id:
                                qrels[str(qid)][str(doc_id)] = score
                        else:
                            qrels[str(qid)][str(label_item)] = 1
                elif isinstance(labels, dict):
                    for doc_id, score in labels.items():
                        qrels[str(qid)][str(doc_id)] = score
    return qrels

DATASET_CATEGORIES = {
    "Legal & Regulatory": ["ACORDAR", "AILA2019-Case", "AILA2019-Statutes", "CUAD", "LeCaRDv2", "LegalQuAD", "REGIR-EU2UK", "REGIR-UK2EU"],
    "Medical & Clinical": ["CliniDS-2014", "CliniDS-2015", "CliniDS-2016", "ClinicalTrials-2021", "ClinicalTrials-2022", "ClinicalTrials-2023", "NFCorpus", "PrecisionMedicine", "Genomics-AdHoc"],
    "Code & Programming": ["APPS", "CodeEditSearch", "CodeSearchNet", "Conala", "HumanEval-X", "LeetCode", "MBPP", "RepoBench", "SWE-Bench-Lite"],
    "Financial": ["ConvFinQA", "FiQA", "FinQA", "FinanceBench", "HC3Finance"],
    "Academic & Scientific": ["ArguAna", "LitSearch", "ProofWiki-Proof", "ProofWiki-Reference", "Competition-Math"],
    "Conversational & Dialog": ["CAsT-2019", "CAsT-2020", "CAsT-2021", "CAsT-2022", "ProCIS-Dialog", "ProCIS-Turn", "SParC", "Quora"],
    "News & Social Media": ["ChroniclingAmericaQA", "Microblog-2011", "Microblog-2012", "Microblog-2013", "Microblog-2014", "News21"],
    "API Documentation": ["Apple", "FoodAPI", "HuggingfaceAPI", "PytorchAPI"],
    "Others": ["BSARD", "BillSum", "CARE", "CPCD", "CQADupStack", "DD", "ELI5", "ExcluIR", "FairRanking", "Fever", "GerDaLIR", "IFEval", "InstructIR", "MISeD", "Monant", "NTCIR", "NeuCLIR", "NevIR", "PointRec", "ProductSearch_2023", "QuanTemp", "Robust04"]
}

def get_dataset_category(dataset_name):
    """Assigns a given dataset name to its predefined category for organized display."""
    for category, datasets in DATASET_CATEGORIES.items():
        if dataset_name in datasets:
            return category
    return "Others"

def get_dataset_size(dataset_name):
    """Counts the number of documents in a MAIR dataset by reading its JSONL files."""
    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    if not os.path.exists(docs_path):
        return 0
    doc_count = 0
    try:
        docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
        for file in docs_files:
            file_path = os.path.join(docs_path, file)
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        doc_count += 1
    except Exception as e:
        print(f"‚ö†Ô∏è Error counting docs in {dataset_name}: {e}")
    return doc_count

def format_size(num):
    """Formats a number (document count) with K (thousands) or M (millions) suffix for readability."""
    if num >= 1_000_000:
        return f"{num/1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num/1_000:.1f}K"
    return str(num)

def get_mair_datasets():
    """Discovers available MAIR datasets in the combined path, categorizes them, and gets their sizes."""
    datasets_by_category = {category: [] for category in DATASET_CATEGORIES.keys()}
    all_datasets = []
    dataset_sizes = {}
    if os.path.exists(MAIR_COMBINED_PATH):
        print(f"üîç Scanning: {MAIR_COMBINED_PATH}")
        try:
            items = os.listdir(MAIR_COMBINED_PATH)
            for item in sorted(items):
                if item.startswith('.') or item.startswith('_'):
                    continue
                dataset_path = os.path.join(MAIR_COMBINED_PATH, item)
                if not os.path.isdir(dataset_path):
                    continue
                docs_path = os.path.join(dataset_path, 'docs')
                queries_path = os.path.join(dataset_path, 'queries')
                if os.path.exists(docs_path) and os.path.exists(queries_path):
                    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
                    queries_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
                    if docs_files and queries_files:
                        all_datasets.append(item)
                        category = get_dataset_category(item)
                        datasets_by_category[category].append(item)
                        doc_count = get_dataset_size(item)
                        dataset_sizes[item] = doc_count
                        print(f"   ‚úÖ {item} [Category: {category}] - {format_size(doc_count)} docs")
        except Exception as e:
            print(f"   ‚ùå Error scanning directory: {e}")
    return all_datasets, datasets_by_category, dataset_sizes

def calculate_timing_stats(timing_list):
    """Calculates basic descriptive statistics (mean, median, min, max, std dev, total sum)
    for a given list of numerical timings. This is used to summarize performance metrics
    like search times, upload durations, etc., providing a quick overview of performance.
    """
    if not timing_list:
        return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0, "total": 0.0}
    return {
        "mean": statistics.mean(timing_list),
        "median": statistics.median(timing_list),
        "min": min(timing_list),
        "max": max(timing_list),
        "std": statistics.stdev(timing_list) if len(timing_list) > 1 else 0.0,
        "total": sum(timing_list)
    }

def calculate_retrieval_metrics(retrieved_results, qrels, k_values=K_VALUES):
    """Calculates NDCG, MAP, Recall, and Precision @K for given retrieval results and ground truth qrels."""
    ndcg_scores = {f"NDCG@{k}": [] for k in k_values}
    map_scores = {f"MAP@{k}": [] for k in k_values}
    recall_scores = {f"Recall@{k}": [] for k in k_values}
    precision_scores = {f"P@{k}": [] for k in k_values}

    for qid in retrieved_results:
        if qid not in qrels:
            continue
        relevant_docs = set(qrels[qid].keys())
        retrieved_docs = list(retrieved_results[qid].keys())
        if len(relevant_docs) == 0:
            continue

        for k in k_values:
            top_k = retrieved_docs[:k]
            hits = len(set(top_k) & relevant_docs)
            recall = hits / len(relevant_docs) if len(relevant_docs) > 0 else 0.0
            recall_scores[f"Recall@{k}"].append(recall)
            precision = hits / k if k > 0 else 0.0
            precision_scores[f"P@{k}"].append(precision)

            ap = 0.0
            hits_so_far = 0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    hits_so_far += 1
                    ap += hits_so_far / i
            ap = ap / min(len(relevant_docs), k) if len(relevant_docs) > 0 else 0.0
            map_scores[f"MAP@{k}"].append(ap)

            dcg = 0.0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    dcg += 1.0 / np.log2(i + 1)
            idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), k)))
            ndcg = dcg / idcg if idcg > 0 else 0.0
            ndcg_scores[f"NDCG@{k}"].append(ndcg)

    ndcg_avg = {k: np.mean(v) if v else 0.0 for k, v in ndcg_scores.items()}
    map_avg = {k: np.mean(v) if v else 0.0 for k, v in map_scores.items()}
    recall_avg = {k: np.mean(v) if v else 0.0 for k, v in recall_scores.items()}
    precision_avg = {k: np.mean(v) if v else 0.0 for k, v in precision_scores.items()}
    return ndcg_avg, map_avg, recall_avg, precision_avg

def format_and_print_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Formats and prints retrieval metrics (NDCG, MAP, Recall, Precision) in a clean,
    tabular format for specified K values. This provides an immediate, human-readable
    summary of the retrieval quality for a benchmark run.
    """
    print("\nRetrieval Metrics:")
    print("‚îÄ" * 80)
    for k in k_values:
        print(f"NDCG@{k}: {ndcg.get(f'NDCG@{k}', 0.0):.4f} | MAP@{k}: {_map.get(f'MAP@{k}', 0.0):.4f} | Recall@{k}: {recall.get(f'Recall@{k}', 0.0):.4f} | P@{k}: {precision.get(f'P@{k}', 0.0):.4f}")

def extract_all_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Extracts all relevant retrieval metrics into a single dictionary.
    This structured format is ideal for storage, particularly for CSV output,
    ensuring all evaluation results are consistently captured.
    """
    metrics = {}
    for k in k_values:
        metrics[f"NDCG@{k}"] = float(ndcg.get(f"NDCG@{k}", 0.0)) if ndcg.get(f"NDCG@{k}") is not None else 0.0
        metrics[f"MAP@{k}"] = float(_map.get(f"MAP@{k}", 0.0))
        metrics[f"Recall@{k}"] = float(recall.get(f"Recall@{k}", 0.0))
        metrics[f"P@{k}"] = float(precision.get(f"P@{k}", 0.0))
    return metrics

def save_results_to_csv(new_result: dict, csv_path: str):
    """Appends a new benchmark result entry to a CSV file. If the file doesn't exist,
    it creates it along with the header. Otherwise, it appends the new data,
    ensuring data integrity and continuity of results over multiple runs.
    """
    new_df = pd.DataFrame([new_result])
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    write_header = not os.path.exists(csv_path)
    try:
        if write_header:
            new_df.to_csv(csv_path, mode='w', header=True, index=False)
        else:
            existing_df = pd.read_csv(csv_path)
            combined_df = pd.concat([existing_df, new_df], ignore_index=True)
            combined_df.to_csv(csv_path, mode='w', header=True, index=False)
        print(f"üíæ Saved to: {csv_path}")
    except Exception as e:
        print(f"‚ùå CSV save failed: {e}")

def clean_memory():
    """Forces Python's garbage collector to release memory.
    This is particularly important in resource-constrained environments like Colab,
    especially when processing large datasets, to prevent out-of-memory errors.
    """
    gc.collect()

def should_cleanup_namespace(provider_name, dataset_name):
    """Interactively prompts the user whether to delete the created vector database
    index/namespace after benchmarking. This gives the user control over resource
    management, helping to prevent unintended cloud costs, based on the `CLEANUP_POLICY`.
    """
    if CLEANUP_POLICY == "always_delete":
        return True
    elif CLEANUP_POLICY == "always_keep":
        return False
    else:
        response = input(f"\n‚ùì Delete {provider_name} namespace for {dataset_name}? (y/n): ").strip().lower()
        return response in ['y', 'yes']

# 6Ô∏è‚É£ Provider Classes
# These classes encapsulate the specific logic for interacting with each vector
# database provider (Moorcheh, Pinecone, Elasticsearch). Each class handles
# operations such as vector upload, search queries, and resource cleanup, tailored
# to the provider's API.

class MoorchehProvider:
    """Manages interaction with the Moorcheh vector database for benchmarking."""
    def __init__(self, client, namespace_name, precomputed_vectors, query_embeddings):
        self.client = client # Moorcheh API client instance.
        self.namespace_name = namespace_name # Unique name for the Moorcheh namespace.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and vectors.
        self.query_embeddings = query_embeddings # List of query embeddings.
        self.upload_timings = {"server_upload_time_s": 0.0, "batch_details": []}
        self.search_timings = []

    def upload(self):
        """Uploads vectors to a Moorcheh vector namespace."""
        from moorcheh_sdk import ConflictError # Imports specific Moorcheh exception for existing namespaces.
        try:
            # Attempts to create a new namespace. Moorcheh organizes vectors into namespaces.
            self.client.create_namespace(namespace_name=self.namespace_name, type="vector", vector_dimension=VECTOR_DIMENSION)
            print(f"‚úÖ Created VECTOR namespace: {self.namespace_name} (dim: {VECTOR_DIMENSION})")
        except ConflictError:
            print(f"‚ö†Ô∏è Namespace exists, using: {self.namespace_name}")
        except Exception as e:
            print(f"‚ùå Error creating Moorcheh namespace: {e}")
            raise

        batch_num = 0
        # Iterates through the precomputed vectors in batches, displaying a progress bar.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading to Moorcheh"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Gets a slice of vectors for the current batch.
            batch_num += 1
            try:
                # Calls the Moorcheh SDK to upload the batch of vectors.
                response = self.client.upload_vectors(namespace_name=self.namespace_name, vectors=batch)
                server_time = 0.0
                if isinstance(response, dict):
                    server_time = response.get("execution_time", 0.0)
                    if "timings" in response:
                        server_time = response["timings"].get("total", server_time)
                self.upload_timings["server_upload_time_s"] += server_time # Accumulates total server upload time.
                self.upload_timings["batch_details"].append({"batch_num": batch_num, "batch_size": len(batch), "server_time_s": server_time})
            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed: {e}")

        print(f"\n‚è±Ô∏è  Moorcheh Upload Timing (SERVER-SIDE): {self.upload_timings['server_upload_time_s']:.4f}s")
        return len(self.precomputed_vectors)

    def search(self, query_idx, top_k=100):
        """Performs a vector search with a query vector against Moorcheh."""
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.
        resp = self.client.search(namespaces=[self.namespace_name], query=query_embedding, top_k=top_k)
        server_time = resp.get("execution_time", 0.0) if isinstance(resp, dict) else 0.0
        timings = resp.get("timings", {}) if isinstance(resp, dict) else {}
        self.search_timings.append({"server_time_s": server_time, "timings": timings}) # Records timings for this search operation.
        hits = resp.get("results", []) if isinstance(resp, dict) else []
        return {r["id"]: r["score"] for r in hits}

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Moorcheh, including detailed component breakdowns."""
        server_times = [t["server_time_s"] for t in self.search_timings] # Extracts overall server times.
        stats = calculate_timing_stats(server_times) # Calculates stats for overall server time.
        component_totals = {}
        for timing in self.search_timings:
            for component, time_val in timing.get("timings", {}).items():
                if component not in component_totals:
                    component_totals[component] = []
                component_totals[component].append(time_val)
        component_stats = {component: calculate_timing_stats(times) for component, times in component_totals.items()}
        return stats, component_stats

    def cleanup(self):
        """Deletes the Moorcheh namespace created for the benchmark."""
        try:
            self.client.delete_namespace(self.namespace_name)
            print(f"üßπ Deleted namespace: {self.namespace_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete: {e}")


class PineconeProvider:
    """Manages interaction with the Pinecone vector database for benchmarking."""
    def __init__(self, client, index_name, precomputed_vectors, query_embeddings, corpus_texts, cohere_client):
        self.client = client # Pinecone API client instance.
        self.index_name = index_name # Unique name for the Pinecone index.
        self.index = None # Will store the Pinecone Index object after creation.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and vectors.
        self.query_embeddings = query_embeddings # List of query embeddings.
        self.corpus_texts = corpus_texts # Original texts of the corpus for reranking.
        self.cohere_client = cohere_client # Cohere client for reranking.
        self.upload_timings = {"index_creation_s": 0.0, "upsert_time_s": 0.0, "batch_details": []}
        self.search_timings = []
        self.rerank_timings = []

    def upload(self):
        """Uploads vectors to a Pinecone index."""
        from pinecone import ServerlessSpec # Imports Pinecone's specification for index creation.

        t0 = time.perf_counter()
        # Checks if an index with the same name already exists and deletes it.
        try:
            if self.index_name in [idx.name for idx in self.client.list_indexes()]:
                self.client.delete_index(self.index_name)
                time.sleep(5) # Pauses to allow Pinecone to complete the index deletion.
        except:
            pass

        # Creates a new Pinecone index.
        self.client.create_index(name=self.index_name, dimension=VECTOR_DIMENSION, metric='cosine', spec=ServerlessSpec(cloud='aws', region='us-east-1')) # **User Customizable**: Adjust cloud/region as needed for `ServerlessSpec`.
        # Waits for the newly created index to be ready.
        while not self.client.describe_index(self.index_name).status['ready']:
            time.sleep(1)

        t1 = time.perf_counter()
        self.upload_timings["index_creation_s"] = t1 - t0 # Records index creation time.
        self.index = self.client.Index(self.index_name) # Gets the Pinecone Index object.

        batch_num = 0
        # Upserts vectors in batches to Pinecone.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading to Pinecone"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE]
            vectors = [(v['id'], v['vector']) for v in batch]
            batch_num += 1
            t_batch_start = time.perf_counter()
            self.index.upsert(vectors=vectors)
            t_batch_end = time.perf_counter()
            batch_time = t_batch_end - t_batch_start
            self.upload_timings["upsert_time_s"] += batch_time # Accumulates total upsert time.
            self.upload_timings["batch_details"].append({"batch_num": batch_num, "batch_size": len(batch), "upsert_time_s": batch_time})

        print(f"\n‚è±Ô∏è  Pinecone Upload Timing: Index Creation: {self.upload_timings['index_creation_s']:.4f}s, Upsert Total: {self.upload_timings['upsert_time_s']:.4f}s")
        return len(self.precomputed_vectors)

    def search(self, query_idx, query_text, top_k=100):
        """Performs a vector search and then reranks results using Cohere for Pinecone."""
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.
        t0 = time.perf_counter()
        results = self.index.query(vector=query_embedding, top_k=top_k) # Executes the Pinecone query.
        t1 = time.perf_counter()
        search_time = t1 - t0

        doc_ids = [match['id'] for match in results['matches']] # Extracts document IDs from initial search.
        doc_texts = [self.corpus_texts.get(doc_id, "") for doc_id in doc_ids] # Retrieves original texts for reranking.

        t_rerank_start = time.perf_counter()
        try:
            # Reranks the top_k results using Cohere's rerank model.
            rerank_response = self.cohere_client.rerank(model=RERANK_MODEL, query=query_text, documents=doc_texts, top_n=top_k, return_documents=False)
            reranked_results = {doc_ids[result.index]: result.relevance_score for result in rerank_response.results}
            t_rerank_end = time.perf_counter()
            rerank_time = t_rerank_end - t_rerank_start
        except Exception as e:
            print(f"\n‚ö†Ô∏è Reranking failed: {e}, using original results")
            reranked_results = {match['id']: match['score'] for match in results['matches']}
            rerank_time = 0.0

        total_time = search_time + rerank_time # Total time including initial search and reranking.
        self.search_timings.append({"query_time_s": search_time, "total_time_s": total_time}) # Records timings.
        self.rerank_timings.append({"rerank_time_s": rerank_time})
        return reranked_results

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Pinecone (client-side durations)."""
        query_times = [t["query_time_s"] for t in self.search_timings]
        total_times = [t["total_time_s"] for t in self.search_timings]
        rerank_times = [t["rerank_time_s"] for t in self.rerank_timings]
        return {"query": calculate_timing_stats(query_times), "total": calculate_timing_stats(total_times), "rerank": calculate_timing_stats(rerank_times)}

    def cleanup(self):
        """Deletes the Pinecone index created for the benchmark."""
        try:
            self.client.delete_index(self.index_name)
            print(f"üßπ Deleted Pinecone index: {self.index_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete: {e}")


class ElasticsearchProvider:
    """Manages interaction with Elasticsearch for benchmarking, using its `dense_vector` field type."""
    def __init__(self, client, index_name, precomputed_vectors, query_embeddings):
        self.client = client # Elasticsearch API client instance.
        self.index_name = index_name # Unique name for the Elasticsearch index.
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and vectors.
        self.query_embeddings = query_embeddings # List of query embeddings.
        self.upload_timings = {
            "index_creation_s": 0.0,
            "server_bulk_time_ms": 0.0,
            "client_total_time_s": 0.0,
            "batch_details": []
        }
        self.search_timings = []

    def upload(self):
        """Uploads vectors to an Elasticsearch index using a `dense_vector` field."""
        try:
            # Deletes an existing index if it's found, ensuring a fresh start for the benchmark.
            self.client.indices.delete(index=self.index_name)
            time.sleep(2)
        except:
            pass

        t0 = time.perf_counter()
        # Defines the index mapping with a `dense_vector` field for storing embeddings.
        index_config = {
            "mappings": {"properties": {"doc_id": {"type": "keyword"}, "embedding": {"type": "dense_vector", "dims": VECTOR_DIMENSION, "index": True, "similarity": "cosine"}}},
            "settings": {"number_of_shards": 1, "number_of_replicas": 0} # **User Customizable**: Adjust shard/replica settings.
        }

        try:
            self.client.indices.create(index=self.index_name, body=index_config)
            print(f"‚úÖ Created Elasticsearch index: {self.index_name} (dim: {VECTOR_DIMENSION})")
        except Exception as e:
            print(f"‚ö†Ô∏è Index creation error: {e}")
            raise

        t1 = time.perf_counter()
        self.upload_timings["index_creation_s"] = t1 - t0 # Records index creation time.

        print(f"\nüì§ Uploading vectors to Elasticsearch...")
        batch_num = 0
        total_uploaded = 0
        t_upload_start = time.perf_counter()

        # Uses Elasticsearch's bulk API for efficient ingestion of many documents.
        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading to Elasticsearch"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE]
            batch_num += 1

            bulk_body = []
            for vec in batch:
                bulk_body.append(json.dumps({"index": {"_index": self.index_name, "_id": vec["id"]}}))
                bulk_body.append(json.dumps({"doc_id": vec["id"], "embedding": vec["vector"]}))

            bulk_data = "\n".join(bulk_body) + "\n"

            try:
                t_batch_start = time.perf_counter()
                response = self.client.bulk(body=bulk_data, refresh=False) # `refresh=False` improves ingestion performance.
                t_batch_end = time.perf_counter()

                batch_client_time = t_batch_end - t_batch_start
                server_took_ms = response.get("took", 0)
                self.upload_timings["server_bulk_time_ms"] += server_took_ms # Accumulates server-side bulk time.

                batch_success = 0
                if "items" in response:
                    for item in response["items"]:
                        if "index" in item and item["index"].get("status") in [200, 201]:
                            batch_success += 1

                total_uploaded += batch_success
                self.upload_timings["batch_details"].append({
                    "batch_num": batch_num,
                    "batch_size": len(batch),
                    "server_time_ms": server_took_ms,
                    "client_time_s": batch_client_time
                })
            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed: {e}")

        t_upload_end = time.perf_counter()
        self.upload_timings["client_total_time_s"] = t_upload_end - t_upload_start # Records client-side total upload time.

        try:
            self.client.indices.refresh(index=self.index_name) # Refreshes the index to make documents searchable.
        except Exception as e:
            print(f"‚ö†Ô∏è Refresh error: {e}")

        print(f"\n‚è±Ô∏è  Elasticsearch Upload Timing:")
        print(f"    Index Creation: {self.upload_timings['index_creation_s']:.4f}s")
        print(f"    Bulk Upload (SERVER): {self.upload_timings['server_bulk_time_ms']/1000:.4f}s")
        print(f"    Bulk Upload (CLIENT): {self.upload_timings['client_total_time_s']:.4f}s")
        return total_uploaded

    def search(self, query_idx, top_k=100):
        """Performs a k-Nearest Neighbor (kNN) search against Elasticsearch."""
        query_embedding = self.query_embeddings[query_idx] # Retrieves the specific query embedding.
        query_body = {
            "knn": {
                "field": "embedding",
                "query_vector": query_embedding,
                "k": top_k,
                "num_candidates": min(top_k * 10, 10000) # **User Customizable**: Adjust `num_candidates` for speed/accuracy trade-off.
            },
            "size": top_k,
            "_source": ["doc_id"]
        }

        try:
            t_start = time.perf_counter()
            response = self.client.search(index=self.index_name, body=query_body)
            t_end = time.perf_counter()

            client_time_s = t_end - t_start
            server_took_ms = response.get("took", 0)
            server_took_s = server_took_ms / 1000.0

            self.search_timings.append({
                "server_time_ms": server_took_ms,
                "server_time_s": server_took_s,
                "client_time_s": client_time_s
            }) # Records detailed timings for this search.

            hits = response.get("hits", {}).get("hits", [])
            results = {}
            for i, hit in enumerate(hits):
                doc_id = hit.get("_id")
                score = hit.get("_score", 1.0 / (1.0 + i)) # Retrieval score; falls back to a decreasing score if `_score` is missing.
                results[str(doc_id)] = score
            return results

        except Exception as e:
            print(f"\n‚ùå Search error: {e}")
            self.search_timings.append({
                "server_time_ms": 0,
                "server_time_s": 0.0,
                "client_time_s": 0.0
            })
            return {}

    def get_search_stats(self):
        """Calculates and returns search timing statistics for Elasticsearch (both server-side and client-side)."""
        server_times_s = [t["server_time_s"] for t in self.search_timings]
        client_times_s = [t["client_time_s"] for t in self.search_timings]

        server_stats = calculate_timing_stats(server_times_s)
        client_stats = calculate_timing_stats(client_times_s)

        return {
            "server": server_stats,
            "client": client_stats
        }

    def cleanup(self):
        """Deletes the Elasticsearch index created for the benchmark."""
        try:
            self.client.indices.delete(index=self.index_name)
            print(f"üßπ Deleted Elasticsearch index: {self.index_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete index: {e}")


# 7Ô∏è‚É£ Main Benchmark Orchestration Function
# The `run_benchmark_all_providers` function orchestrates the entire benchmarking workflow
# for a given dataset and selected providers. It encompasses data loading, embedding
# generation, uploading to various vector databases, performing searches,
# evaluating retrieval quality, and saving detailed results.

def run_benchmark_all_providers(dataset_name, provider_names, es_client=None):
    """Runs a full benchmark for specified providers and a chosen dataset."""
    print(f"\n{'='*70}")
    print(f"üöÄ Dataset: {dataset_name}")
    print(f"üìä Testing providers: {', '.join([p.upper() for p in provider_names])}")
    print(f"{'='*70}")

    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    queries_path = os.path.join(dataset_path, 'queries')

    if not os.path.exists(docs_path) or not os.path.exists(queries_path):
        print(f"‚ö†Ô∏è Dataset not found")
        return []

    print(f"\nüì• Loading dataset...")
    corpus = {}
    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
    for file in docs_files:
        corpus.update(load_jsonl(os.path.join(docs_path, file)))

    queries = {}
    query_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
    for file in query_files:
        queries.update(load_jsonl(os.path.join(queries_path, file)))

    qrels = extract_qrels_from_queries(queries)
    print(f"‚úÖ Loaded: {len(corpus)} docs, {len(queries)} queries, {len(qrels)} qrels")

    if not corpus or not queries or not qrels:
        print(f"‚ö†Ô∏è Empty dataset")
        return []

    # Generate Float Embeddings with Cohere (Once per dataset for all providers)
    print(f"\nüß† Step 1: Generating embeddings ONCE with Cohere {EMBEDDING_MODEL}...")
    docs = list(corpus.items())[:MAX_UPLOAD_DOCS] # Limit documents based on MAX_UPLOAD_DOCS.

    texts = []
    corpus_texts = {}
    for doc_id, doc_content in docs:
        if isinstance(doc_content, dict):
            text = None
            for field in ['doc', 'text', 'content', 'body', 'passage', 'document', 'title', 'abstract']:
                if field in doc_content and doc_content[field]:
                    val = doc_content[field]
                    if isinstance(val, str) and val.strip():
                        text = val
                        break
            if not text:
                for key, val in doc_content.items():
                    if isinstance(val, str) and val.strip() and len(val) > 10:
                        text = val
                        break
            if not text:
                text = str(doc_content)
        else:
            text = str(doc_content)

        final_text = text if text else "document"
        texts.append(final_text)
        corpus_texts[str(doc_id)] = final_text

    doc_ids = [str(d[0]) for d in docs]

    print(f"  üìÑ Generating corpus embeddings for {len(texts)} documents...")
    t_embed_start = time.perf_counter()

    corpus_embeddings = []
    for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Embedding corpus"):
        batch_texts = texts[i:i+BATCH_SIZE]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_CORPUS)
        corpus_embeddings.extend(response.embeddings)

    t_embed_end = time.perf_counter()
    embedding_time = t_embed_end - t_embed_start

    precomputed_vectors = []
    for doc_id, text, embedding in zip(doc_ids, texts, corpus_embeddings):
        precomputed_vectors.append({"id": doc_id, "vector": embedding})

    print(f"  üîç Generating query embeddings for {len(queries)} queries...")
    query_ids = list(queries.keys())
    query_texts = []
    for qid in query_ids:
        q = queries[qid]
        if isinstance(q, dict):
            text = (q.get('text') or q.get('query') or q.get('instruction') or q.get('question') or q.get('query_text') or str(q))
        else:
            text = str(q)
        query_texts.append(text if text else "query")

    t_query_start = time.perf_counter()
    query_embeddings = []
    for i in tqdm(range(0, len(query_texts), BATCH_SIZE), desc="Embedding queries"):
        batch_texts = [str(t) if t else "query" for t in query_texts[i:i+BATCH_SIZE]]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_QUERY)
        query_embeddings.extend(response.embeddings)

    t_query_end = time.perf_counter()
    query_embedding_time = t_query_end - t_query_start

    print(f"\n‚è±Ô∏è  Embedding Generation Time (SHARED):")
    print(f"    Corpus: {embedding_time:.4f}s, Queries: {query_embedding_time:.4f}s, Total: {embedding_time + query_embedding_time:.4f}s")

    results_all_providers = [] # A list to collect benchmark results from all provider runs for this dataset.

    # Benchmark Each Selected Provider
    for provider_name in provider_names:
        print(f"\n{'‚îÄ'*70}")
        print(f"üîß Testing Provider: {provider_name.upper()}")
        print(f"{'‚îÄ'*70}")

        try:
            provider = None
            # Initializes the appropriate provider class based on the `provider_name`.
            if provider_name == 'moorcheh':
                if 'moorcheh' not in clients: # Skips if Moorcheh client was not initialized.
                    print(f"‚ö†Ô∏è Skipping {provider_name}")
                    continue
                namespace_name = f"mair-{dataset_name.replace('/', '-')}-cohere-v4"[:63]
                provider = MoorchehProvider(clients['moorcheh'], namespace_name, precomputed_vectors, query_embeddings)
            elif provider_name == 'pinecone':
                if 'pinecone' not in clients: # Skips if Pinecone client was not initialized.
                    print(f"‚ö†Ô∏è Skipping {provider_name}")
                    continue
                safe_dataset_name = dataset_name.replace('/', '-').replace('_', '-').lower()
                index_name = f"mair-{safe_dataset_name}-v4"[:45]
                provider = PineconeProvider(clients['pinecone'], index_name, precomputed_vectors, query_embeddings, corpus_texts, cohere_client)
            elif provider_name == 'elasticsearch':
                if es_client is None: # Skips if ES client failed to connect.
                    print(f"‚ö†Ô∏è Skipping {provider_name}")
                    continue
                safe_dataset_name = dataset_name.replace('/', '-').replace('_', '-').lower()
                index_name = f"mair-{safe_dataset_name}-v4"
                provider = ElasticsearchProvider(es_client, index_name, precomputed_vectors, query_embeddings)

            if provider is None: # Catches cases where no provider object was created.
                print(f"‚ùå Could not initialize provider: {provider_name}. Check configuration and API keys.")
                continue

            print(f"\nüì§ Uploading pre-computed vectors to {provider_name}...")
            num_uploaded = provider.upload() # Uploads data to the current provider.

            print(f"\nüîç Searching with pre-computed query vectors...")

            results = {} # Stores search results for this specific provider.
            for i, qid in enumerate(tqdm(query_ids, desc=f"Searching {provider_name}")):
                try:
                    if provider_name == 'pinecone':
                        results[qid] = provider.search(i, query_texts[i], top_k=TOP_K_SEARCH)
                    else:
                        results[qid] = provider.search(i, top_k=TOP_K_SEARCH)
                except Exception as e:
                    print(f"\n‚ùå Query {i+1} failed: {e}")
                    results[qid] = {} # Logs empty results for failed queries.

            # Retrieves and prints search timing statistics from the provider object.
            if provider_name == 'moorcheh':
                search_stats, component_stats = provider.get_search_stats()
                print(f"\n‚è±Ô∏è  Moorcheh Search Timing (SERVER-SIDE): Total: {search_stats['total']:.4f}s, Mean: {search_stats['mean']:.4f}s")
            elif provider_name == 'pinecone':
                search_stats = provider.get_search_stats()
                print(f"\n‚è±Ô∏è  Pinecone Search Timing: Total: {search_stats['total']['total']:.4f}s, Mean: {search_stats['total']['mean']:.4f}s")
            elif provider_name == 'elasticsearch':
                search_stats = provider.get_search_stats()
                print(f"\n‚è±Ô∏è  Elasticsearch Search Timing:")
                print(f"    SERVER - Total: {search_stats['server']['total']:.4f}s, Mean: {search_stats['server']['mean']:.4f}s")
                print(f"    CLIENT - Total: {search_stats['client']['total']:.4f}s, Mean: {search_stats['client']['mean']:.4f}s")

            print(f"\nüìä Evaluating retrieval quality for {provider_name}...")
            ndcg, _map, recall, precision = calculate_retrieval_metrics(results, qrels, K_VALUES) # Evaluate retrieval quality.
            format_and_print_metrics(ndcg, _map, recall, precision) # Print formatted metrics.

            metrics = extract_all_metrics(ndcg, _map, recall, precision) # Stores all metrics.
            dataset_category = get_dataset_category(dataset_name) # Gets dataset category.

            result = { # Compiles all results and metadata into a dictionary.
                "Dataset": dataset_name,
                "Category": dataset_category,
                "Provider": provider_name,
                "Num_Corpus": len(corpus),
                "Num_Uploaded": num_uploaded,
                "Num_Queries": len(queries),
                "Embedding_Model": EMBEDDING_MODEL,
                "Vector_Dimension": VECTOR_DIMENSION,
                "Batch_Size": BATCH_SIZE,
                "Cleanup_Policy": CLEANUP_POLICY,
                "Embedding_Generation_Time_s": round(embedding_time, 4),
                "Query_Embedding_Time_s": round(query_embedding_time, 4),
                "Total_Embedding_Time_s": round(embedding_time + query_embedding_time, 4),
            }

            if provider_name == 'moorcheh':
                search_stats, component_stats = provider.get_search_stats()
                result.update({
                    "Rerank_Model": "built-in", # Moorcheh has built-in scoring/reranking.
                    "Upload_Server_Total_s": round(provider.upload_timings["server_upload_time_s"], 4),
                    "Upload_Total_s": round(provider.upload_timings["server_upload_time_s"], 4),
                    "Search_Server_Total_s": round(search_stats['total'], 4),
                    "Search_Server_Mean_s": round(search_stats['mean'], 4),
                    "Search_Server_Median_s": round(search_stats['median'], 4),
                    "Search_Server_Mean_ms": round(search_stats['mean'] * 1000, 2),
                })

            elif provider_name == 'pinecone':
                search_stats = provider.get_search_stats()
                result.update({
                    "Rerank_Model": RERANK_MODEL,
                    "Upload_Index_Creation_s": round(provider.upload_timings["index_creation_s"], 4),
                    "Upload_Upsert_Total_s": round(provider.upload_timings["upsert_time_s"], 4),
                    "Upload_Total_s": round(provider.upload_timings["index_creation_s"] + provider.upload_timings["upsert_time_s"], 4),
                    "Search_Total_s": round(search_stats['total']['total'], 4),
                    "Search_Mean_s": round(search_stats['total']['mean'], 4),
                    "Search_Mean_ms": round(search_stats['total']['mean'] * 1000, 2),
                    "Rerank_Mean_ms": round(search_stats['rerank']['mean'] * 1000, 2),
                })

            elif provider_name == 'elasticsearch':
                search_stats = provider.get_search_stats()
                result.update({
                    "Rerank_Model": "built-in (Cosine KNN)", # ES uses KNN with cosine similarity.
                    "Upload_Index_Creation_s": round(provider.upload_timings["index_creation_s"], 4),
                    "Upload_Server_Bulk_s": round(provider.upload_timings["server_bulk_time_ms"] / 1000, 4),
                    "Upload_Client_Total_s": round(provider.upload_timings["client_total_time_s"], 4),
                    "Upload_Total_s": round(provider.upload_timings["index_creation_s"] + provider.upload_timings["client_total_time_s"], 4),
                    "Search_Server_Total_s": round(search_stats['server']['total'], 4),
                    "Search_Server_Mean_s": round(search_stats['server']['mean'], 4),
                    "Search_Server_Median_s": round(search_stats['server']['median'], 4),
                    "Search_Server_Min_s": round(search_stats['server']['min'], 4),
                    "Search_Server_Max_s": round(search_stats['server']['max'], 4),
                    "Search_Server_Std_s": round(search_stats['server']['std'], 4),
                    "Search_Server_Mean_ms": round(search_stats['server']['mean'] * 1000, 2),
                    "Search_Client_Total_s": round(search_stats['client']['total'], 4),
                    "Search_Client_Mean_s": round(search_stats['client']['mean'], 4),
                    "Search_Client_Median_s": round(search_stats['client']['median'], 4),
                    "Search_Client_Mean_ms": round(search_stats['client']['mean'] * 1000, 2),
                })

            result.update(metrics) # Adds retrieval quality metrics to the result entry.
            save_results_to_csv(result, CSV_PATH) # Saves results to CSV after each provider/dataset run.
            results_all_providers.append(result) # Appends to the list of all results for final summary.

            # Cleanup: Asks the user whether to delete the created index/namespace to manage cloud resources.
            should_delete = should_cleanup_namespace(provider_name, dataset_name)
            if should_delete:
                provider.cleanup() # Calls the provider's cleanup method.
            else:
                print(f"üíæ Keeping {provider_name} namespace/index")

            del provider # Explicitly deletes the provider object to free up resources.
            del results # Clears search results for the next iteration.
            clean_memory() # Forces garbage collection.

            print(f"\n‚úÖ {provider_name.upper()} completed for {dataset_name}")

        except Exception as e:
            print(f"‚ùå Error with {provider_name}: {e}")
            import traceback
            traceback.print_exc() # Prints full traceback for debugging.
            clean_memory()

    # Clears large embedding arrays from memory after all providers for a dataset have been processed.
    del precomputed_vectors
    del query_embeddings
    del corpus_embeddings
    del corpus_texts
    clean_memory()

    return results_all_providers


# 8Ô∏è‚É£ Main Execution Block
# This block handles the overall flow of the benchmark, including API key checks,
# client initialization, provider selection, dataset discovery, and the interactive
# loop for running benchmarks.

print("üîë Checking API Keys...")
for provider in ['moorcheh', 'cohere', 'pinecone']:
    key = api_keys.get(provider)
    status = "‚úÖ" if key else "‚ùå"
    print(f"  {status} {provider.capitalize()}")

clients = {}
es_client = None

# Initializes Cohere client, which is essential for generating embeddings for ALL providers.
# The script will exit if the Cohere API key is not found.
if api_keys['cohere']:
    import cohere
    cohere_client = cohere.Client(api_keys['cohere'])
    print(f"‚úÖ Cohere client initialized")
else:
    print("\n‚ùå Cohere API key required!")
    exit(1)

# Initializes Moorcheh client if it was selected and its API key is present.
if 'moorcheh' in api_keys and api_keys['moorcheh']:
    from moorcheh_sdk import MoorchehClient, ConflictError
    clients['moorcheh'] = MoorchehClient(api_key=api_keys['moorcheh'])
    print(f"‚úÖ Moorcheh client initialized")

# Initializes Pinecone client if it was selected and its API key is present.
if 'pinecone' in api_keys and api_keys['pinecone']:
    from pinecone import Pinecone, ServerlessSpec
    clients['pinecone'] = Pinecone(api_key=api_keys['pinecone'])
    print(f"‚úÖ Pinecone client initialized")

# Initializes Elasticsearch client if it was selected and properly configured.
try:
    from elasticsearch import Elasticsearch
    es_config = api_keys['elasticsearch']

    if es_config.get('api_key'):
        try:
            es_client = Elasticsearch(es_config['url'], api_key=es_config['api_key'], request_timeout=60)
            if es_client.ping():
                print("‚úÖ Elasticsearch connected (API Key)")
            else:
                es_client = None
        except Exception as e:
            es_client = None

    if not es_client and es_config.get('password'):
        try:
            es_client = Elasticsearch(es_config['url'], basic_auth=(es_config['username'], es_config['password']), request_timeout=60)
            if es_client.ping():
                print("‚úÖ Elasticsearch connected (Basic Auth)")
            else:
                es_client = None
        except Exception as e:
            es_client = None

    if not es_client:
        try:
            es_client = Elasticsearch(es_config['url'], request_timeout=60)
            if es_client.ping():
                print("‚úÖ Elasticsearch connected (No Auth)")
            else:
                es_client = None
        except Exception as e:
            es_client = None

    if not es_client:
        print(f"‚ö†Ô∏è Elasticsearch connection failed - URL: {es_config['url']}")

except ImportError:
    print("‚ö†Ô∏è Elasticsearch library not installed")
except Exception as e:
    print(f"‚ö†Ô∏è Elasticsearch error: {e}")

# 9Ô∏è‚É£ Provider Selection
# User selects which providers to test interactively.
print("\nüîß Available Providers (ALL using pre-computed vectors):")
print(f"  1. Moorcheh (Vector namespace - {EMBEDDING_MODEL} - {VECTOR_DIMENSION}D)")
print(f"  2. Pinecone (Vector index - {EMBEDDING_MODEL} - {VECTOR_DIMENSION}D + Cohere {RERANK_MODEL})")
print(f"  3. Elasticsearch (Dense vector - {EMBEDDING_MODEL} - {VECTOR_DIMENSION}D - Cosine KNN)")

provider_choice = input("\n‚û°Ô∏è Select providers (e.g., '1,2,3' or 'all'): ").strip().lower()
if provider_choice == 'all':
    selected_providers = [p for p in ['moorcheh', 'pinecone', 'elasticsearch'] if (p in clients or p == 'elasticsearch' and es_client)]
else:
    provider_map = {'1': 'moorcheh', '2': 'pinecone', '3': 'elasticsearch'}
    selected_providers = [provider_map[p.strip()] for p in provider_choice.split(',') if p.strip() in provider_map and (provider_map[p.strip()] in clients or (provider_map[p.strip()] == 'elasticsearch' and es_client))]

print(f"\n‚úÖ Selected providers: {', '.join(selected_providers)}")

# 10. Dataset Discovery and Selection Loop
# This section allows the user to interactively select which datasets to run the
# benchmark on. The benchmark will execute for each selected dataset and provider combination.
print("\nüìä Discovering datasets...")
mair_datasets, datasets_by_category, dataset_sizes = get_mair_datasets()

if not mair_datasets:
    print("‚ùå No datasets found!")
    exit(1)

print(f"\nüìÇ Available datasets by category ({len(mair_datasets)} total):")
dataset_index = 1
dataset_map = {}

for category in sorted(datasets_by_category.keys()):
    datasets = datasets_by_category[category]
    if datasets:
        sorted_datasets = sorted(datasets, key=lambda d: dataset_sizes.get(d, 0), reverse=True)
        print(f"\n  üìÅ {category} ({len(sorted_datasets)} datasets):")
        for dataset in sorted_datasets:
            size_str = format_size(dataset_sizes.get(dataset, 0))
            print(f"     {dataset_index}. {dataset} ({size_str} docs)")
            dataset_map[dataset_index] = dataset
            dataset_index += 1

all_results = []

while True:
    choice = input(f"\n‚û°Ô∏è Enter dataset number (1-{len(dataset_map)}) or 'stop': ").strip().lower()
    if choice == "stop":
        break

    try:
        idx = int(choice)
        if idx in dataset_map:
            dataset_name = dataset_map[idx]
        else:
            print(f"‚ö†Ô∏è Invalid number")
            continue
    except ValueError:
        print("‚ö†Ô∏è Invalid input")
        continue

    results = run_benchmark_all_providers(dataset_name, selected_providers, es_client=es_client)
    all_results.extend(results)

# 11. Final Summary and Analysis
# After all selected datasets are processed, this section provides an overall summary
# of the benchmark results. It includes tables of key metrics, average performance
# comparisons by provider, and insights into space efficiency and timing.

if all_results:
    print(f"\n{'='*70}")
    print("üèÅ BENCHMARK COMPLETE!")
    print(f"{'='*70}")

    df = pd.DataFrame(all_results)
    print(f"\nüíæ Results saved to: {CSV_PATH}")

    print("\nüìä Summary Table:")
    summary_cols = ['Dataset', 'Category', 'Provider', 'NDCG@10', 'MAP@10', 'Recall@100', 'Search_Server_Mean_ms']
    print(df[summary_cols].to_string(index=False))

    print("\nüìà Average Performance by Provider:")
    numeric_cols = ['NDCG@10', 'MAP@10', 'Recall@100', 'Search_Server_Mean_ms', 'Search_Mean_ms']
    valid_cols = [col for col in numeric_cols if col in df.columns]
    avg_df = df.groupby('Provider')[valid_cols].mean()
    print(avg_df.to_string())

    print("\nüèÜ Overall Performance Winners:")
    for metric in ['NDCG@10', 'MAP@10', 'Recall@100']:
        if metric in df.columns:
            winner_data = df.groupby('Provider')[metric].mean()
            if len(winner_data) > 0:
                winner = winner_data.idxmax()
                winner_score = winner_data.max()
                print(f"  {metric}: {winner.upper()} ({winner_score:.4f})")

    print("\n" + "="*70)
    print("‚ú® Benchmark Complete - Results saved to CSV!")
    print("="*70)
else:
    print("\n‚ö†Ô∏è No results to display")

## Vector (Non-Binary) Search with PGVector and PostgreSQL in Google Colab

In [None]:
# ============================================================
# MAIR Benchmark - PGVector Edition (Google Colab Optimized)
# Comprehensive vector search benchmarking using PostgreSQL + pgvector
# ============================================================

# -------------------- STEP 1: Install PostgreSQL & pgvector in Colab ---------------------
# This section sets up the PostgreSQL database server and installs the pgvector extension
# directly within the Google Colab environment. This is a one-time setup for the session.
print("üîß Installing PostgreSQL and pgvector in Google Colab...")
print("This will take 2-3 minutes on first run...
")

import os
import subprocess
import time

# Install PostgreSQL server and client utilities.
# `apt-get update -qq`: Updates package lists quietly.
# `apt-get install -y postgresql postgresql-contrib`: Installs PostgreSQL server and extensions.
print("üì¶ Installing PostgreSQL...")
os.system('apt-get update -qq > /dev/null 2>&1')
os.system('apt-get install -y postgresql postgresql-contrib > /dev/null 2>&1')

# Install build dependencies for pgvector (needed to compile from source).
# `build-essential`: Provides compilers (gcc, g++).
# `git`: To clone the pgvector repository.
# `postgresql-server-dev-14`: Development headers for PostgreSQL 14 (adjust version if using a different PG version).
print("üì¶ Installing build tools...")
os.system('apt-get install -y build-essential git postgresql-server-dev-14 > /dev/null 2>&1')

# Clone and install pgvector extension by compiling from source.
# This ensures compatibility and access to the latest features.
print("üì¶ Installing pgvector extension...")
os.system('cd /tmp && rm -rf pgvector && git clone --quiet https://github.com/pgvector/pgvector.git')
os.system('cd /tmp/pgvector && make > /dev/null 2>&1 && make install > /dev/null 2>&1')

# Start PostgreSQL service.
print("üöÄ Starting PostgreSQL service...")
os.system('service postgresql start > /dev/null 2>&1')
time.sleep(2) # Give the service a moment to fully start up.

# Configure the PostgreSQL database.
# `sudo -u postgres`: Executes commands as the 'postgres' user, which has administrative privileges.
# `DROP DATABASE IF EXISTS vectordb`: Ensures a clean slate for the 'vectordb' database.
# `CREATE DATABASE vectordb`: Creates a new database named 'vectordb' for our vector data.
# `ALTER USER postgres PASSWORD 'postgres'`: Sets a simple password for the default 'postgres' user.
# `CREATE EXTENSION IF NOT EXISTS vector`: Enables the pgvector extension within the 'vectordb'.
print("üîß Configuring database...")
os.system('sudo -u postgres psql -c "DROP DATABASE IF EXISTS vectordb;" > /dev/null 2>&1')
os.system('sudo -u postgres psql -c "CREATE DATABASE vectordb;" > /dev/null 2>&1')
os.system('sudo -u postgres psql -c "ALTER USER postgres PASSWORD \'postgres\';" > /dev/null 2>&1')
os.system('sudo -u postgres psql -d vectordb -c "CREATE EXTENSION IF NOT EXISTS vector;" > /dev/null 2>&1')

print("‚úÖ PostgreSQL with pgvector is ready!
")

# -------------------- STEP 2: Install Python Dependencies ---------------------
# Install Python packages required for connecting to PostgreSQL, generating embeddings, and data analysis.
# `psycopg2-binary`: PostgreSQL adapter for Python.
# `pgvector`: Python client for pgvector.
# `cohere`: For generating embeddings.
# `pandas`, `numpy`, `tqdm`: Standard data science libraries.
print("üì¶ Installing Python packages...")
os.system('pip install -q psycopg2-binary pgvector cohere pandas numpy tqdm')
print("‚úÖ Python packages installed!
")

# -------------------- STEP 3: Import Libraries ---------------------
# Import all necessary Python libraries for the benchmark script.
import gc # For garbage collection to manage memory.
import json # For handling JSONL dataset files.
import statistics # For calculating performance statistics.
import numpy as np # For numerical operations, especially with embeddings.
import pandas as pd # For data manipulation and CSV output.
from tqdm import tqdm # For displaying progress bars during long operations.
import psycopg2 # The Python PostgreSQL adapter.
import psycopg2.extras # For advanced psycopg2 features like execute_values.
from pgvector.psycopg2 import register_vector # Utility to register vector type with psycopg2.

# -------------------- Environment Setup ---------------------
# This section handles Google Drive mounting and API key retrieval,
# adapting to whether the notebook is run in Colab or a local environment.

# DRIVE_PATH: Defines where benchmark results will be saved. Default is the current directory.
# If running in Google Colab, it will be automatically updated to a path in your mounted Google Drive.
DRIVE_PATH = "."

# MAIR_COMBINED_PATH: **User Customizable Path**
# This is the path to your combined MAIR datasets in Google Drive. It should match the SAVE_PATH from the
# 'Combine MAIR Docs and Queries' step. Ensure this path is correct for your setup.
MAIR_COMBINED_PATH = "/content/gdrive/MyDrive/Moorcheh/MAIR_Datasets/MAIR-Combined"

try:
    # This block attempts to configure for Google Colab, leveraging its `drive` and `userdata` (Secrets) features.
    from google.colab import drive, userdata as colab_userdata

    try:
        # Mount Google Drive to allow Colab to access your files.
        # `force_remount=True` ensures a fresh mount if needed, otherwise Colab might remember the authentication.
        drive.mount('/content/gdrive')
        print("‚úÖ Google Drive mounted successfully")
    except Exception as e:
        print(f"‚ö†Ô∏è Drive mount warning: {e}")
        print("Continuing without Drive mount...")

    # DRIVE_PATH: **User Customizable Path**
    # If running in Colab, results will be saved here within your Google Drive.
    # You can customize this path (e.g., '/content/gdrive/MyDrive/MyProject/BenchmarkResults')
    # to organize your benchmark outputs effectively.
    DRIVE_PATH = '/content/gdrive/MyDrive/Moorcheh/Benchmark_Results/MAIR.PGVector.Results'
    os.makedirs(DRIVE_PATH, exist_ok=True) # Creates the directory if it doesn't exist.
    print(f"‚úÖ Running in Colab. Benchmark results will be saved to: {DRIVE_PATH}")

    # Retrieve COHERE_API_KEY from Colab secrets. **User Action Required**:
    # To use this, you need to add your API key to Colab's "Secrets" panel.
    # Look for the key icon (üîë) on the left sidebar of your Colab notebook.
    # Name your secret EXACTLY as `COHERE_API_KEY`.
    COHERE_API_KEY = colab_userdata.get('COHERE_API_KEY')

except ImportError:
    # This block runs if not in Google Colab (e.g., local Python environment).
    DRIVE_PATH = "." # Results will be saved in the current working directory.
    print("‚ö†Ô∏è Not running in Google Colab. Results will be saved locally. Ensure COHERE_API_KEY is set as an environment variable.")
    # Retrieve COHERE_API_KEY from environment variables.
    COHERE_API_KEY = os.environ.get('COHERE_API_KEY')

# PostgreSQL connection parameters.
# These define how the script connects to the local PostgreSQL instance set up in Step 1.
PG_CONN_PARAMS = {
    'host': 'localhost',
    'port': '5432',
    'database': 'vectordb',
    'user': 'postgres',
    'password': 'postgres'
}

# -------------------- Configuration Parameters ---------------------
# These parameters control various aspects of the benchmark. Users can adjust these values
# to customize the benchmark's behavior, performance, and resource usage.

# TOP_K_SEARCH: **User Customizable Value**
# The number of top-ranked results to retrieve from the vector database for each query.
# A higher value might improve recall but generally increases search latency and resource usage.
TOP_K_SEARCH = 100

# K_VALUES: **User Customizable Value**
# A list of 'k' values at which retrieval metrics (NDCG, MAP, Recall, Precision) will be calculated.
# These values define the cut-off points for evaluation (e.g., NDCG@1, MAP@10, Recall@100).
# You can add or remove values based on your evaluation needs.
K_VALUES = [1, 3, 5, 10, 100]

# MAX_UPLOAD_DOCS: **User Customizable Value**
# Limits the number of documents uploaded to pgvector. This is crucial for managing costs
# and execution time, especially with very large datasets. Set to a lower number (e.g., 10000)
# for quick tests or a very high number (e.g., 700000) for comprehensive runs. Set to `None`
# or a number greater than your dataset size to upload all documents.
MAX_UPLOAD_DOCS = 700000

# BATCH_SIZE: **User Customizable Value**
# The number of embeddings to process or upload in a single API request/batch.
# Adjusting this value can significantly impact performance, memory usage, and API rate limits.
# Larger batches are generally faster due to reduced overhead but consume more RAM.
BATCH_SIZE = 100

# EMBEDDING_MODEL: **User Customizable Value**
# The Cohere model used to generate the dense float embeddings.
# 'embed-v4.0' is recommended for its performance and higher dimensionality. Other options include
# 'embed-english-v3.0', 'embed-multilingual-v3.0', etc. Changing this will require adjusting VECTOR_DIMENSION.
EMBEDDING_MODEL = "embed-v4.0"

# INPUT_TYPE_CORPUS: Specifies the input type for corpus documents to Cohere's embedding model.
# This helps Cohere optimize embedding generation for different content types (e.g., 'search_document').
INPUT_TYPE_CORPUS = "search_document"

# INPUT_TYPE_QUERY: Specifies the input type for queries to Cohere's embedding model.
# Similar to corpus input type, this optimizes query embedding generation (e.g., 'search_query').
INPUT_TYPE_QUERY = "search_query"

# VECTOR_DIMENSION: **User Customizable Value (Must match EMBEDDING_MODEL)**
# The dimensionality of the generated embeddings. This value is critical and MUST match the output
# dimension of the `EMBEDDING_MODEL` you choose. Cohere's 'embed-v4.0' has 1536 dimensions;
# 'embed-v3.0' has 1024 dimensions. Incorrect dimension will lead to errors in pgvector.
VECTOR_DIMENSION = 1536

# CSV_PATH: **User Customizable Path**
# The full path where the final benchmark results will be saved in CSV format.
# This file will be created or appended to in your Google Drive (or local directory).
CSV_PATH = os.path.join(DRIVE_PATH, "MAIR.PGVector.Cohere.V4.csv")

# -------------------- Cleanup Policy ---------------------
# This section allows the user to define how the PostgreSQL tables created by pgvector
# are handled after each benchmark run. This helps manage local storage and state.
print("\nüßπ Table cleanup policy options:")
print("  1) ask_each_time  -> Prompt after each dataset (default)")
print("  2) always_delete  -> Automatically delete after benchmarking")
print("  3) always_keep    -> Never delete (keep for future searches or debugging)")

# ‚û°Ô∏è USER CONFIGURATION: Choose your preferred cleanup policy.
cleanup_choice = input("Choose policy [1/2/3] (default 1): ").strip() or "1"
CLEANUP_POLICY_MAP = {"1": "ask_each_time", "2": "always_delete", "3": "always_keep"}
CLEANUP_POLICY = CLEANUP_POLICY_MAP.get(cleanup_choice, "ask_each_time")
print(f"‚úÖ Selected cleanup policy: {CLEANUP_POLICY}")

# -------------------- Helper Functions ---------------------
# These functions facilitate data loading, processing, and metric calculation for the benchmark.

def load_jsonl(filepath):
    """Loads data from a JSONL (JSON Lines) file into a dictionary.
    It intelligently identifies various common ID field names to create a mapping
    from item ID to its content. This is flexible for MAIR's varied formats.
    """
    data = {}
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    item_id = item.get('_id') or item.get('id') or item.get('query_id') or item.get('doc_id')
                    if item_id:
                        data[str(item_id)] = item
                    else:
                        data[str(len(data))] = item
        print(f"   Loaded {len(data)} items from {os.path.basename(filepath)}")
    except Exception as e:
        print(f"‚ö†Ô∏è Error loading {filepath}: {e}")
    return data

def extract_qrels_from_queries(queries):
    """Extracts Qrels (query-relevance judgments) from query data structures.
    This function is tailored for MAIR datasets where qrels might be embedded
    within query files, often in fields like 'labels', 'relevance', or 'qrels'.
    It supports different formats for these fields (list of dicts, list of IDs, or dict).
    """
    qrels = {}
    for qid, q in queries.items():
        if isinstance(q, dict):
            labels = q.get('labels') or q.get('relevance') or q.get('qrels')
            if labels:
                qrels[str(qid)] = {}
                if isinstance(labels, list):
                    for label_item in labels:
                        if isinstance(label_item, dict):
                            doc_id = label_item.get('id') or label_item.get('doc_id')
                            score = label_item.get('score', 1)
                            if doc_id:
                                qrels[str(qid)][str(doc_id)] = score
                        else:
                            qrels[str(qid)][str(label_item)] = 1
                elif isinstance(labels, dict):
                    for doc_id, score in labels.items():
                        qrels[str(qid)][str(doc_id)] = score
    return qrels

# Predefined categories for MAIR datasets, used to organize their display to the user.
# This mapping helps in presenting a structured and navigable list of available datasets.
DATASET_CATEGORIES = {
    "Legal & Regulatory": ["ACORDAR", "AILA2019-Case", "AILA2019-Statutes", "CUAD", "LeCaRDv2", "LegalQuAD", "REGIR-EU2UK", "REGIR-UK2EU"],
    "Medical & Clinical": ["CliniDS-2014", "CliniDS-2015", "CliniDS-2016", "ClinicalTrials-2021", "ClinicalTrials-2022", "ClinicalTrials-2023", "NFCorpus", "PrecisionMedicine", "Genomics-AdHoc"],
    "Code & Programming": ["APPS", "CodeEditSearch", "CodeSearchNet", "Conala", "HumanEval-X", "LeetCode", "MBPP", "RepoBench", "SWE-Bench-Lite"],
    "Financial": ["ConvFinQA", "FiQA", "FinQA", "FinanceBench", "HC3Finance"],
    "Academic & Scientific": ["ArguAna", "LitSearch", "ProofWiki-Proof", "ProofWiki-Reference", "Competition-Math"],
    "Conversational & Dialog": ["CAsT-2019", "CAsT-2020", "CAsT-2021", "CAsT-2022", "ProCIS-Dialog", "ProCIS-Turn", "SParC", "Quora"],
    "News & Social Media": ["ChroniclingAmericaQA", "Microblog-2011", "Microblog-2012", "Microblog-2013", "Microblog-2014", "News21"],
    "API Documentation": ["Apple", "FoodAPI", "HuggingfaceAPI", "PytorchAPI"],
    "Others": ["BSARD", "BillSum", "CARE", "CPCD", "CQADupStack", "DD", "ELI5", "ExcluIR", "FairRanking", "Fever", "GerDaLIR", "IFEval", "InstructIR", "MISeD", "Monant", "NTCIR", "NeuCLIR", "NevIR", "PointRec", "ProductSearch_2023", "QuanTemp", "Robust04"]
}

def get_dataset_category(dataset_name):
    """Assigns a given dataset name to its predefined category."""
    for category, datasets in DATASET_CATEGORIES.items():
        if dataset_name in datasets:
            return category
    return "Others"

def get_dataset_size(dataset_name):
    """Counts the number of documents in a MAIR dataset by reading its JSONL files."""
    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    if not os.path.exists(docs_path):
        return 0
    doc_count = 0
    try:
        docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
        for file in docs_files:
            file_path = os.path.join(docs_path, file)
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        doc_count += 1
    except Exception as e:
        print(f"‚ö†Ô∏è Error counting docs in {dataset_name}: {e}")
    return doc_count

def format_size(num):
    """Formats a number (document count) with K (thousands) or M (millions) suffix for readability."""
    if num >= 1_000_000:
        return f"{num/1_000_000:.1f}M"
    elif num >= 1_000:
        return f"{num/1_000:.1f}K"
    return str(num)

def get_mair_datasets():
    """Discovers available MAIR datasets in the combined path, categorizes them, and gets their sizes."""
    datasets_by_category = {category: [] for category in DATASET_CATEGORIES.keys()}
    all_datasets = []
    dataset_sizes = {}
    if os.path.exists(MAIR_COMBINED_PATH):
        print(f"üîç Scanning: {MAIR_COMBINED_PATH}")
        try:
            items = os.listdir(MAIR_COMBINED_PATH)
            for item in sorted(items):
                if item.startswith('.') or item.startswith('_'):
                    continue
                dataset_path = os.path.join(MAIR_COMBINED_PATH, item)
                if not os.path.isdir(dataset_path):
                    continue
                docs_path = os.path.join(dataset_path, 'docs')
                queries_path = os.path.join(dataset_path, 'queries')
                if os.path.exists(docs_path) and os.path.exists(queries_path):
                    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
                    queries_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
                    if docs_files and queries_files:
                        all_datasets.append(item)
                        category = get_dataset_category(item)
                        datasets_by_category[category].append(item)
                        doc_count = get_dataset_size(item)
                        dataset_sizes[item] = doc_count
                        print(f"   ‚úÖ {item} [{category}] - {format_size(doc_count)} docs")
        except Exception as e:
            print(f"   ‚ùå Error scanning directory: {e}")
    return all_datasets, datasets_by_category, dataset_sizes

def calculate_timing_stats(timing_list):
    """Calculates basic descriptive statistics (mean, median, min, max, std dev, total sum)
    for a given list of numerical timings. This is used to summarize performance metrics
    like search times, upload durations, etc., providing a quick overview of performance.
    """
    if not timing_list:
        return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0, "total": 0.0}
    return {
        "mean": statistics.mean(timing_list),
        "median": statistics.median(timing_list),
        "min": min(timing_list),
        "max": max(timing_list),
        "std": statistics.stdev(timing_list) if len(timing_list) > 1 else 0.0,
        "total": sum(timing_list)
    }

def calculate_retrieval_metrics(retrieved_results, qrels, k_values=K_VALUES):
    """Calculates NDCG, MAP, Recall, and Precision @K for given retrieval results and ground truth qrels."""
    ndcg_scores = {f"NDCG@{k}": [] for k in k_values}
    map_scores = {f"MAP@{k}": [] for k in k_values}
    recall_scores = {f"Recall@{k}": [] for k in k_values}
    precision_scores = {f"P@{k}": [] for k in k_values}

    for qid in retrieved_results:
        if qid not in qrels:
            continue
        relevant_docs = set(qrels[qid].keys())
        retrieved_docs = list(retrieved_results[qid].keys())
        if len(relevant_docs) == 0:
            continue

        for k in k_values:
            top_k = retrieved_docs[:k]
            hits = len(set(top_k) & relevant_docs)
            recall = hits / len(relevant_docs) if len(relevant_docs) > 0 else 0.0
            recall_scores[f"Recall@{k}"].append(recall)
            precision = hits / k if k > 0 else 0.0
            precision_scores[f"P@{k}"].append(precision)

            ap = 0.0
            hits_so_far = 0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    hits_so_far += 1
                    ap += hits_so_far / i
            ap = ap / min(len(relevant_docs), k) if len(relevant_docs) > 0 else 0.0
            map_scores[f"MAP@{k}"].append(ap)

            dcg = 0.0
            for i, doc_id in enumerate(top_k, 1):
                if doc_id in relevant_docs:
                    dcg += 1.0 / np.log2(i + 1)
            idcg = sum(1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), k)))
            ndcg = dcg / idcg if idcg > 0 else 0.0
            ndcg_scores[f"NDCG@{k}"].append(ndcg)

    ndcg_avg = {k: np.mean(v) if v else 0.0 for k, v in ndcg_scores.items()}
    map_avg = {k: np.mean(v) if v else 0.0 for k, v in map_scores.items()}
    recall_avg = {k: np.mean(v) if v else 0.0 for k, v in recall_scores.items()}
    precision_avg = {k: np.mean(v) if v else 0.0 for k, v in precision_scores.items()}
    return ndcg_avg, map_avg, recall_avg, precision_avg

def format_and_print_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Formats and prints retrieval metrics (NDCG, MAP, Recall, Precision) in a clean,
    tabular format for specified K values. This provides an immediate, human-readable
    summary of the retrieval quality for a benchmark run.
    """
    print("\nRetrieval Metrics:")
    print("‚îÄ" * 80)
    for k in k_values:
        print(f"NDCG@{k}: {ndcg.get(f'NDCG@{k}', 0.0):.4f} | MAP@{k}: {_map.get(f'MAP@{k}', 0.0):.4f} | Recall@{k}: {recall.get(f'Recall@{k}', 0.0):.4f} | P@{k}: {precision.get(f'P@{k}', 0.0):.4f}")

def extract_all_metrics(ndcg, _map, recall, precision, k_values=K_VALUES):
    """Extracts all relevant retrieval metrics into a single dictionary.
    This structured format is ideal for storage, particularly for CSV output,
    ensuring all evaluation results are consistently captured.
    """
    metrics = {}
    for k in k_values:
        metrics[f"NDCG@{k}"] = float(ndcg.get(f"NDCG@{k}", 0.0)) if ndcg.get(f"NDCG@{k}") is not None else 0.0
        metrics[f"MAP@{k}"] = float(_map.get(f"MAP@{k}", 0.0))
        metrics[f"Recall@{k}"] = float(recall.get(f"Recall@{k}", 0.0))
        metrics[f"P@{k}"] = float(precision.get(f"P@{k}", 0.0))
    return metrics

def save_results_to_csv(new_result: dict, csv_path: str):
    """Appends a new benchmark result entry to a CSV file. If the file doesn't exist,
    it creates it along with the header. Otherwise, it appends the new data,
    ensuring data integrity and continuity of results over multiple runs.
    """
    new_df = pd.DataFrame([new_result])
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    write_header = not os.path.exists(csv_path)
    try:
        if write_header:
            new_df.to_csv(csv_path, mode='w', header=True, index=False)
        else:
            existing_df = pd.read_csv(csv_path)
            combined_df = pd.concat([existing_df, new_df], ignore_index=True)
            combined_df.to_csv(csv_path, mode='w', header=True, index=False)
        print(f"üíæ Saved to: {csv_path}")
    except Exception as e:
        print(f"‚ùå CSV save failed: {e}")

def clean_memory():
    """Forces Python's garbage collector to release memory.
    This is particularly important in resource-constrained environments like Colab,
    especially when processing large datasets, to prevent out-of-memory errors.
    """
    gc.collect()

def should_cleanup_table(table_name, dataset_name):
    """Determines whether to delete the PGVector table based on the user-selected cleanup policy."""
    if CLEANUP_POLICY == "always_delete":
        return True # Automatically delete without asking
    elif CLEANUP_POLICY == "always_keep":
        return False # Never delete, keep the table
    else: # CLEANUP_POLICY == "ask_each_time"
        response = input(f"\n‚ùì Delete PGVector table {table_name} for {dataset_name}? (y/n): ").strip().lower()
        return response in ['y', 'yes']

# -------------------- PGVector Provider Class ---------------------
# This class encapsulates the specific logic for interacting with PostgreSQL
# and the pgvector extension for dense vector benchmarking.

class PGVectorProvider:
    """Manages interaction with PostgreSQL and pgvector for non-binary vector benchmarking."""
    def __init__(self, conn_params, table_name, precomputed_vectors, query_embeddings, vector_dim=1536):
        self.conn_params = conn_params # Dictionary of PostgreSQL connection parameters
        self.table_name = table_name # Name of the table to create/use in PostgreSQL
        self.precomputed_vectors = precomputed_vectors # List of documents with their IDs and float vectors
        self.query_embeddings = query_embeddings # List of query float embeddings
        self.vector_dim = vector_dim # Dimension of the vectors (length of the float array)
        self.conn = None # Placeholder for the database connection object
        self.upload_timings = {
            "table_creation_s": 0.0, # Time to create the database table
            "insert_time_s": 0.0, # Total time for inserting vectors
            "index_creation_s": 0.0, # Time to create the HNSW index
            "batch_details": [] # Detailed timings for each insert batch
        }
        self.search_timings = [] # Timings for individual search queries

    def connect(self):
        """Establishes a connection to the PostgreSQL database using the provided parameters."""
        try:
            self.conn = psycopg2.connect(**self.conn_params) # Connect to DB
            register_vector(self.conn) # Register the pgvector extension with the psycopg2 connection
            print(f"‚úÖ Connected to PostgreSQL database")
        except Exception as e:
            print(f"‚ùå Connection failed: {e}")
            raise # Re-raise to stop execution if connection fails

    def upload(self):
        """Uploads float vectors to the PGVector table using the `vector` data type."""
        if not self.conn:
            self.connect()

        cur = self.conn.cursor() # Create a cursor object for executing SQL commands

        # Drop table if it exists to ensure a clean benchmark run.
        try:
            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
            self.conn.commit() # Commit the transaction
        except Exception as e:
            self.conn.rollback()

        # Create table with `vector(N)` type for dense float embeddings.
        t0 = time.time() # Start timer for table creation
        try:
            cur.execute("CREATE EXTENSION IF NOT EXISTS vector") # Ensure pgvector extension is enabled
            cur.execute(f"""
                CREATE TABLE {self.table_name} (
                    id TEXT PRIMARY KEY,                       -- Document ID
                    embedding vector({self.vector_dim})           -- Float embedding stored as a vector array
                )
            """)
            self.conn.commit() # Commit the table creation
            t1 = time.time() # End timer
            self.upload_timings["table_creation_s"] = t1 - t0
            print(f"‚úÖ Created table: {self.table_name} (dim: {self.vector_dim})")
        except Exception as e:
            print(f"‚ùå Table creation failed: {e}")
            self.conn.rollback()
            raise

        # Insert vectors in batches using `execute_values` for efficiency.
        batch_num = 0
        total_inserted = 0
        t_insert_start = time.time() # Start timer for total insert time

        for i in tqdm(range(0, len(self.precomputed_vectors), BATCH_SIZE), desc="Uploading to PGVector"):
            batch = self.precomputed_vectors[i:i+BATCH_SIZE] # Get a batch of vectors
            batch_num += 1

            t_batch_start = time.time() # Start timer for current batch
            try:
                # Prepare data for `execute_values`: list of tuples (id, vector_list).
                values = [(v['id'], v['vector']) for v in batch]
                # `psycopg2.extras.execute_values` is much faster than individual INSERTs.
                # `template="(%s, %s::vector)"` specifies how the values are formatted for the SQL query.
                # `::vector` casts the Python list to a pgvector array type.
                psycopg2.extras.execute_values(
                    cur,
                    f"INSERT INTO {self.table_name} (id, embedding) VALUES %s",
                    values,
                    template="(%s, %s::vector)"
                )
                self.conn.commit() # Commit the batch insert

                t_batch_end = time.time() # End timer for current batch
                batch_time = t_batch_end - t_batch_start
                total_inserted += len(batch)

                self.upload_timings["batch_details"].append({
                    "batch_num": batch_num,
                    "batch_size": len(batch),
                    "insert_time_s": batch_time
                })

            except Exception as e:
                print(f"\n‚ùå Batch {batch_num} failed: {e}")
                self.conn.rollback() # Rollback the current batch on error

        t_insert_end = time.time() # End timer for total insert time
        self.upload_timings["insert_time_s"] = t_insert_end - t_insert_start

        # Create HNSW index for Approximate Nearest Neighbor (ANN) search.
        # `vector_cosine_ops` tells pgvector to use cosine similarity for this index type.
        # `m` and `ef_construction` are HNSW parameters influencing index quality and build time.
        print(f"\nüîß Creating HNSW index...")
        t_index_start = time.time() # Start timer for index creation
        try:
            # ‚û°Ô∏è USER CONFIGURATION: You can adjust `m` and `ef_construction` for different performance/accuracy tradeoffs.
            # Higher `m` and `ef_construction` lead to better accuracy but longer build times and more memory.
            cur.execute(f"""
                CREATE INDEX ON {self.table_name}
                USING hnsw (embedding vector_cosine_ops)
                WITH (m = 16, ef_construction = 64)
            """) # SQL command to create HNSW index
            self.conn.commit() # Commit the index creation
            t_index_end = time.time() # End timer
            self.upload_timings["index_creation_s"] = t_index_end - t_index_start
            print(f"‚úÖ HNSW index created in {self.upload_timings['index_creation_s']:.4f}s")
        except Exception as e:
            print(f"‚ö†Ô∏è Index creation warning: {e}. Indexing is important for search performance. Proceeding without index.")
            self.conn.rollback() # Rollback on error

        cur.close() # Close the cursor

        print(f"\n‚è±Ô∏è  PGVector Upload Timing:")
        print(f"    Table Creation: {self.upload_timings['table_creation_s']:.4f}s")
        print(f"    Insert Total: {self.upload_timings['insert_time_s']:.4f}s")
        print(f"    Index Creation: {self.upload_timings['index_creation_s']:.4f}s")
        print(f"    Total: {sum([self.upload_timings['table_creation_s'], self.upload_timings['insert_time_s'], self.upload_timings['index_creation_s']]):.4f}s")

        return total_inserted # Return count of successfully inserted vectors

    def search(self, query_idx, top_k=100):
        """Performs a vector similarity search using cosine similarity in pgvector."""
        if not self.conn:
            self.connect()

        query_embedding = self.query_embeddings[query_idx] # Get the float query vector
        cur = self.conn.cursor() # Create a cursor

        t_start = time.time() # Start timer for search query
        try:
            # Use `<=>` operator for cosine distance (1 - cosine similarity) in pgvector.
            # The query string `::vector` casts the Python list to a vector type for comparison.
            # We select `id` and calculate similarity as `1 - distance`.
            # Results are ordered by `embedding <=> query` (lower distance = more similar).
            cur.execute(f"""
                SELECT id, 1 - (embedding <=> %s::vector) as similarity
                FROM {self.table_name}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            """, (query_embedding, query_embedding, top_k))

            results = cur.fetchall() # Fetch all results
            t_end = time.time() # End timer

            search_time = t_end - t_start
            self.search_timings.append({"search_time_s": search_time}) # Record search time

            # Format results into a dictionary of {doc_id: score}.
            return {str(row[0]): float(row[1]) for row in results}

        except Exception as e:
            print(f"\n‚ùå Search error: {e}")
            self.search_timings.append({"search_time_s": 0.0}) # Log 0 time for failed searches
            return {} # Return empty results on error
        finally:
            cur.close() # Always close the cursor

    def get_search_stats(self):
        """Calculates and returns search timing statistics for PGVector."""
        search_times = [t["search_time_s"] for t in self.search_timings] # Extract all search times
        return calculate_timing_stats(search_times) # Use helper to calculate stats

    def cleanup(self):
        """Drops the PostgreSQL table created for the benchmark and closes the database connection."""
        if not self.conn:
            return

        try:
            cur = self.conn.cursor() # Create cursor
            cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") # Drop table
            self.conn.commit() # Commit transaction
            cur.close() # Close cursor
            print(f"üßπ Deleted table: {self.table_name}")
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to delete table: {e}")
        finally:
            self.conn.close() # Always close the database connection
            self.conn = None

# -------------------- Main Benchmark Function ---------------------
# This function orchestrates the entire benchmarking process for a single dataset
# with pgvector, including data loading, embedding generation, upload, search, and evaluation.

def run_benchmark_pgvector(dataset_name, pg_conn_params):
    """Runs a full benchmark for PGVector with float embeddings for a specified MAIR dataset."""
    print(f"\n{'='*70}")
    print(f"üöÄ Dataset: {dataset_name}")
    print(f"üìä Testing PGVector with Cohere {EMBEDDING_MODEL}")
    print(f"{'='*70}")

    # Construct paths to the dataset within the combined MAIR directory.
    dataset_path = os.path.join(MAIR_COMBINED_PATH, dataset_name)
    docs_path = os.path.join(dataset_path, 'docs')
    queries_path = os.path.join(dataset_path, 'queries')

    # Verify dataset files exist before proceeding.
    if not os.path.exists(docs_path) or not os.path.exists(queries_path):
        print(f"‚ö†Ô∏è Dataset not found")
        return None

    print(f"\nüì• Loading dataset...")
    corpus = {}
    docs_files = [f for f in os.listdir(docs_path) if f.endswith('.jsonl')]
    for file in docs_files:
        corpus.update(load_jsonl(os.path.join(docs_path, file))) # Load document JSONL files

    queries = {}
    query_files = [f for f in os.listdir(queries_path) if f.endswith('.jsonl')]
    for file in query_files:
        queries.update(load_jsonl(os.path.join(queries_path, file))) # Load query JSONL files

    qrels = extract_qrels_from_queries(queries) # Extract relevance judgments from queries
    print(f"‚úÖ Loaded: {len(corpus)} docs, {len(queries)} queries, {len(qrels)} qrels")

    if not corpus or not queries or not qrels: # Early exit if essential data is missing
        print(f"‚ö†Ô∏è Empty dataset")
        return None

    print(f"\nüß† Generating embeddings with Cohere {EMBEDDING_MODEL}...")
    # Limit the number of documents to embed and upload, based on MAX_UPLOAD_DOCS configuration.
    docs = list(corpus.items())[:MAX_UPLOAD_DOCS]

    texts = []
    for doc_id, doc_content in docs:
        if isinstance(doc_content, dict):
            text = None
            for field in ['doc', 'text', 'content', 'body', 'passage', 'document', 'title', 'abstract']:
                if field in doc_content and doc_content[field]:
                    val = doc_content[field]
                    if isinstance(val, str) and val.strip():
                        text = val
                        break
            if not text:
                for key, val in doc_content.items():
                    if isinstance(val, str) and val.strip() and len(val) > 10:
                        text = val
                        break
            if not text:
                text = str(doc_content)
        else:
            text = str(doc_content)

        texts.append(text if text else "document") # Default text if extraction yields empty string

    doc_ids = [str(d[0]) for d in docs] # Ensure all document IDs are strings

    print(f"  üìÑ Generating corpus embeddings for {len(texts)} documents...")
    t_embed_start = time.time() # Start timer for corpus embedding

    corpus_embeddings = []
    # Batch processing for Cohere API calls for corpus embeddings.
    for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Embedding corpus"):
        batch_texts = texts[i:i+BATCH_SIZE]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_CORPUS)
        corpus_embeddings.extend(response.embeddings)

    t_embed_end = time.time() # End timer
    embedding_time = t_embed_end - t_embed_start

    # Prepare precomputed vectors in the format expected by the PGVector provider.
    precomputed_vectors = []
    for doc_id, embedding in zip(doc_ids, corpus_embeddings):
        precomputed_vectors.append({"id": doc_id, "vector": embedding})

    print(f"  üîç Generating query embeddings for {len(queries)} queries...")
    query_ids = list(queries.keys())
    query_texts = []
    for qid in query_ids:
        q = queries[qid]
        if isinstance(q, dict):
            text = (q.get('text') or q.get('query') or q.get('instruction') or q.get('question') or q.get('query_text') or str(q))
        else:
            text = str(q)
        query_texts.append(text if text else "query")

    t_query_start = time.time() # Start timer for query embedding
    query_embeddings = []
    # Batch processing for Cohere API calls for query embeddings.
    for i in tqdm(range(0, len(query_texts), BATCH_SIZE), desc="Embedding queries"):
        batch_texts = [str(t) if t else "query" for t in query_texts[i:i+BATCH_SIZE]]
        response = cohere_client.embed(texts=batch_texts, model=EMBEDDING_MODEL, input_type=INPUT_TYPE_QUERY)
        query_embeddings.extend(response.embeddings)

    t_query_end = time.time() # End timer
    query_embedding_time = t_query_end - t_query_start

    print(f"\n‚è±Ô∏è  Embedding Generation Time:")
    print(f"    Corpus: {embedding_time:.4f}s, Queries: {query_embedding_time:.4f}s, Total: {embedding_time + query_embedding_time:.4f}s")

    # Generate a safe table name for PostgreSQL from the dataset name.
    # PostgreSQL table names have a maximum length of 63 characters.
    safe_dataset_name = dataset_name.replace('/', '_').replace('-', '_').lower()
    table_name = f"mair_{safe_dataset_name}_v4"[:63]

    try:
        # Initialize the PGVector provider class.
        provider = PGVectorProvider(
            conn_params=pg_conn_params,
            table_name=table_name,
            precomputed_vectors=precomputed_vectors,
            query_embeddings=query_embeddings,
            vector_dim=VECTOR_DIMENSION
        )

        print(f"\nüì§ Uploading vectors to PGVector table: {table_name}...")
        num_uploaded = provider.upload() # Upload vectors to PostgreSQL

        print(f"\nüîç Performing vector similarity search...")
        results = {} # Dictionary to store retrieved results {query_id: {doc_id: score}}
        for i, qid in enumerate(tqdm(query_ids, desc="Searching PGVector")):
            try:
                results[qid] = provider.search(i, top_k=TOP_K_SEARCH)
            except Exception as e:
                print(f"\n‚ùå Query {i+1} failed: {e}")
                results[qid] = {} # Log empty results for failed queries

        search_stats = provider.get_search_stats() # Get search timing statistics
        print(f"\n‚è±Ô∏è  PGVector Search Timing:")
        print(f"    Total: {search_stats['total']:.4f}s")
        print(f"    Mean: {search_stats['mean']:.4f}s ({search_stats['mean']*1000:.2f}ms)")
        print(f"    Median: {search_stats['median']:.4f}s")
        print(f"    Min: {search_stats['min']:.4f}s, Max: {search_stats['max']:.4f}s")

        print(f"\nüìä Evaluating retrieval quality...")
        ndcg, _map, recall, precision = calculate_retrieval_metrics(results, qrels, K_VALUES) # Calculate metrics
        format_and_print_metrics(ndcg, _map, recall, precision) # Print formatted metrics

        metrics = extract_all_metrics(ndcg, _map, recall, precision) # Extract all metrics into a dictionary
        dataset_category = get_dataset_category(dataset_name) # Get dataset category for result entry

        result = { # Compile all results and metadata into a dictionary
            "Dataset": dataset_name,
            "Category": dataset_category,
            "Provider": "pgvector",
            "Num_Corpus": len(corpus),
            "Num_Uploaded": num_uploaded,
            "Num_Queries": len(queries),
            "Embedding_Model": EMBEDDING_MODEL,
            "Vector_Dimension": VECTOR_DIMENSION,
            "Batch_Size": BATCH_SIZE,
            "Cleanup_Policy": CLEANUP_POLICY,
            "Embedding_Generation_Time_s": round(embedding_time, 4),
            "Query_Embedding_Time_s": round(query_embedding_time, 4),
            "Total_Embedding_Time_s": round(embedding_time + query_embedding_time, 4),
            "Upload_Table_Creation_s": round(provider.upload_timings["table_creation_s"], 4),
            "Upload_Insert_Total_s": round(provider.upload_timings["insert_time_s"], 4),
            "Upload_Index_Creation_s": round(provider.upload_timings["index_creation_s"], 4),
            "Upload_Total_s": round(
                provider.upload_timings["table_creation_s"] +
                provider.upload_timings["insert_time_s"] +
                provider.upload_timings["index_creation_s"], 4
            ),
            "Search_Total_s": round(search_stats['total'], 4),
            "Search_Mean_s": round(search_stats['mean'], 4),
            "Search_Median_s": round(search_stats['median'], 4),
            "Search_Min_s": round(search_stats['min'], 4),
            "Search_Max_s": round(search_stats['max'], 4),
            "Search_Std_s": round(search_stats['std'], 4),
            "Search_Mean_ms": round(search_stats['mean'] * 1000, 2),
            "Index_Type": "HNSW",
            "Similarity_Metric": "cosine"
        }

        result.update(metrics) # Add the retrieval metrics to the result dictionary
        save_results_to_csv(result, CSV_PATH) # Save this run's results to the CSV file

        should_delete = should_cleanup_table(table_name, dataset_name) # Determine cleanup action based on policy
        if should_delete:
            provider.cleanup()
        else:
            print(f"üíæ Keeping PGVector table: {table_name}")

        del provider # Explicitly delete provider object
        del results # Clear search results
        clean_memory() # Force garbage collection

        print(f"\n‚úÖ PGVector benchmark completed for {dataset_name}")
        return result # Return the result entry for this dataset

    except Exception as e:
        print(f"‚ùå Error during benchmark: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging
        clean_memory()
        return None # Return None on error
    finally:
        # Clean up large objects from memory regardless of success or failure.
        del precomputed_vectors
        del query_embeddings
        del corpus_embeddings
        clean_memory()


# -------------------- MAIN EXECUTION BLOCK ---------------------
# This block handles the overall flow of the benchmark, including API key checks,
# client initialization, dataset discovery, and the interactive loop for running benchmarks.

print("\n" + "="*70)
print("üéØ MAIR BENCHMARK - PGVECTOR EDITION (Google Colab)")
print("="*70)

print("\nüîë Checking API Keys and Database Connection...")

# Check Cohere API key availability.
if not COHERE_API_KEY:
    print("\n‚ùå Cohere API key required!")
    print("Add it to Colab Secrets with key 'COHERE_API_KEY'")
    exit(1) # Stop execution if key is missing

import cohere
cohere_client = cohere.Client(COHERE_API_KEY) # Initialize Cohere client.
print(f"‚úÖ Cohere client initialized")

# Test PostgreSQL connection.
try:
    test_conn = psycopg2.connect(**PG_CONN_PARAMS) # Attempt to connect to DB
    register_vector(test_conn) # Register vector type
    cur = test_conn.cursor() # Create cursor
    cur.execute("SELECT version();") # Get PostgreSQL version
    pg_version = cur.fetchone()[0]
    cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector';") # Get pgvector version
    pgvector_version = cur.fetchone()
    cur.close() # Close cursor
    test_conn.close() # Close connection
    print(f"‚úÖ PostgreSQL connection successful")
    print(f"   PostgreSQL: {pg_version.split(',')[0]}")
    if pgvector_version:
        print(f"   pgvector: v{pgvector_version[0]}")
    else:
        print(f"   ‚ö†Ô∏è pgvector extension not found")
except Exception as e:
    print(f"‚ùå PostgreSQL connection failed: {e}")
    exit(1) # Stop execution if DB connection fails

# Discover available MAIR datasets.
print("\nüìä Discovering datasets...")
mair_datasets, datasets_by_category, dataset_sizes = get_mair_datasets() # Use helper function

if not mair_datasets: # If no datasets are found, inform the user and exit.
    print("‚ùå No datasets found!")
    print(f"Please ensure datasets are in: {MAIR_COMBINED_PATH}")
    exit(1)

print(f"\nüìÇ Available datasets by category ({len(mair_datasets)} total):")
dataset_index = 1
dataset_map = {} # Map user-friendly index to actual dataset name

# Display datasets grouped by category and sorted by size for easy selection.
for category in sorted(datasets_by_category.keys()):
    datasets = datasets_by_category[category]
    if datasets:
        # Sort datasets within each category by document count (descending).
        sorted_datasets = sorted(datasets, key=lambda d: dataset_sizes.get(d, 0), reverse=True)
        print(f"\n  üìÅ {category} ({len(sorted_datasets)} datasets):")
        for dataset in sorted_datasets:
            size_str = format_size(dataset_sizes.get(dataset, 0)) # Format document count for display
            print(f"     {dataset_index}. {dataset} ({size_str} docs)")
            dataset_map[dataset_index] = dataset
            dataset_index += 1

all_results = [] # List to store all benchmark results.

# Main interactive benchmarking loop.
print("\n" + "="*70)
print("üöÄ READY TO BENCHMARK!")
print("="*70)

while True:
    choice = input(f"\n‚û°Ô∏è Enter dataset number (1-{len(dataset_map)}) or 'stop': ").strip().lower()
    if choice == "stop":
        break # Exit loop if user types 'stop'

    try:
        idx = int(choice)
        if idx in dataset_map:
            dataset_name = dataset_map[idx] # Get the selected dataset name
        else:
            print(f"‚ö†Ô∏è Invalid number. Please choose 1-{len(dataset_map)}")
            continue # Ask for input again
    except ValueError:
        print("‚ö†Ô∏è Invalid input. Enter a number or 'stop'")
        continue # Ask for input again

    # Run the benchmark for the selected dataset.
    result = run_benchmark_pgvector(dataset_name, PG_CONN_PARAMS)
    if result: # If the benchmark run was successful, add its result.
        all_results.append(result)

# -------------------- Final Summary and Analysis ---------------------
# After all selected datasets are processed, this section provides an overall summary
# of the benchmark results. It includes tables of key metrics, average performance
# comparisons, and insights into timing.

if all_results:
    print(f"\n{'='*70}")
    print("üèÅ BENCHMARK COMPLETE!")
    print(f"{'='*70}")

    df = pd.DataFrame(all_results) # Create a DataFrame from collected results.
    print(f"\nüíæ Results saved to: {CSV_PATH}")

    print("\nüìä Summary Table:")
    # Display key retrieval metrics and performance times.
    summary_cols = ['Dataset', 'Category', 'NDCG@10', 'MAP@10', 'Recall@100', 'Search_Mean_ms']
    available_cols = [col for col in summary_cols if col in df.columns] # Ensure columns exist
    print(df[available_cols].to_string(index=False))

    print("\nüìà Average Performance:")
    # Calculate averages for numeric metrics.
    numeric_cols = ['NDCG@10', 'MAP@10', 'Recall@100', 'Search_Mean_ms', 'Upload_Total_s']
    valid_cols = [col for col in numeric_cols if col in df.columns]
    if valid_cols:
        avg_vals = df[valid_cols].mean()
        print("\n  Retrieval Quality:")
        for col in ['NDCG@10', 'MAP@10', 'Recall@100']:
            if col in valid_cols:
                print(f"    {col}: {avg_vals[col]:.4f}")
        print("\n  Performance:")
        if 'Search_Mean_ms' in valid_cols:
            print(f"    Avg Search Time: {avg_vals['Search_Mean_ms']:.2f}ms")
        if 'Upload_Total_s' in valid_cols:
            print(f"    Avg Upload Time: {avg_vals['Upload_Total_s']:.2f}s")

    print("\n" + "="*70)
    print("‚ú® PGVector Benchmark Complete!")
    print("="*70)
    print("\nüìä Key Insights:")
    print(f"  ‚Ä¢ Tested {len(all_results)} dataset(s)")
    print(f"  ‚Ä¢ Using {EMBEDDING_MODEL} embeddings ({VECTOR_DIMENSION}D)")
    print(f"  ‚Ä¢ HNSW index for fast similarity search")
    print(f"  ‚Ä¢ Results: {CSV_PATH}")
else:
    print("\n‚ö†Ô∏è No results to display")

print("\nüí° Tip: PostgreSQL + pgvector is running locally in this Colab instance.")
print("   Tables are temporary and will be lost when the runtime disconnects.")