# Imports

In [1]:
import warnings
warnings.filterwarnings(
    "ignore",
    category=RuntimeWarning,
    message=".*overflow.*|.*invalid value.*"
)


In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

import numpy as np
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm

import faiss
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import STL10
from torchvision import models
import os 
import json

from src.config import load_config
from src.metrics import (
    compute_overlap_stats,
    compute_distance_stats, 
    compute_lid_stats,
    compute_barycenter_stats,
)


from copy import deepcopy
from datetime import datetime

import yaml


# Load Configs

In [3]:
cfg = load_config(
    "../config/base.yaml",
    "../config/data.yaml",
    "../config/embedding.yaml",
    "../config/search.yaml",
    "../config/ann.yaml",
)

In [4]:
NUM_EXPERIMENTS = cfg["experiment"]["num_experiments"]
BASE_SEED = cfg["experiment"]["seed"]
SAMPLE_SIZE = cfg["data"]["sample"]["size"]

M = cfg["ann"]["hnsw"]["m"]
EF_CONSTRUCTION = cfg["ann"]["hnsw"]["ef_construction"]
EF_SEARCH = cfg["ann"]["hnsw"]["ef_search"]

K_MIN = cfg["search"]["knn"]["k_min"]
K_MAX = cfg["search"]["knn"]["k_max"]
K_STEP = cfg["search"]["knn"]["k_step"]
K_VALUES = list(range(K_MIN, K_MAX + 1, K_STEP))

RUNS_DIR = ROOT / Path(cfg["paths"]["runs_dir"])
RUNS_DIR.mkdir(exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_PATH = ROOT / cfg["data"]["data_path"]


p = cfg["embedding"]["preprocessing"]

In [5]:
transform = T.Compose([
    T.Resize(p["resize"]),
    T.CenterCrop(p["center_crop"]),
    T.ToTensor(),
    T.Normalize(
        mean=p["normalize"]["mean"],
        std=p["normalize"]["std"]
    )
])

# Download The Dataset

In [6]:
dataset = STL10(
    root=DATA_PATH,
    split="unlabeled",
    download=True,
    transform=transform
)

FULL_DATASET_LEN = len(dataset)


print("Full STL-10 size:", FULL_DATASET_LEN)


Full STL-10 size: 100000


# Load the embedding model and remove the classification head. 

In [7]:
# Load pretrained ResNet-50 (ImageNet-trained weights)
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Remove the final classification layer to expose penultimate embeddings
embedding_model = torch.nn.Sequential(*list(resnet.children())[:-1])
embedding_model.to(DEVICE)
embedding_model.eval()

# Embedding extraction function
def extract_embeddings(batch):
    with torch.no_grad():
        feats = embedding_model(batch.to(DEVICE))
        # Output shape: (batch_size, 2048, 1, 1) → squeeze to (batch_size, 2048)
        feats = feats.squeeze(-1).squeeze(-1)
    return feats.cpu().numpy().astype("float32")

print("Embedding dimension:", 2048)

Embedding dimension: 2048


In [None]:
import numpy as np

MIN_RUNS = 10
WINDOW = 5             # require stability for last 3 updates
TOL_ABS = 1e-3        # RMS mean-curve change tolerance
STD_TOL = 1e-3
curve_history = []         # store per-run curves
delta_history = []         # store mean-curve change per run
std_norm = np.inf

MAX_NUM_EXPERIMENTS = 100

def to_json_safe(obj):
    if isinstance(obj, dict):
        return {k: to_json_safe(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [to_json_safe(v) for v in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.floating, np.integer)):
        return obj.item()
    else:
        return obj
    
for i in range(MAX_NUM_EXPERIMENTS):
    print(f"\n================ RUN {i} ================")

    # -----------------------------
    # Run setup
    # -----------------------------
    run_cfg = deepcopy(cfg)
    run_cfg["experiment"]["run_id"] = i
    
    
    run_cfg["data"]["sample"]["seed"] = BASE_SEED + i # BASE SEED + i to get new sample each iter

    run_id = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_run{i:02d}"
    run_dir = ROOT / RUNS_DIR / run_id
    run_dir.mkdir(exist_ok=False)

    with open(run_dir / "config.yaml", "w") as f:
        yaml.safe_dump(run_cfg, f)

    # -----------------------------
    # Dataset subsample (per run)
    # -----------------------------
    rng = np.random.default_rng(run_cfg["data"]["sample"]["seed"])
    
    # generates array of ints of size SAMPLE_SIZE within LEN of FULL_DATASET_LEN, 
    # without using same image mroe than 1 time.
    indices = rng.choice(FULL_DATASET_LEN, SAMPLE_SIZE, replace=False)
    
    # collects a sample as a "SUBSET" from full dataset according to the random indices
    # SUbset is an extension of DataLoader I think
    dataset_subset = Subset(dataset, indices)

    loader = DataLoader(
        dataset_subset,
        # TODO: update this from configs
        batch_size=run_cfg["embedding"]["batch_size"], 
        shuffle=False,
    )




    # -----------------------------
    # Embeddings (per run)
    # -----------------------------
    emb_file = run_dir / "embeddings.npy"
    emb_chunk_dir = run_dir / "emb_chunks"

    if emb_file.exists():
        embeddings = np.load(emb_file)
    else:
        emb_chunk_dir.mkdir(exist_ok=True)
        chunks = []

        for j, (batch, _) in enumerate(tqdm(loader)):
            emb = extract_embeddings(batch) # (batch, 2048)
            chunk_path = emb_chunk_dir / f"chunk_{j:05d}.npy"
            np.save(chunk_path, emb)
            chunks.append(chunk_path)

        embeddings = np.vstack([np.load(p) for p in chunks])
        np.save(emb_file, embeddings)

    #  collect dimension of the embeddings 
    d = embeddings.shape[1]
    print("Embeddings shape:", embeddings.shape)




    # -----------------------------
    # FAISS indices (per run)
    # -----------------------------

    # LINEAR SEARCH INDEX: 
    # Create a brute-force exact index that ranks 
    # neighbors by Euclidean distance.
    index_exact = faiss.IndexFlatL2(d)

    # Add embeddings to the index
    index_exact.add(embeddings)



    # HNSW ANN SEARCH INDEX:     
    # Build an HNSW graph index that ranks 
    # neighbors using Euclidean distance
    index_ann = faiss.IndexHNSWFlat(d, M, faiss.METRIC_L2)

    # set graph construction hyperparameter (larger = higher quality, slower build)
    index_ann.hnsw.efConstruction = EF_CONSTRUCTION

    # set search-time accuracy hyperparameter (larger = higher recall, slower search)
    index_ann.hnsw.efSearch = EF_SEARCH

    # add all embedding vectors to the ANN index
    index_ann.add(embeddings)
    
    # -----------------------------
    # Exact neighbors (per run)
    # -----------------------------
    exact_I_path = run_dir / "exact_I.npy"
    exact_D_path = run_dir / "exact_D.npy"

    if exact_I_path.exists() and exact_D_path.exists():
        I_exact_full = np.load(exact_I_path)
        D_exact_full = np.load(exact_D_path)
    else:
        D_exact_full, I_exact_full = index_exact.search(
            embeddings, len(embeddings)
        )
        np.save(exact_I_path, I_exact_full)
        np.save(exact_D_path, D_exact_full)




    # -----------------------------
    # k-sweep + metrics (per run)
    # -----------------------------
    results = {
        "k": [],
        "mean_overlap": [],
        "std_overlap": [],
        "mean_exact_dist": [],
        "std_exact_dist": [],
        "mean_ann_dist": [],
        "std_ann_dist": [],
        "mean_barycenter_shift": [],
        "std_barycenter_shift": [],
        "mean_lid_diff": [],
        "std_lid_diff": [],
        "lid_exact": [],
        "lid_ann": [],
    }


    for k in K_VALUES:
        
        # We ask for k+1 and drop the first value to prevent 
        # including the distance = 0 of a self returned query point. 
        # ---- Exact: slice and drop self-match ----
        D_exact = D_exact_full[:, 1:k+1]
        I_exact = I_exact_full[:, 1:k+1]

        # ---- ANN: search k+1 and drop self-match ----
        D_ann, I_ann = index_ann.search(embeddings, k + 1)
        D_ann = D_ann[:, 1:]
        I_ann = I_ann[:, 1:]

        mean_ov, std_ov = compute_overlap_stats(I_exact, I_ann, k)
    
        dist_stats = compute_distance_stats(D_exact, D_ann)
        
        bary_shift, std_bary_shift = compute_barycenter_stats(embeddings, I_exact, I_ann, D_exact)
                
        lid_stats = compute_lid_stats(D_exact, D_ann)

        

        results["k"].append(k)
        
        results["mean_overlap"].append(mean_ov)
        results["std_overlap"].append(std_ov)
        
        results["mean_exact_dist"].append(dist_stats["mean_exact_dist"])
        results["std_exact_dist"].append(dist_stats["std_exact_dist"])
        results["mean_ann_dist"].append(dist_stats["mean_ann_dist"])
        results["std_ann_dist"].append(dist_stats["std_ann_dist"])

                
        results["mean_barycenter_shift"].append(bary_shift)
        results["std_barycenter_shift"].append(std_bary_shift)
        
        results["mean_lid_diff"].append(lid_stats["mean_lid_diff"])
        results["std_lid_diff"].append(lid_stats["std_lid_diff"])
        results["lid_exact"].append(lid_stats["mean_lid_exact"])
        results["lid_ann"].append(lid_stats["mean_lid_ann"])




    # -----------------------------
    # Save results + plots (per run)
    # -----------------------------
    with open(run_dir / "metrics.json", "w") as f:
        json.dump(to_json_safe(results), f, indent=2)


    plots_dir = run_dir / "plots"
    plots_dir.mkdir(exist_ok=True)

    # Overlap vs k
    plt.figure()
    plt.plot(results["k"], results["mean_overlap"])
    plt.xlabel("k")
    plt.ylabel("Mean Overlap")
    plt.title("Overlap vs k")
    plt.savefig(plots_dir / "overlap_vs_k.png")
    plt.close()

    # LID difference vs k
    plt.figure()
    plt.plot(results["k"], results["mean_lid_diff"])
    plt.xlabel("k")
    plt.ylabel("Mean LID Difference (ANN − Exact)")
    plt.title("LID Difference vs k")
    plt.savefig(plots_dir / "lid_vs_k.png")
    plt.close()

    # Barycenter shift vs k
    plt.figure()
    plt.plot(results["k"], results["mean_barycenter_shift"])
    plt.xlabel("k")
    plt.ylabel("Mean Normalized Barycenter Shift")
    plt.title("Barycenter Shift vs k")
    plt.savefig(plots_dir / "barycenter_shift_vs_k.png")
    plt.close()
    
    plt.figure()
    plt.plot(results["k"], results["mean_exact_dist"], label="Exact")
    plt.plot(results["k"], results["mean_ann_dist"], label="ANN")
    plt.xlabel("k")
    plt.ylabel("Mean k-NN Distance")
    plt.title("Mean Neighborhood Radius vs k")
    plt.legend()
    plt.savefig(plots_dir / "mean_distance_vs_k.png")
    plt.close()




    curve = np.asarray(results["mean_ann_dist"], dtype=float)
    curve_history.append(curve)

    if len(curve_history) >= 2:
        min_len = min(len(c) for c in curve_history)
        curves_aligned = np.stack([c[:min_len] for c in curve_history])

        mean_prev = np.mean(curves_aligned[:-1], axis=0)
        mean_curr = np.mean(curves_aligned, axis=0)

        delta = float(np.sqrt(np.mean((mean_curr - mean_prev) ** 2)))
        delta_history.append(delta)

        curve_std = np.std(curves_aligned, axis=0)
        std_norm = float(np.linalg.norm(curve_std))

        mean_norm = float(np.linalg.norm(mean_prev)) + 1e-8
        delta_rel = float(delta / mean_norm)

        print(
            f"Run {len(curve_history)} | "
            f"Δ(mean) = {delta:.6g} | "
            f"Δ_rel = {delta_rel:.6g} | "
            f"||std|| = {std_norm:.6g}"
        )

    if (
        len(curve_history) >= MIN_RUNS
        and len(delta_history) >= WINDOW
        and all(d < TOL_ABS for d in delta_history[-WINDOW:])
        and (std_norm / (mean_norm + 1e-8)) < STD_TOL

    ):
        print(
            f"\nANN mean-distance curve stabilized after "
            f"{len(curve_history)} runs."
        )
        break






100%|██████████| 313/313 [00:12<00:00, 24.51it/s]


Embeddings shape: (10000, 2048)



100%|██████████| 313/313 [00:12<00:00, 25.97it/s]


Embeddings shape: (10000, 2048)
Run 2 | Δ(mean) = 5.52118e-05 | Δ_rel = 0.00478703 | ||std|| = 0.000281526



100%|██████████| 313/313 [00:12<00:00, 24.91it/s]


Embeddings shape: (10000, 2048)
Run 3 | Δ(mean) = 4.24263e-05 | Δ_rel = 0.00617811 | ||std|| = 0.00033848



100%|██████████| 313/313 [00:12<00:00, 25.16it/s]


Embeddings shape: (10000, 2048)
Run 4 | Δ(mean) = 2.30588e-05 | Δ_rel = 0.00326518 | ||std|| = 0.000347878



100%|██████████| 313/313 [00:12<00:00, 25.97it/s]


Embeddings shape: (10000, 2048)
Run 5 | Δ(mean) = 1.13401e-05 | Δ_rel = 0.00163026 | ||std|| = 0.000328834



100%|██████████| 313/313 [00:12<00:00, 25.63it/s]


Embeddings shape: (10000, 2048)
Run 6 | Δ(mean) = 1.18127e-05 | Δ_rel = 0.00171005 | ||std|| = 0.000324745



100%|██████████| 313/313 [00:12<00:00, 25.84it/s]


Embeddings shape: (10000, 2048)
Run 7 | Δ(mean) = 5.80987e-06 | Δ_rel = 0.00083477 | ||std|| = 0.000307976



100%|██████████| 313/313 [00:12<00:00, 25.14it/s]


Embeddings shape: (10000, 2048)
Run 8 | Δ(mean) = 1.87733e-06 | Δ_rel = 0.000270609 | ||std|| = 0.000289026



100%|██████████| 313/313 [00:12<00:00, 25.91it/s]


Embeddings shape: (10000, 2048)
Run 9 | Δ(mean) = 1.52289e-05 | Δ_rel = 0.00219373 | ||std|| = 0.000339223



100%|██████████| 313/313 [00:13<00:00, 23.99it/s]


Embeddings shape: (10000, 2048)
Run 10 | Δ(mean) = 1.20694e-05 | Δ_rel = 0.00172308 | ||std|| = 0.000363878

Distance-divergence curve stabilized after 10 runs.
