In [None]:
"""
# HNSW Performance Degradation Experiment under Consecutive Updates

This notebook simulates and evaluates the performance degradation of HNSW over a series of consecutive updates. It compares two variants:
1.  **Original HNSW**: Standard deletion and re-insertion.
2.  **PRO-HNSW (Iterative Repair)**: Applies the ROE, RDN, and REA repair modules within each update iteration.

The primary metric for this experiment is Recall@K over the number of update iterations.
"""

# ===================================================================
# 📝 1. CONFIGURATION SECTION
# All user-configurable parameters are defined here.
# ===================================================================

# ────────── Target Dataset and Paths ──────────
TARGET_DATASET_NAME = "fashion-mnist-784-euclidean"
# NOTE: Assumes datasets are in a 'data/' subdirectory and results will be saved to a 'results/degradation' subdirectory.
DATA_ROOT_HDF5 = "../data"
OUTPUT_DIR = "../results/resilience"

# ────────── HNSW and Experiment Parameters ──────────
M_PARAM = 8
EF_CONSTRUCTION = 50
REINSERT_EF = 25
RECALL_K = 10
EF_SEARCH_FOR_EVAL = 30  # A fixed efSearch value used for recall evaluation at each interval.

# ────────── Consecutive Update Simulation Parameters ──────────
PER_ITERATION_DELETE_RATIO = 0.001  # Ratio of the total dataset size to delete and re-insert in each iteration.
TOTAL_UPDATE_ITERATIONS = 1000      # Total number of small-batch update iterations to perform.

# ────────── Evaluation Schedule ──────────
# Defines at which iteration points to stop and evaluate recall.
EVALUATION_INTERVAL = 50  # Evaluate every 50 iterations.
INITIAL_EVAL_POINTS = [0, 1, 5, 10, 20] # Also evaluate at these specific early iterations.

# ────────── System Parameters ──────────
NUM_TRIALS_FOR_EVAL = 3
NUM_THREADS_HNSW = 18

# ===================================================================
# 🛠️ 2. SETUP & LIBRARY IMPORTS
# ===================================================================
import os
import time
import csv
import numpy as np
import hnswlib
import h5py
import pandas as pd
import gc
from datetime import datetime

# --- Automatically generate evaluation points ---
EVALUATION_ITERATIONS = sorted(list(set(INITIAL_EVAL_POINTS + list(range(0, TOTAL_UPDATE_ITERATIONS + 1, EVALUATION_INTERVAL)))))
if TOTAL_UPDATE_ITERATIONS not in EVALUATION_ITERATIONS:
    EVALUATION_ITERATIONS.append(TOTAL_UPDATE_ITERATIONS)
print(f"Evaluation will be performed at iterations: {EVALUATION_ITERATIONS}")


# --- Prepare results directory ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# ===================================================================
# 🔬 3. CORE HELPER FUNCTIONS
# ===================================================================

def normalize_l2(vectors: np.ndarray) -> np.ndarray:
    if vectors is None or vectors.size == 0: return vectors
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    return vectors / np.where(norms == 0, 1e-9, norms)

def load_hdf5_data(data_root, dataset_name, recall_k):
    file_path = os.path.join(data_root, f"{dataset_name}.hdf5")
    if not os.path.exists(file_path):
        print(f"ERROR: Dataset file not found at {file_path}")
        return None, None, None, None
    try:
        with h5py.File(file_path, 'r') as f:
            base_vectors = np.array(f['train'], dtype=np.float32)
            query_vectors = np.array(f['test'], dtype=np.float32)
            gt_indices = np.array(f['neighbors'])
            space = 'l2'
            if 'angular' in dataset_name.lower() or 'cosine' in dataset_name.lower():
                space = 'cosine'
                base_vectors = normalize_l2(base_vectors)
                query_vectors = normalize_l2(query_vectors)
            
            # Slice GT to the required K
            if gt_indices.ndim == 2 and gt_indices.shape[1] > recall_k:
                gt_indices = gt_indices[:, :recall_k]
            
            print(f"Dataset '{dataset_name}' loaded: Base {base_vectors.shape}, Query {query_vectors.shape}, GT {gt_indices.shape}, Space: {space}")
            return base_vectors, query_vectors, gt_indices, space
    except Exception as e:
        print(f"Error loading HDF5 file {file_path}: {e}")
        return None, None, None, None

def build_hnsw_index(data: np.ndarray, labels: np.ndarray, m, efc, space, threads):
    dim = data.shape[1]
    p = hnswlib.Index(space=space, dim=dim)
    p.init_index(max_elements=data.shape[0], ef_construction=efc, M=m, allow_replace_deleted=True)
    if hasattr(p, 'set_num_threads'):
        p.set_num_threads(threads)
    p.add_items(data, labels)
    return p

def evaluate_recall(idx: hnswlib.Index, query_data: np.ndarray, gt_indices: np.ndarray, ef_search, recall_k, num_trials):
    if query_data is None or gt_indices is None or query_data.shape[0] != gt_indices.shape[0]:
        return 0.0
    
    num_queries = query_data.shape[0]
    idx.set_ef(ef_search)
    k_eval = min(recall_k, gt_indices.shape[1])
    gt_sets = [set(row[:k_eval]) for row in gt_indices]

    total_correct_hits = 0
    for _ in range(num_trials):
        labels_batch, _ = idx.knn_query(query_data, k=k_eval)
        for i in range(num_queries):
            found_valid = {label for label in labels_batch[i] if label != -1}
            total_correct_hits += len(found_valid & gt_sets[i])
            
    total_gt_count = sum(len(s) for s in gt_sets) * num_trials
    return (total_correct_hits / total_gt_count) * 100.0 if total_gt_count > 0 else 0.0

# ===================================================================
# 🚀 4. MAIN EXPERIMENT LOGIC
# ===================================================================

print(f"Starting HNSW degradation experiment for: {TARGET_DATASET_NAME}")

# --- Load Data ---
base_vectors, query_vectors, ground_truth_indices, space_type = load_hdf5_data(DATA_ROOT_HDF5, TARGET_DATASET_NAME, RECALL_K)
if base_vectors is None:
    exit("Exiting due to data loading failure.")

# --- Prepare Update Chunks ---
initial_base_size = base_vectors.shape[0]
all_possible_labels = np.arange(initial_base_size, dtype=np.int32)
items_per_iter = max(1, int(initial_base_size * PER_ITERATION_DELETE_RATIO))
total_distinct_items_to_update = min(initial_base_size, TOTAL_UPDATE_ITERATIONS * items_per_iter)
ids_for_entire_process = np.random.choice(all_possible_labels, size=total_distinct_items_to_update, replace=False)

ids_chunks_for_iterations = []
if total_distinct_items_to_update > 0:
    effective_iterations = (total_distinct_items_to_update + items_per_iter - 1) // items_per_iter
    ids_chunks_for_iterations = np.array_split(ids_for_entire_process, effective_iterations)

print(f"Simulation plan: {len(ids_chunks_for_iterations)} iterations with ~{items_per_iter} updates each.")

# --- Prepare CSV for Results ---
csv_path = f"{OUTPUT_DIR}/degradation_results_{TARGET_DATASET_NAME}_{STAMP}.csv"
csv_header = [
    "timestamp", "dataset", "variant", "iteration", "recall", "ef_search",
    "M", "Efc", "rein", "recall_k",
    "total_update_time", "total_roe_time", "total_rdn_time", "total_rea_time",
    "total_removed_edges", "total_repaired_nodes", "total_resolved_edges"
]
with open(csv_path, "w", newline="") as fp:
    csv.writer(fp).writerow(csv_header)

# --- Run Experiment for each Variant ---
variants_to_run = ["Original HNSW", "PRO-HNSW (Iterative Repair)"]
final_plot_data = {}

for variant_name in variants_to_run:
    print(f"\n{'='*20} Processing Variant: {variant_name} {'='*20}")

    # 1. Build a fresh index for each variant
    t_build_start = time.time()
    hnsw_idx = build_hnsw_index(base_vectors, all_possible_labels, M_PARAM, EF_CONSTRUCTION, space_type, NUM_THREADS_HNSW)
    build_time = time.time() - t_build_start
    print(f"  Initial index built in {build_time:.4f}s.")

    # 2. Initialize metrics and result storage
    results_for_this_variant = []
    cumulative_metrics = {
        "update_time": 0.0, "roe_time": 0.0, "rdn_time": 0.0, "rea_time": 0.0,
        "removed_edges": 0, "repaired_nodes": 0, "resolved_edges": 0
    }
    
    # 3. Main update and evaluation loop
    for iteration_idx in range(TOTAL_UPDATE_ITERATIONS + 1):
        # --- Evaluate at scheduled intervals ---
        if iteration_idx in EVALUATION_ITERATIONS:
            recall_val = evaluate_recall(hnsw_idx, query_vectors, ground_truth_indices, EF_SEARCH_FOR_EVAL, RECALL_K, NUM_TRIALS_FOR_EVAL)
            print(f"  > Eval at Iteration {iteration_idx}: Recall@{RECALL_K} (efS={EF_SEARCH_FOR_EVAL}) = {recall_val:.2f}%")
            results_for_this_variant.append({'iterations': iteration_idx, 'recall': recall_val})
            
            # Write current state to CSV
            row_data = [
                datetime.now().strftime("%Y%m%d_%H%M%S"), TARGET_DATASET_NAME, variant_name, iteration_idx, f"{recall_val:.4f}", EF_SEARCH_FOR_EVAL,
                M_PARAM, EF_CONSTRUCTION, REINSERT_EF, space_type, RECALL_K,
                f"{cumulative_metrics['update_time']:.4f}", f"{cumulative_metrics['roe_time']:.4f}", f"{cumulative_metrics['rdn_time']:.4f}", f"{cumulative_metrics['rea_time']:.4f}",
                cumulative_metrics['removed_edges'], cumulative_metrics['repaired_nodes'], cumulative_metrics['resolved_edges']
            ]
            with open(csv_path, "a", newline="") as fp:
                csv.writer(fp).writerow(row_data)

        # --- Perform update step (if not the last evaluation point) ---
        if iteration_idx < len(ids_chunks_for_iterations):
            ids_this_iteration = ids_chunks_for_iterations[iteration_idx]
            if ids_this_iteration.size == 0:
                continue

            update_step_start_time = time.time()
            
            # a. Mark elements for deletion
            for d_id in ids_this_iteration:
                hnsw_idx.mark_deleted(int(d_id))
            
            # b. PRO-HNSW: Run ROE before re-insertion
            if variant_name == "PRO-HNSW (Iterative Repair)":
                t_roe_start = time.time()
                removed_count = hnsw_idx.remove_obsolete_edges(ids_this_iteration.tolist())
                cumulative_metrics["roe_time"] += (time.time() - t_roe_start)
                cumulative_metrics["removed_edges"] += int(removed_count)
            
            # c. Re-insert elements
            hnsw_idx.set_ef(REINSERT_EF)
            reinsert_data = np.ascontiguousarray(base_vectors[ids_this_iteration], dtype=np.float32)
            hnsw_idx.add_items(reinsert_data, ids_this_iteration.astype(np.int32), replace_deleted=True)
            
            # d. PRO-HNSW: Run RDN and REA after re-insertion
            if variant_name == "PRO-HNSW (Iterative Repair)":
                t_rdn_start = time.time()
                repaired_count = hnsw_idx.repair_disconnected_nodes()
                cumulative_metrics["rdn_time"] += (time.time() - t_rdn_start)
                cumulative_metrics["repaired_nodes"] += int(repaired_count)

                t_rea_start = time.time()
                resolved_count = hnsw_idx.resolve_edge_asymmetry()
                cumulative_metrics["rea_time"] += (time.time() - t_rea_start)
                cumulative_metrics["resolved_edges"] += int(resolved_count)

            cumulative_metrics["update_time"] += (time.time() - update_step_start_time)
            
    final_plot_data[variant_name] = pd.DataFrame(results_for_this_variant)
    del hnsw_idx
    gc.collect()

print(f"\n✅ All experiments complete. Results have been saved to: {csv_path}")
print("This CSV file can now be used to plot the performance degradation curves.")