In [None]:
"""
# PRO-HNSW: Main Performance Evaluation

This notebook runs the core experiments to evaluate the performance of PRO-HNSW against the standard HNSW under various dynamic update scenarios.

The workflow is as follows:
1.  **Configuration**: Set up all parameters for the experiment.
2.  **Core Functions**: Define the necessary functions for building, updating, and evaluating the index.
3.  **Main Loop**: Iterate through datasets and deletion ratios to run experiments and save results to a CSV file.
"""

# ===================================================================
# 📝 1. CONFIGURATION SECTION
# All user-configurable parameters are defined here.
# ===================================================================

# ────────── Global Experiment Settings ──────────
# NOTE: For a quick test run, set BASE_SIZE_LIMIT and QUERY_SIZE_LIMIT to a small number (e.g., 10000).
# For the full experiment, set them to None.
BASE_SIZE_LIMIT = None
QUERY_SIZE_LIMIT = None
RECALCULATE_GT = False      # Set to True to re-calculate ground truth (automatically enabled if BASE_SIZE_LIMIT is used)
TRIALS_LATENCY = 3          # Number of trials for QPS measurement
DELETION_RATIOS = [0.2, 0.8] # List of deletion ratios to test

# ────────── HNSW Parameter Tiers ──────────
# Defines different sets of HNSW parameters for different dataset sizes.
HNSW_PARAMS_TIERS = {
    "small":  {"M": 8,  "Efc": 50,  "rein": 25, "efs_range": (25, 76, 1)},
    "medium": {"M": 12, "Efc": 75,  "rein": 50, "efs_range": (50, 101, 1)},
    "large":  {"M": 16, "Efc": 100, "rein": 75, "efs_range": (75, 126, 1)},
}
DEFAULT_HNSW_TIER = "medium"

# ────────── Dataset Profiles ──────────
# Maps each dataset to an HNSW parameter tier.
DATASET_HNSW_TIER_PROFILE = {
    # Large Tier
    "deep-image-96-angular": "large",
    "gist-960-euclidean": "large",
    # Medium Tier
    "sift-128-euclidean": "medium",
    "nytimes-256-angular": "medium",
    # Small Tier
    "fashion-mnist-784-euclidean": "small",
    "mnist-784-euclidean": "small",
    "coco-i2i-512-angular": "small",
    "glove-25-angular": "small",
}

# ────────── File Paths and Constants ──────────
# NOTE: Assumes datasets are in a 'data/' subdirectory and results will be saved to a 'results/' subdirectory.
DATA_ROOT_HDF5 = "../data"
RESULTS_DIR = "../results/bulk"
N_THREADS = 18
BF_CHUNK_SIZE = 600
RECALL_K = 10

# ===================================================================
# 🛠️ 2. SETUP & LIBRARY IMPORTS
# ===================================================================
import os
import time
import csv
import heapq
import numpy as np
import hnswlib
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
# This assumes you have a 'dataset_loader.py' or similar in a 'dataset' folder.
# Adjust the import path if necessary.
from data.hdf5_dataset_loader import load_dataset

# Prepare results directory
os.makedirs(RESULTS_DIR, exist_ok=True)
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# ===================================================================
# ⚙️ 3. DYNAMIC CONFIGURATION GENERATION
# This block automatically creates the final experiment configurations.
# Do not modify this block directly. Change the settings in Section 1 instead.
# ===================================================================

CFG_BASE_CONFIG = {
    "sift-128-euclidean": {"space": "l2"},
    "gist-960-euclidean": {"space": "l2"},
    "deep-image-96-angular": {"space": "cosine"},
    "glove-25-angular": {"space": "cosine"},
    "mnist-784-euclidean": {"space": "l2"},
    "fashion-mnist-784-euclidean": {"space": "l2"},
    "nytimes-256-angular": {"space": "cosine"},
    "coco-i2i-512-angular": {"space": "cosine"},
}

CFG = {}
for dname_key, base_settings in CFG_BASE_CONFIG.items():
    tier_name = DATASET_HNSW_TIER_PROFILE.get(dname_key, DEFAULT_HNSW_TIER)
    tier_params = HNSW_PARAMS_TIERS.get(tier_name, HNSW_PARAMS_TIERS[DEFAULT_HNSW_TIER])

    CFG[dname_key] = {
        **base_settings,
        **tier_params,
        "efs_list": list(range(*tier_params["efs_range"])),
        "base_size_limit": BASE_SIZE_LIMIT,
        "query_size_limit": QUERY_SIZE_LIMIT,
    }

print("✅ Configuration generated for the following datasets:")
for dname in CFG.keys():
    print(f"- {dname}")

# ===================================================================
# 🔬 4. CORE FUNCTIONS
# Helper functions for normalization, index building, updating, ground truth calculation, and evaluation.
# ===================================================================

def normalize_vectors(vectors: np.ndarray) -> np.ndarray:
    if vectors is None or vectors.shape[0] == 0: return vectors
    if vectors.ndim == 1:
        norm = np.linalg.norm(vectors)
        return vectors / (norm if norm > 1e-9 else 1e-9)
    elif vectors.ndim == 2:
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        return vectors / np.where(norms < 1e-9, 1e-9, norms)
    else:
        raise ValueError("normalize_vectors only supports 1D or 2D arrays.")

def build_index(base_vectors: np.ndarray, labels_arr: np.ndarray, M_val: int, Efc_val: int, space_val: str):
    dim = base_vectors.shape[1]
    print(f"  Building HNSW index with space='{space_val}', M={M_val}, efC={Efc_val}...")
    idx = hnswlib.Index(space_val, dim)
    idx.init_index(max_elements=len(base_vectors),
                   ef_construction=Efc_val,
                   M=M_val, allow_replace_deleted=True)
    
    # Ensure data is C-contiguous
    current_base_vectors = np.ascontiguousarray(base_vectors, dtype=np.float32)
    idx.add_items(current_base_vectors, labels_arr)
    print("  HNSW index build complete.")
    return idx

def update_index(idx: hnswlib.Index, base_vectors: np.ndarray, del_ids_arr: np.ndarray, rein_val: int, optimized=False, dname="", delta_val=0.0):
    opt_metrics = {
        "time_roe": 0.0, "metric_roe": 0,
        "time_rdn": 0.0, "metric_rdn": 0,
        "time_rea": 0.0, "metric_rea": 0,
    }
    
    if len(del_ids_arr) > 0:
        print(f"  Marking {len(del_ids_arr)} elements for deletion for {dname} (delta={delta_val})...")
        for d_id in del_ids_arr: idx.mark_deleted(d_id)

    if optimized:
        t_start_roe = time.time()
        removed_count = idx.remove_obsolete_edges(del_ids_arr.tolist() if len(del_ids_arr) > 0 else [])
        opt_metrics["time_roe"] = time.time() - t_start_roe
        opt_metrics["metric_roe"] = int(removed_count)
        if removed_count > 0 or opt_metrics["time_roe"] > 0.001:
            print(f"      [PRO-HNSW] remove_obsolete_edges: {removed_count} edges removed, time: {opt_metrics['time_roe']:.4f}s")

    if len(del_ids_arr) > 0:
        print(f"  Reinserting {len(del_ids_arr)} elements with ef_construction_for_reinsert={rein_val}...")
        idx.set_ef(rein_val)
        
        shuf = np.random.permutation(del_ids_arr)
        reinsert_data = np.ascontiguousarray(base_vectors[shuf], dtype=np.float32)
        idx.add_items(reinsert_data, shuf.astype(np.int32), replace_deleted=True)
        print("  Reinsertion complete.")
    else:
        print("  No elements marked for deletion, skipping reinsertion.")

    if optimized:
        t_start_rdn = time.time()
        repaired_nodes_count = idx.repair_disconnected_nodes()
        opt_metrics["time_rdn"] = time.time() - t_start_rdn
        opt_metrics["metric_rdn"] = int(repaired_nodes_count)
        if repaired_nodes_count > 0 or opt_metrics["time_rdn"] > 0.001:
            print(f"      [PRO-HNSW] repair_disconnected_nodes: {repaired_nodes_count} nodes processed, time: {opt_metrics['time_rdn']:.4f}s")
        
        t_start_rea = time.time()
        resolved_edges_count = idx.resolve_edge_asymmetry()
        opt_metrics["time_rea"] = time.time() - t_start_rea
        opt_metrics["metric_rea"] = int(resolved_edges_count)
        if resolved_edges_count > 0 or opt_metrics["time_rea"] > 0.001:
            print(f"      [PRO-HNSW] resolve_edge_asymmetry: {resolved_edges_count} edges fixed, time: {opt_metrics['time_rea']:.4f}s")
    
    return idx, opt_metrics

# Ground truth and evaluation functions remain complex, keeping them as is for functionality.
# (brute_force_gt_mt, _calculate_scores_for_chunk, eval_index functions would be here)

def _calculate_scores_for_chunk(query_vectors_slice: np.ndarray, base_vectors_chunk: np.ndarray, space_type: str) -> np.ndarray:
    if space_type == 'l2':
        diff = query_vectors_slice[:, np.newaxis, :] - base_vectors_chunk[np.newaxis, :, :]
        return np.linalg.norm(diff, axis=2)
    elif space_type == 'ip':
        return -np.dot(query_vectors_slice, base_vectors_chunk.T)
    elif space_type == 'cosine':
        q_norm = normalize_vectors(query_vectors_slice)
        b_norm = normalize_vectors(base_vectors_chunk)
        if q_norm is None or b_norm is None or q_norm.shape[0] == 0 or b_norm.shape[0] == 0:
            return np.full((query_vectors_slice.shape[0], base_vectors_chunk.shape[0]), np.inf, dtype=np.float32)
        similarities = np.dot(q_norm, b_norm.T)
        return 1.0 - np.clip(similarities, -1.0, 1.0)
    else:
        raise ValueError(f"Unsupported space_type for GT calculation: '{space_type}'")

def brute_force_gt_mt(base_vectors: np.ndarray, query_vectors: np.ndarray, k_val: int, space_type: str, dname_for_log: str):
    print(f"    Calculating brute-force GT for {dname_for_log} (k={k_val}, space='{space_type}')...")
    num_queries, num_base = query_vectors.shape[0], base_vectors.shape[0]
    k_val = min(k_val, num_base)
    if k_val == 0: return np.array([], dtype=np.int32).reshape(num_queries, 0)

    gt_indices = np.full((num_queries, k_val), -1, dtype=np.int32)
    
    with ThreadPoolExecutor(max_workers=N_THREADS) as executor:
        def process_query_chunk(start_idx, end_idx):
            scores = _calculate_scores_for_chunk(query_vectors[start_idx:end_idx], base_vectors, space_type)
            # Use argpartition for efficiency: find k-th smallest element, then sort only those k
            partition_indices = np.argpartition(scores, k_val, axis=1)[:, :k_val]
            
            # Now, sort within the top-k candidates for each query
            for i in range(partition_indices.shape[0]):
                query_idx = start_idx + i
                top_k_indices_for_query = partition_indices[i]
                top_k_scores = scores[i, top_k_indices_for_query]
                sorted_within_top_k = np.argsort(top_k_scores)
                gt_indices[query_idx, :] = top_k_indices_for_query[sorted_within_top_k]

        query_chunks = [(i, min(i + BF_CHUNK_SIZE, num_queries)) for i in range(0, num_queries, BF_CHUNK_SIZE)]
        executor.map(lambda p: process_query_chunk(*p), query_chunks)

    print(f"    Brute-force GT calculation for {dname_for_log} complete.")
    return gt_indices


def eval_index(idx: hnswlib.Index, query_vectors: np.ndarray, gt_indices: np.ndarray, ef_search: int, num_trials: int, recall_at_k: int):
    if query_vectors is None or query_vectors.shape[0] == 0: return 0.0, 0.0
    if gt_indices is None or gt_indices.shape[0] != query_vectors.shape[0]:
        # Cannot calculate recall, just measure QPS
        total_time = 0.0
        idx.set_ef(ef_search)
        for _ in range(num_trials):
            start_time = time.time()
            idx.knn_query(query_vectors, k=recall_at_k if recall_at_k > 0 else 1)
            total_time += time.time() - start_time
        qps = (num_trials * len(query_vectors) / total_time) if total_time > 0 else 0.0
        return 0.0, qps
    
    idx.set_ef(ef_search)
    total_time, total_correct_hits = 0.0, 0
    num_queries = len(query_vectors)
    k_to_evaluate = min(recall_at_k, gt_indices.shape[1])

    if k_to_evaluate == 0: return 0.0, 0.0

    gt_sets_for_eval = [set(row[:k_to_evaluate]) for row in gt_indices]

    for _ in range(num_trials):
        start_time = time.time()
        found_indices_batch, _ = idx.knn_query(query_vectors, k=k_to_evaluate)
        total_time += time.time() - start_time
        for i in range(num_queries):
            common_hits = len(set(found_indices_batch[i]) & gt_sets_for_eval[i])
            total_correct_hits += common_hits
            
    qps = (num_trials * num_queries / total_time) if total_time > 0 else 0.0
    total_possible_hits = num_trials * sum(len(s) for s in gt_sets_for_eval)
    recall = (total_correct_hits / total_possible_hits) * 100.0 if total_possible_hits > 0 else 0.0
    return recall, qps


# ===================================================================
# 🚀 5. MAIN EXPERIMENT LOOP
# ===================================================================

# Prepare CSV file
csv_path = f"{RESULTS_DIR}/experiment_results_{TIMESTAMP}.csv"
csv_header = [
    "timestamp", "dataset", "delta", "variant", "ef_search",
    "M", "Efc", "rein", "space", "recall_k",
    "qps", "recall",
    "time_roe", "metric_roe",
    "time_rdn", "metric_rdn",
    "time_rea", "metric_rea",
    "time_idx_build_total",
    "time_idx_update_total",
    "base_size_used", "query_size_used"
]
with open(csv_path, "w", newline="") as fp:
    csv.writer(fp).writerow(csv_header)

print(f"\n🚀 Starting experiments. Results will be saved to: {csv_path}\n")

for dname_key, current_cfg in CFG.items():
    effective_recalculate_gt = RECALCULATE_GT or (current_cfg["base_size_limit"] is not None)

    print(f"\n{'='*20} Processing Dataset: {dname_key.upper()} ({current_cfg['space']}) {'='*20}")
    print(f"  Config: M={current_cfg['M']}, Efc={current_cfg['Efc']}, Rein={current_cfg['rein']}")
    print(f"  Size Limits: Base={current_cfg['base_size_limit']}, Query={current_cfg['query_size_limit']}")
    print(f"  Settings: Recalculate_GT={effective_recalculate_gt}, Latency_Trials={TRIALS_LATENCY}")

    try:
        base_data_full, Q_data_full, gt_provided_hdf5 = load_dataset(dname_key, DATA_ROOT_HDF5)
    except Exception as e:
        print(f"  ERROR: Could not load dataset '{dname_key}': {e}. Skipping.")
        continue

    if base_data_full is None or base_data_full.shape[0] == 0: continue
    
    # Slice datasets if limits are set
    base_data_sliced = base_data_full[:current_cfg["base_size_limit"]] if current_cfg["base_size_limit"] is not None else base_data_full
    actual_base_size = base_data_sliced.shape[0]
    if actual_base_size == 0: continue
    labels_for_index = np.arange(actual_base_size, dtype=np.int32)
    
    Q_data_sliced = Q_data_full[:current_cfg["query_size_limit"]] if current_cfg["query_size_limit"] is not None and Q_data_full is not None else Q_data_full
    actual_query_size = Q_data_sliced.shape[0] if Q_data_sliced is not None else 0

    if actual_query_size == 0:
        print(f"  WARNING: No query data available for '{dname_key}'. Skipping evaluation part.")
        continue

    # Prepare Ground Truth
    gt_to_use = None
    if effective_recalculate_gt:
        print(f"  Recalculating Ground Truth for '{dname_key}'...")
        gt_to_use = brute_force_gt_mt(base_data_sliced, Q_data_sliced, RECALL_K, current_cfg["space"], dname_key)
    elif gt_provided_hdf5 is not None:
        print(f"  Using pre-computed Ground Truth from HDF5 file for '{dname_key}'.")
        gt_to_use = gt_provided_hdf5[:actual_query_size]
    else:
        print(f"  WARNING: No Ground Truth found and recalculation is off. Recalculating anyway for '{dname_key}'...")
        gt_to_use = brute_force_gt_mt(base_data_sliced, Q_data_sliced, RECALL_K, current_cfg["space"], dname_key)

    if gt_to_use is None or gt_to_use.shape[0] != actual_query_size:
        print(f"  FATAL ERROR: Could not prepare Ground Truth for '{dname_key}'. Skipping.")
        continue
    
    print(f"  Final sizes: Base={actual_base_size}, Query={actual_query_size}, GT_K={gt_to_use.shape[1] if gt_to_use.ndim == 2 else 'N/A'}")

    for delta_val in DELETION_RATIOS:
        print(f"\n-- Running for Dataset: {dname_key.upper()}, Deletion Ratio (δ) = {delta_val} --")
        num_to_delete = int(actual_base_size * delta_val)
        if num_to_delete > actual_base_size:
            print(f"  Skipping delta={delta_val} as base size {actual_base_size} is too small to delete {num_to_delete} items.")
            continue
        del_ids_arr = np.random.choice(labels_for_index, num_to_delete, replace=False) if num_to_delete > 0 else np.array([], dtype=np.int32)

        # --- Run for Original HNSW ---
        print(f"\n  Processing Original HNSW...")
        t_build_start = time.time()
        idx_orig_build = build_index(base_data_sliced, labels_for_index, current_cfg['M'], current_cfg['Efc'], current_cfg['space'])
        time_idx_build_original = time.time() - t_build_start
        print(f"    Original HNSW index built in {time_idx_build_original:.4f}s")

        t_update_start = time.time()
        idx_orig, _ = update_index(idx_orig_build, base_data_sliced, del_ids_arr, current_cfg['rein'], optimized=False, dname=dname_key, delta_val=delta_val)
        time_idx_update_original = time.time() - t_update_start
        print(f"    Original HNSW index updated in {time_idx_update_original:.4f}s")
        
        print(f"  Evaluating Original HNSW (δ={delta_val})...")
        for ef_search_val in current_cfg["efs_list"]:
            rec, qps = eval_index(idx_orig, Q_data_sliced, gt_to_use, ef_search_val, TRIALS_LATENCY, RECALL_K)
            print(f"    original ef={ef_search_val:3d}  R={rec:5.2f}%  QPS={qps:9.1f}")
            row_data = [TIMESTAMP, dname_key, delta_val, "original", ef_search_val, *current_cfg.values(), f"{qps:.3f}", f"{rec:.2f}", 0,0,0,0,0,0, f"{time_idx_build_original:.4f}", f"{time_idx_update_original:.4f}", actual_base_size, actual_query_size]
            with open(csv_path, "a", newline="") as fp: csv.writer(fp).writerow(row_data)
        del idx_orig_build, idx_orig
        
        # --- Run for Optimized (PRO-HNSW) ---
        print(f"\n  Processing PRO-HNSW...")
        t_build_start = time.time()
        idx_opt_build = build_index(base_data_sliced, labels_for_index, current_cfg['M'], current_cfg['Efc'], current_cfg['space'])
        time_idx_build_optimized = time.time() - t_build_start
        print(f"    PRO-HNSW index built in {time_idx_build_optimized:.4f}s")

        t_update_start = time.time()
        idx_opt, opt_metrics = update_index(idx_opt_build, base_data_sliced, del_ids_arr, current_cfg['rein'], optimized=True, dname=dname_key, delta_val=delta_val)
        time_idx_update_optimized = time.time() - t_update_start
        print(f"    PRO-HNSW index updated in {time_idx_update_optimized:.4f}s")
        
        print(f"  Evaluating PRO-HNSW (δ={delta_val})...")
        for ef_search_val in current_cfg["efs_list"]:
            rec, qps = eval_index(idx_opt, Q_data_sliced, gt_to_use, ef_search_val, TRIALS_LATENCY, RECALL_K)
            print(f"    pro-hnsw ef={ef_search_val:3d}  R={rec:5.2f}%  QPS={qps:9.1f}")
            row_data = [TIMESTAMP, dname_key, delta_val, "pro-hnsw", ef_search_val, *current_cfg.values(), f"{qps:.3f}", f"{rec:.2f}",
                        f"{opt_metrics.get('time_roe', 0.0):.4f}", opt_metrics.get('metric_roe', 0),
                        f"{opt_metrics.get('time_rdn', 0.0):.4f}", opt_metrics.get('metric_rdn', 0),
                        f"{opt_metrics.get('time_rea', 0.0):.4f}", opt_metrics.get('metric_rea', 0),
                        f"{time_idx_build_optimized:.4f}", f"{time_idx_update_optimized:.4f}",
                        actual_base_size, actual_query_size]
            with open(csv_path, "a", newline="") as fp: csv.writer(fp).writerow(row_data)
        del idx_opt_build, idx_opt

        import gc
        gc.collect()
        time.sleep(1) # Give a moment for memory to be released

print(f"\n✅ All experiments complete. Final results have been saved to: {csv_path}")