In [None]:
"""
# PRO-HNSW: Consecutive Updates Performance Evaluation

This notebook runs the experiments for the consecutive update scenario.
It simulates a continuous stream of small-batch delete-and-insert operations
and evaluates the performance of PRO-HNSW against the standard HNSW.

The workflow is as follows:
1.  **Configuration**: Set up all parameters for the experiment.
2.  **Core Functions**: Define helper functions for the experiment.
3.  **Main Loop**: Iterate through datasets, deletion ratios, and variants to run the
    consecutive update simulation and save results to a CSV file.
"""

# ===================================================================
# 📝 1. CONFIGURATION SECTION
# All user-configurable parameters are defined here.
# ===================================================================

# ────────── Global Experiment Settings ──────────
# NOTE: For a quick test run, set BASE_SIZE_LIMIT and QUERY_SIZE_LIMIT to a small number (e.g., 10000).
# For the full experiment, set them to None.
BASE_SIZE_LIMIT = None
QUERY_SIZE_LIMIT = None
RECALCULATE_GT = False
TRIALS_LATENCY = 3
# Total deletion ratio to reach by the end of the consecutive updates.
DELETION_RATIOS = [0.2, 0.8]
# Number of small-batch updates to perform to reach the final deletion ratio.
N_CONSECUTIVE_UPDATES = 1000

# ────────── HNSW Parameter Tiers ──────────
HNSW_PARAMS_TIERS = {
    "small":  {"M": 8,  "Efc": 50,  "rein": 25, "efs_range": (25, 76, 10)},
    "medium": {"M": 12, "Efc": 75,  "rein": 50, "efs_range": (50, 101, 10)},
    "large":  {"M": 16, "Efc": 100, "rein": 75, "efs_range": (75, 126, 10)},
}
DEFAULT_HNSW_TIER = "medium"

# ────────── Dataset Profiles ──────────
DATASET_HNSW_TIER_PROFILE = {
    "deep-image-96-angular": "large",
    "gist-960-euclidean": "large",
    "sift-128-euclidean": "medium",
    "nytimes-256-angular": "medium",
    "fashion-mnist-784-euclidean": "small",
    "mnist-784-euclidean": "small",
    "coco-i2i-512-angular": "small",
    "glove-25-angular": "small",
}

# ────────── File Paths and Constants ──────────
DATA_ROOT_HDF5 = "../data"
RESULTS_DIR = "../results/consecutive"
N_THREADS = 18
BF_CHUNK_SIZE = 600
RECALL_K = 10

# ===================================================================
# 🛠️ 2. SETUP & LIBRARY IMPORTS
# ===================================================================
import os
import time
import csv
import heapq
import numpy as np
import hnswlib
import gc
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
# This assumes you have a 'dataset_loader.py' or similar in a 'dataset' folder.
from dataset.hdf5_dataset_loader import load_dataset

os.makedirs(RESULTS_DIR, exist_ok=True)
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# ===================================================================
# ⚙️ 3. DYNAMIC CONFIGURATION GENERATION
# ===================================================================

CFG_BASE_CONFIG = {
    "sift-128-euclidean": {"space": "l2"},
    "gist-960-euclidean": {"space": "l2"},
    "deep-image-96-angular": {"space": "cosine"},
    "glove-25-angular": {"space": "cosine"},
    "mnist-784-euclidean": {"space": "l2"},
    "fashion-mnist-784-euclidean": {"space": "l2"},
    "nytimes-256-angular": {"space": "cosine"},
    "coco-i2i-512-angular": {"space": "cosine"},
}

CFG = {}
for dname_key, base_settings in CFG_BASE_CONFIG.items():
    tier_name = DATASET_HNSW_TIER_PROFILE.get(dname_key, DEFAULT_HNSW_TIER)
    tier_params = HNSW_PARAMS_TIERS.get(tier_name, HNSW_PARAMS_TIERS[DEFAULT_HNSW_TIER])
    CFG[dname_key] = {
        **base_settings,
        **tier_params,
        "efs_list": list(range(*tier_params["efs_range"])),
        "base_size_limit": BASE_SIZE_LIMIT,
        "query_size_limit": QUERY_SIZE_LIMIT,
    }

print("✅ Configuration generated for the following datasets:")
for dname in CFG.keys():
    print(f"- {dname}")

# ===================================================================
# 🔬 4. CORE FUNCTIONS
# ===================================================================

def normalize_vectors(vectors: np.ndarray) -> np.ndarray:
    if vectors is None or vectors.size == 0: return vectors
    if vectors.ndim == 1:
        norm = np.linalg.norm(vectors)
        return vectors / (norm if norm > 1e-9 else 1e-9)
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    return vectors / np.where(norms < 1e-9, 1e-9, norms)

def build_index(base_vectors: np.ndarray, labels_arr: np.ndarray, M_val: int, Efc_val: int, space_val: str, dname_for_log=""):
    dim = base_vectors.shape[1]
    num_elements = len(base_vectors)
    log_prefix = f"  Build_Index ({dname_for_log}):" if dname_for_log else "  Build_Index:"
    print(f"{log_prefix} Starting build for {num_elements} elements. Space='{space_val}', M={M_val}, efC={Efc_val}.")
    
    idx = hnswlib.Index(space_val, dim)
    idx.init_index(max_elements=num_elements, ef_construction=Efc_val, M=M_val, allow_replace_deleted=True)
    
    current_base_vectors = np.ascontiguousarray(base_vectors, dtype=np.float32)
    
    if hasattr(idx, 'set_num_threads'):
        try:
            idx.set_num_threads(N_THREADS)
        except Exception as e:
            print(f"{log_prefix} Warning: Failed to set num_threads for HNSW build: {e}")

    idx.add_items(current_base_vectors, labels_arr)
    print(f"{log_prefix} HNSW index build fully complete.")
    return idx

# (brute_force_gt_mt and eval_index functions are assumed to be the same as in the previous script)
def brute_force_gt_mt(base_vectors: np.ndarray, query_vectors: np.ndarray, k_val: int, space_type: str, dname_for_log: str):
    print(f"    Calculating brute-force GT for {dname_for_log} (k={k_val}, space='{space_type}')...")
    num_queries, num_base = query_vectors.shape[0], base_vectors.shape[0]
    k_val = min(k_val, num_base)
    if k_val == 0: return np.array([], dtype=np.int32).reshape(num_queries, 0)
    gt_indices = np.full((num_queries, k_val), -1, dtype=np.int32)
    with ThreadPoolExecutor(max_workers=N_THREADS) as executor:
        def process_query_chunk(start_idx, end_idx):
            scores = _calculate_scores_for_chunk(query_vectors[start_idx:end_idx], base_vectors, space_type)
            partition_indices = np.argpartition(scores, k_val, axis=1)[:, :k_val]
            for i in range(partition_indices.shape[0]):
                query_idx = start_idx + i
                top_k_indices_for_query = partition_indices[i]
                top_k_scores = scores[i, top_k_indices_for_query]
                sorted_within_top_k = np.argsort(top_k_scores)
                gt_indices[query_idx, :] = top_k_indices_for_query[sorted_within_top_k]
        query_chunks = [(i, min(i + BF_CHUNK_SIZE, num_queries)) for i in range(0, num_queries, BF_CHUNK_SIZE)]
        executor.map(lambda p: process_query_chunk(*p), query_chunks)
    print(f"    Brute-force GT calculation for {dname_for_log} complete.")
    return gt_indices
def _calculate_scores_for_chunk(query_vectors_slice: np.ndarray, base_vectors_chunk: np.ndarray, space_type: str) -> np.ndarray:
    if space_type == 'l2':
        diff = query_vectors_slice[:, np.newaxis, :] - base_vectors_chunk[np.newaxis, :, :]
        return np.linalg.norm(diff, axis=2)
    elif space_type == 'cosine':
        q_norm = normalize_vectors(query_vectors_slice)
        b_norm = normalize_vectors(base_vectors_chunk)
        if q_norm.size == 0 or b_norm.size == 0: return np.full((q_norm.shape[0], b_norm.shape[0]), np.inf, dtype=np.float32)
        similarities = np.dot(q_norm, b_norm.T)
        return 1.0 - np.clip(similarities, -1.0, 1.0)
    raise ValueError(f"Unsupported space_type for GT calculation: '{space_type}'")
def eval_index(idx: hnswlib.Index, query_vectors: np.ndarray, gt_indices: np.ndarray, ef_search: int, num_trials: int, recall_at_k: int):
    if query_vectors.size == 0: return 0.0, 0.0
    num_queries = len(query_vectors)
    idx.set_ef(ef_search)
    if gt_indices is None or gt_indices.shape[0] != num_queries:
        total_time_no_gt = 0.0
        for _ in range(num_trials):
            start_time = time.time(); idx.knn_query(query_vectors, k=max(1, recall_at_k)); total_time_no_gt += time.time() - start_time
        return 0.0, (num_trials * num_queries / total_time_no_gt) if total_time_no_gt > 0 else 0.0
    k_to_evaluate = min(recall_at_k, gt_indices.shape[1])
    if k_to_evaluate == 0: return 0.0, 0.0
    gt_sets_for_eval = [set(row[:k_to_evaluate]) for row in gt_indices]
    total_time_eval, total_correct_hits = 0.0, 0
    for _ in range(num_trials):
        start_time = time.time()
        found_indices_batch, _ = idx.knn_query(query_vectors, k=k_to_evaluate)
        total_time_eval += time.time() - start_time
        for i in range(num_queries): total_correct_hits += len(set(found_indices_batch[i]) & gt_sets_for_eval[i])
    qps = (num_trials * num_queries / total_time_eval) if total_time_eval > 0 else 0.0
    total_possible_hits = num_trials * sum(len(s) for s in gt_sets_for_eval)
    recall = (total_correct_hits / total_possible_hits) * 100.0 if total_possible_hits > 0 else 0.0
    return recall, qps


# ===================================================================
# 🚀 5. MAIN EXPERIMENT LOOP
# ===================================================================

# Prepare CSV file
csv_path = f"{RESULTS_DIR}/experiment_results_consecutive_{TIMESTAMP}.csv"
csv_header = [
    "timestamp", "dataset", "total_delta", "n_consecutive_updates", "variant", "ef_search",
    "M", "Efc", "rein", "space", "recall_k",
    "qps", "recall",
    "time_roe_total", "metric_roe_total",
    "time_rdn_final", "metric_rdn_final",
    "time_rea_final", "metric_rea_final",
    "time_reinsert_total",
    "time_idx_build_total",
    "time_idx_update_total",
    "base_size_used", "query_size_used"
]
with open(csv_path, "w", newline="") as fp:
    csv.writer(fp).writerow(csv_header)

print(f"\n🚀 Starting consecutive update experiments. Results will be saved to: {csv_path}\n")

for dname_key, current_cfg in CFG.items():
    print(f"\n{'='*20} Processing Dataset: {dname_key.upper()} ({current_cfg['space']}) {'='*20}")
    
    # --- Data and GT Preparation ---
    try:
        base_data_full, Q_data_full, gt_provided_hdf5 = load_dataset(dname_key, DATA_ROOT_HDF5)
    except Exception as e:
        print(f"  ERROR: Could not load dataset '{dname_key}': {e}. Skipping.")
        continue
    
    base_data_sliced = base_data_full[:current_cfg["base_size_limit"]] if current_cfg["base_size_limit"] is not None else base_data_full
    actual_base_size = base_data_sliced.shape[0]
    if actual_base_size == 0: continue
    labels_for_index = np.arange(actual_base_size, dtype=np.int32)
    
    Q_data_sliced = Q_data_full[:current_cfg["query_size_limit"]] if current_cfg["query_size_limit"] is not None and Q_data_full is not None else Q_data_full
    actual_query_size = Q_data_sliced.shape[0] if Q_data_sliced is not None else 0

    if actual_query_size == 0:
        print(f"  WARNING: No query data for '{dname_key}'. Skipping evaluation part.")
        continue
        
    effective_recalculate_gt = RECALCULATE_GT or (current_cfg["base_size_limit"] is not None)
    gt_to_use = None
    if effective_recalculate_gt:
        gt_to_use = brute_force_gt_mt(base_data_sliced, Q_data_sliced, RECALL_K, current_cfg["space"], dname_key)
    elif gt_provided_hdf5 is not None:
        gt_to_use = gt_provided_hdf5[:actual_query_size, :RECALL_K]
    else:
        gt_to_use = brute_force_gt_mt(base_data_sliced, Q_data_sliced, RECALL_K, current_cfg["space"], dname_key)

    if gt_to_use is None or gt_to_use.shape[0] != actual_query_size:
        print(f"  FATAL ERROR: Could not prepare Ground Truth for '{dname_key}'. Skipping.")
        continue
    
    print(f"  Data Ready: Base={actual_base_size}, Query={actual_query_size}, GT_K={gt_to_use.shape[1]}")

    for total_delta_val in DELETION_RATIOS:
        print(f"\n-- Running for Dataset: {dname_key.upper()}, Total Deletion Ratio (δ) = {total_delta_val} --")
        
        # --- Prepare deletion chunks ---
        num_to_delete_total = int(actual_base_size * total_delta_val)
        if num_to_delete_total <= 0:
            print("  Total items to delete is 0. Will evaluate initial index state only.")
            del_ids_chunks = [np.array([], dtype=np.int32)]
        else:
            del_ids_overall = np.random.choice(labels_for_index, num_to_delete_total, replace=False)
            effective_n_updates = min(N_CONSECUTIVE_UPDATES, num_to_delete_total)
            del_ids_chunks = np.array_split(del_ids_overall, effective_n_updates)
        
        print(f"  Simulating {len(del_ids_chunks)} consecutive updates to reach δ={total_delta_val}.")

        # --- Run for both variants (Original and PRO-HNSW) ---
        for variant in ["original", "pro-hnsw"]:
            print(f"\n  Processing Variant: {variant.upper()}...")
            
            # 1. Build Initial Index
            t_build_start = time.time()
            idx = build_index(base_data_sliced, labels_for_index, current_cfg['M'], current_cfg['Efc'], current_cfg['space'], f"{dname_key}-{variant}")
            time_idx_build = time.time() - t_build_start
            print(f"    {variant.upper()} index built in {time_idx_build:.4f}s")
            
            # 2. Run Consecutive Updates
            t_update_overall_start = time.time()
            cumulative_metrics = {"time_roe": 0.0, "metric_roe": 0, "time_reinsert": 0.0}

            for i, chunk_del_ids in enumerate(del_ids_chunks):
                if len(chunk_del_ids) == 0: continue
                
                # Mark for deletion
                for d_id in chunk_del_ids: idx.mark_deleted(int(d_id))
                
                # If optimized, run ROE for the current chunk
                if variant == "pro-hnsw":
                    t_roe_iter_start = time.time()
                    removed_count_iter = idx.remove_obsolete_edges(chunk_del_ids.tolist())
                    cumulative_metrics["time_roe"] += (time.time() - t_roe_iter_start)
                    cumulative_metrics["metric_roe"] += int(removed_count_iter)
                
                # Re-insert the items
                idx.set_ef(current_cfg['rein'])
                reinsert_data = np.ascontiguousarray(base_data_sliced[chunk_del_ids], dtype=np.float32)
                t_reinsert_iter_start = time.time()
                idx.add_items(reinsert_data, chunk_del_ids.astype(np.int32), replace_deleted=True)
                cumulative_metrics["time_reinsert"] += (time.time() - t_reinsert_iter_start)

            # 3. Run Final Repairs for PRO-HNSW
            final_repair_metrics = {"time_rdn": 0.0, "metric_rdn": 0, "time_rea": 0.0, "metric_rea": 0}
            if variant == "pro-hnsw" and num_to_delete_total > 0:
                print("    Running final repair functions for PRO-HNSW...")
                t_rdn_start = time.time()
                final_repair_metrics["metric_rdn"] = int(idx.repair_disconnected_nodes())
                final_repair_metrics["time_rdn"] = time.time() - t_rdn_start

                t_rea_start = time.time()
                final_repair_metrics["metric_rea"] = int(idx.resolve_edge_asymmetry())
                final_repair_metrics["time_rea"] = time.time() - t_rea_start
                print("    Final repairs complete.")

            time_idx_update_total = time.time() - t_update_overall_start
            print(f"    {variant.upper()} all updates finished in {time_idx_update_total:.4f}s")
            
            # 4. Evaluate and Save Results
            print(f"  Evaluating {variant.upper()} (total δ={total_delta_val})...")
            for ef_search_val in current_cfg["efs_list"]:
                rec, qps = eval_index(idx, Q_data_sliced, gt_to_use, ef_search_val, TRIALS_LATENCY, RECALL_K)
                print(f"    {variant} ef={ef_search_val:3d}  R={rec:5.2f}%  QPS={qps:9.1f}")
                row_data = [TIMESTAMP, dname_key, total_delta_val, len(del_ids_chunks), variant, ef_search_val,
                            current_cfg["M"], current_cfg["Efc"], current_cfg["rein"], current_cfg["space"], RECALL_K,
                            f"{qps:.3f}", f"{rec:.2f}",
                            f"{cumulative_metrics['time_roe']:.4f}", cumulative_metrics['metric_roe'],
                            f"{final_repair_metrics['time_rdn']:.4f}", final_repair_metrics['metric_rdn'],
                            f"{final_repair_metrics['time_rea']:.4f}", final_repair_metrics['metric_rea'],
                            f"{cumulative_metrics['time_reinsert']:.4f}",
                            f"{time_idx_build:.4f}", f"{time_idx_update_total:.4f}",
                            actual_base_size, actual_query_size]
                with open(csv_path, "a", newline="") as fp:
                    csv.writer(fp).writerow(row_data)

            del idx
            gc.collect()
            time.sleep(1)

print(f"\n✅ All experiments complete. Final results have been saved to: {csv_path}")