In [None]:
"""
# PRO-HNSW: Static Graph Optimization Experiment (No Updates)

This notebook evaluates the effectiveness of the PRO-HNSW repair modules
(repair_disconnected_nodes and resolve_edge_asymmetry) on a freshly built,
static HNSW index without any delete or update operations.

The workflow is as follows:
1.  Build a standard HNSW index.
2.  Evaluate its performance ("baseline_initial").
3.  Apply the PRO-HNSW repair functions to the *same* index.
4.  Evaluate its performance again ("optimized_static").
5.  Save results to a CSV file for comparison.
"""

# ===================================================================
# 📝 1. CONFIGURATION SECTION
# All user-configurable parameters are defined here.
# ===================================================================

# ────────── Global Experiment Settings ──────────
BASE_SIZE_LIMIT = None  # Set to a number (e.g., 10000) for a quick test, or None for full dataset.
QUERY_SIZE_LIMIT = None
RECALCULATE_GT = False # Not typically needed if HDF5 provides GT.
TRIALS_LATENCY = 3

# ────────── HNSW Parameter Tiers ──────────
HNSW_PARAMS_TIERS = {
    "small":  {"M": 8,  "Efc": 50,  "efs_range": (25, 76, 10)},
    "medium": {"M": 12, "Efc": 75,  "efs_range": (50, 101, 10)},
    "large":  {"M": 16, "Efc": 100, "efs_range": (75, 126, 10)},
}
DEFAULT_HNSW_TIER = "medium"

# ────────── Dataset Profiles ──────────
DATASET_HNSW_TIER_PROFILE = {
    "deep-image-96-angular": "large",
    "gist-960-euclidean": "large",
    "sift-128-euclidean": "medium",
    "nytimes-256-angular": "medium",
    "fashion-mnist-784-euclidean": "small",
    "mnist-784-euclidean": "small",
    "coco-i2i-512-angular": "small",
    "glove-25-angular": "small",
}

# ────────── File Paths and Constants ──────────
DATA_ROOT_HDF5 = "../data"
RESULTS_DIR = "../results/no_updates"
N_THREADS = 18
RECALL_K = 10

# ===================================================================
# 🛠️ 2. SETUP & LIBRARY IMPORTS
# ===================================================================
import os
import time
import csv
import numpy as np
import hnswlib
import h5py
import gc
from datetime import datetime
# This assumes you have a 'dataset_loader.py' or similar.
# For simplicity, a local loader is defined in the functions below.

os.makedirs(RESULTS_DIR, exist_ok=True)
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# ===================================================================
# ⚙️ 3. DYNAMIC CONFIGURATION GENERATION
# ===================================================================

CFG_BASE_CONFIG = {
    "sift-128-euclidean": {"space": "l2"},
    "gist-960-euclidean": {"space": "l2"},
    "deep-image-96-angular": {"space": "cosine"},
    "glove-25-angular": {"space": "cosine"},
    "mnist-784-euclidean": {"space": "l2"},
    "fashion-mnist-784-euclidean": {"space": "l2"},
    "nytimes-256-angular": {"space": "cosine"},
    "coco-i2i-512-angular": {"space": "cosine"},
}

CFG = {}
for dname_key, base_settings in CFG_BASE_CONFIG.items():
    tier_name = DATASET_HNSW_TIER_PROFILE.get(dname_key, DEFAULT_HNSW_TIER)
    tier_params = HNSW_PARAMS_TIERS.get(tier_name, HNSW_PARAMS_TIERS[DEFAULT_HNSW_TIER])
    CFG[dname_key] = {
        **base_settings,
        "M": tier_params["M"],
        "Efc": tier_params["Efc"],
        "efs_list": list(range(*tier_params["efs_range"])),
        "base_size_limit": BASE_SIZE_LIMIT,
        "query_size_limit": QUERY_SIZE_LIMIT,
    }

print("✅ Configuration generated for the following datasets:")
for dname in CFG.keys():
    print(f"- {dname}")

# ===================================================================
# 🔬 4. CORE HELPER FUNCTIONS
# ===================================================================

def normalize_vectors(vectors: np.ndarray, space_type: str) -> np.ndarray:
    if space_type != 'cosine' or vectors is None or vectors.size == 0:
        return vectors.astype(np.float32) if vectors is not None else None
    vectors_float32 = vectors.astype(np.float32)
    norms = np.linalg.norm(vectors_float32, axis=1, keepdims=True)
    return vectors_float32 / np.where(norms == 0, 1e-9, norms)

def load_hdf5_data(data_root, dataset_name, recall_k):
    file_path = os.path.join(data_root, f"{dataset_name}.hdf5")
    if not os.path.exists(file_path):
        print(f"ERROR: Dataset file not found at {file_path}")
        return None, None, None, None
    try:
        with h5py.File(file_path, 'r') as f:
            base = np.array(f['train'], dtype=np.float32)
            query = np.array(f['test'], dtype=np.float32)
            gt = np.array(f['neighbors'])
            space = 'l2'
            if 'angular' in dataset_name.lower() or 'cosine' in dataset_name.lower():
                space = 'cosine'

            if gt.ndim == 2 and gt.shape[1] > recall_k:
                gt = gt[:, :recall_k]
            
            print(f"Dataset '{dataset_name}' loaded: Base {base.shape}, Query {query.shape}, GT {gt.shape}, Space: {space}")
            return base, query, gt, space
    except Exception as e:
        print(f"Error loading HDF5 file {file_path}: {e}")
        return None, None, None, None

def build_index(base_vectors: np.ndarray, labels_arr: np.ndarray, M, Efc, space, threads):
    dim = base_vectors.shape[1]
    print(f"  Building HNSW index (Space: {space}, M: {M}, efC: {Efc})...")
    idx = hnswlib.Index(space=space, dim=dim)
    idx.init_index(max_elements=len(base_vectors), ef_construction=Efc, M=M)
    
    if hasattr(idx, 'set_num_threads'):
        idx.set_num_threads(threads)
    
    idx.add_items(np.ascontiguousarray(base_vectors), labels_arr)
    print(f"  HNSW index build complete.")
    return idx

def evaluate_performance(idx: hnswlib.Index, query_vectors: np.ndarray, gt_indices: np.ndarray, ef_search, num_trials, recall_k):
    if query_vectors is None or gt_indices is None or query_vectors.shape[0] != gt_indices.shape[0]:
        return 0.0, 0.0
    
    idx.set_ef(ef_search)
    total_time, total_correct_hits = 0.0, 0
    num_queries = query_vectors.shape[0]
    k_eval = min(recall_k, gt_indices.shape[1])
    gt_sets = [set(row[:k_eval]) for row in gt_indices]

    for _ in range(num_trials):
        start_time = time.time()
        labels_batch, _ = idx.knn_query(query_vectors, k=k_eval)
        total_time += time.time() - start_time
        if labels_batch.shape[0] == num_queries:
            for i in range(num_queries):
                total_correct_hits += len(set(labels_batch[i]) & gt_sets[i])
    
    qps = (num_trials * num_queries / total_time) if total_time > 0 else 0.0
    total_gt_count = sum(len(s) for s in gt_sets) * num_trials
    recall = (total_correct_hits / total_gt_count) * 100.0 if total_gt_count > 0 else 0.0
    return recall, qps


# ===================================================================
# 🚀 5. MAIN EXPERIMENT LOOP
# ===================================================================
csv_path = f"{RESULTS_DIR}/experiment_results_static_{TIMESTAMP}.csv"
csv_header = [
    "timestamp", "dataset", "variant", "ef_search",
    "M", "Efc", "space", "recall_k",
    "qps", "recall",
    "time_idx_build", "time_apply_rdn", "metric_rdn", "time_apply_rea", "metric_rea"
]
with open(csv_path, "w", newline="") as fp:
    csv.writer(fp).writerow(csv_header)

print(f"\n🚀 Starting static optimization experiments. Results will be saved to: {csv_path}\n")

for dname_key, current_cfg in CFG.items():
    print(f"\n{'='*20} Processing Dataset: {dname_key.upper()} ({current_cfg['space']}) {'='*20}")
    
    # --- Data Preparation ---
    base_data, query_data, gt_data, space = load_hdf5_data(DATA_ROOT_HDF5, dname_key, RECALL_K)
    if base_data is None: continue

    base_data = base_data[:current_cfg["base_size_limit"]] if current_cfg["base_size_limit"] else base_data
    query_data = query_data[:current_cfg["query_size_limit"]] if current_cfg["query_size_limit"] else query_data
    gt_data = gt_data[:query_data.shape[0]]
    
    if base_data.size == 0 or query_data.size == 0 or gt_data.size == 0:
        print(f"  Data is empty for {dname_key} after slicing. Skipping.")
        continue

    base_data_processed = normalize_vectors(base_data, space)
    query_data_processed = normalize_vectors(query_data, space)
    labels_for_index = np.arange(base_data_processed.shape[0], dtype=np.int32)
    
    print(f"  Data Ready: Base={base_data_processed.shape[0]}, Query={query_data_processed.shape[0]}")

    # --- Build a single, shared index ---
    print(f"  Building a single HNSW index for both variants...")
    t_build_start = time.time()
    idx_shared = build_index(base_data_processed, labels_for_index, current_cfg['M'], current_cfg['Efc'], space, N_THREADS)
    time_build_common = time.time() - t_build_start
    print(f"  Index built in {time_build_common:.4f}s")
    
    # --- 1. Evaluate Baseline (Initial Build) ---
    print("\n  Evaluating: Baseline HNSW (as-built)...")
    for ef_search in current_cfg["efs_list"]:
        rec, qps = evaluate_performance(idx_shared, query_data_processed, gt_data, ef_search, TRIALS_LATENCY, RECALL_K)
        print(f"    baseline ef={ef_search:3d}  R={rec:5.2f}%  QPS={qps:9.1f}")
        row = [TIMESTAMP, dname_key, "baseline_initial", ef_search, current_cfg["M"], current_cfg["Efc"], space, RECALL_K,
               f"{qps:.3f}", f"{rec:.4f}", f"{time_build_common:.4f}", 0, 0, 0, 0]
        with open(csv_path, "a", newline="") as fp: csv.writer(fp).writerow(row)
        
    # --- 2. Apply Optimizations and Evaluate Again ---
    print("\n  Applying PRO-HNSW static optimizations...")
    
    time_rdn, metric_rdn, time_rea, metric_rea = 0.0, 0, 0.0, 0
    
    if hasattr(idx_shared, 'repair_disconnected_nodes'):
        t_rdn_start = time.time()
        metric_rdn = idx_shared.repair_disconnected_nodes()
        time_rdn = time.time() - t_rdn_start
        print(f"    repair_disconnected_nodes completed in {time_rdn:.4f}s, processed {metric_rdn} nodes.")

    if hasattr(idx_shared, 'resolve_edge_asymmetry'):
        t_rea_start = time.time()
        metric_rea = idx_shared.resolve_edge_asymmetry()
        time_rea = time.time() - t_rea_start
        print(f"    resolve_edge_asymmetry completed in {time_rea:.4f}s, fixed {metric_rea} edges.")
    
    print("\n  Evaluating: PRO-HNSW (statically optimized)...")
    for ef_search in current_cfg["efs_list"]:
        rec, qps = evaluate_performance(idx_shared, query_data_processed, gt_data, ef_search, TRIALS_LATENCY, RECALL_K)
        print(f"    optimized ef={ef_search:3d}  R={rec:5.2f}%  QPS={qps:9.1f}")
        row = [TIMESTAMP, dname_key, "optimized_static", ef_search, current_cfg["M"], current_cfg["Efc"], space, RECALL_K,
               f"{qps:.3f}", f"{rec:.4f}", f"{time_build_common:.4f}", 
               f"{time_rdn:.4f}", metric_rdn, f"{time_rea:.4f}", metric_rea]
        with open(csv_path, "a", newline="") as fp: csv.writer(fp).writerow(row)

    del idx_shared
    gc.collect()

print(f"\n✅ All 'No Update' experiments complete. Results saved to: {csv_path}")