In [1]:
import warnings
warnings.filterwarnings(
    "ignore",
    category=RuntimeWarning,
    message=".*overflow.*|.*invalid value.*"
)


In [2]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent
sys.path.append(str(ROOT))

import numpy as np
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm

import faiss
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import STL10
from torchvision import models
import os 
import json

from src.config import load_config
from src.metrics import (
    compute_overlap_stats,
    compute_distance_stats, 
    compute_lid_stats,
    compute_barycenter_stats,
)


from copy import deepcopy
from datetime import datetime

import yaml


In [3]:
# Path to config
CONFIG_PATH = Path("../config/config.yaml")

# Load YAML config
with open(CONFIG_PATH, "r") as f:
    cfg = yaml.safe_load(f)

cfg


{'experiment': {'name': 'neighborhood_analysis',
  'base_seed': 42,
  'max_runs': 25,
  'min_runs': 10,
  'convergence': {'enabled': True,
   'window': 5,
   'tolerance_abs': '1e-3',
   'tolerance_std': '1e-4'}},
 'data': {'dataset': 'STL-10',
  'total_size': 100000,
  'sample': {'enabled': True,
   'subset_size': 10000,
   'sampling': 'random',
   'replacement': False}},
 'embedding': {'model': 'resnet50',
  'pretrained': 'imagenet1k',
  'output_dim': 2048,
  'batch_size': 32,
  'device': 'cuda',
  'normalize': False},
 'search': {'distance_metric': 'l2',
  'exact': {'enabled': True},
  'ann': {'enabled': True,
   'algorithm': 'hnsw',
   'parameters': {'M': 32, 'efConstruction': 200, 'efSearch': 50}}},
 'evaluation': {'min_k': 10,
  'max_k': 500,
  'k_step': 10,
  'metrics': {'neighborhood_overlap': True,
   'average_neighbor_distance': True,
   'barycenter_shift': True,
   'local_intrinsic_dimensionality': True}},
 'ef_search_study': {'enabled': True, 'values': [20, 50, 100]},
 'ef_c

In [4]:
import numpy as np
import json
import yaml
import torch
import faiss
import matplotlib.pyplot as plt

from pathlib import Path
from copy import deepcopy
from datetime import datetime
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import STL10
from torchvision import models, transforms as T
from pathlib import Path
import torch

# -----------------------------
# Global paths
# -----------------------------
ROOT = Path(".")                  # project root (adjust if needed)
DATA_PATH = ROOT / "../data"          # where STL-10 will be stored
RUNS_DIR = "../runs"                  # relative to ROOT

DATA_PATH.mkdir(exist_ok=True)
(ROOT / RUNS_DIR).mkdir(exist_ok=True)

# -----------------------------
# Device
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# ======================================================
# ---------------- CONFIG PARSING ----------------------
# ======================================================

base_seed = cfg["experiment"]["base_seed"]
max_runs = cfg["experiment"]["max_runs"]

subset_size = cfg["data"]["sample"]["subset_size"]

embedding_cfg = cfg["embedding"]
batch_size = embedding_cfg["batch_size"]

ann_cfg = cfg["search"]["ann"]["parameters"]
M = ann_cfg["M"]
ef_construction = ann_cfg["efConstruction"]
ef_search = ann_cfg["efSearch"]

eval_cfg = cfg["evaluation"]
K_VALUES = list(
    range(eval_cfg["min_k"], eval_cfg["max_k"] + 1, eval_cfg["k_step"])
)

EF_SEARCH_VALUES = cfg["ef_search_study"]["values"]
EF_CONSTRUCTION_VALUES = cfg["ef_construction_study"]["values"]
ALPHAS = cfg["alpha_study"]["alpha_values"]

RUN_EF_SEARCH = cfg["ef_search_study"]["enabled"]
RUN_EF_CONSTRUCTION = cfg["ef_construction_study"]["enabled"]
RUN_ALPHA = cfg["alpha_study"]["enabled"]

EF_MIN = 32


# ======================================================
# ---------------- ONE-TIME SETUP ----------------------
# ======================================================

transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225))

])

dataset = STL10(
    root=DATA_PATH,
    split="unlabeled",
    download=True,
    transform=transform
)

FULL_DATASET_LEN = len(dataset)
print("Full STL-10 size:", FULL_DATASET_LEN)

resnet = models.resnet50(
    weights=models.ResNet50_Weights.IMAGENET1K_V2
)
embedding_model = torch.nn.Sequential(*list(resnet.children())[:-1])
embedding_model.to(DEVICE)
embedding_model.eval()

def extract_embeddings(batch):
    with torch.no_grad():
        feats = embedding_model(batch.to(DEVICE))
        feats = feats.squeeze(-1).squeeze(-1)
    return feats.cpu().numpy().astype("float32")

def to_json_safe(obj):
    if isinstance(obj, dict):
        return {k: to_json_safe(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [to_json_safe(v) for v in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.floating, np.integer)):
        return obj.item()
    else:
        return obj


def run_ann_condition(embeddings, index_ann, D_exact_full, I_exact_full):
    results = {
        "k": [],
        "mean_overlap": [],
        "mean_exact_dist": [],
        "mean_ann_dist": [],
        "mean_barycenter_shift": [],
        "mean_lid_diff": [],
        "mean_lid_exact": [],
        "mean_lid_ann": [],
    }

    for k in K_VALUES:
        D_exact = D_exact_full[:, 1:k+1]
        I_exact = I_exact_full[:, 1:k+1]

        D_ann, I_ann = index_ann.search(embeddings, k + 1)
        D_ann = D_ann[:, 1:]
        I_ann = I_ann[:, 1:]

        mean_ov, _ = compute_overlap_stats(I_exact, I_ann, k)
        dist_stats = compute_distance_stats(D_exact, D_ann)
        bary_shift, _ = compute_barycenter_stats(
            embeddings, I_exact, I_ann, D_exact
        )
        lid_stats = compute_lid_stats(D_exact, D_ann)

        lid_exact = lid_stats["mean_lid_exact"]
        lid_ann = lid_stats["mean_lid_ann"]
        lid_diff = lid_ann - lid_exact

        results["k"].append(k)
        results["mean_overlap"].append(mean_ov)
        results["mean_exact_dist"].append(dist_stats["mean_exact_dist"])
        results["mean_ann_dist"].append(dist_stats["mean_ann_dist"])
        results["mean_barycenter_shift"].append(bary_shift)
        results["mean_lid_exact"].append(lid_exact)
        results["mean_lid_ann"].append(lid_ann)
        results["mean_lid_diff"].append(lid_diff)

    return results


# ======================================================
# ---------------- PER-RUN LOOP ------------------------
# ======================================================

for run_id in range(max_runs):

    print(f"\n================ RUN {run_id} ================")

    ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = ROOT / RUNS_DIR / f"{ts}_run{run_id:02d}"
    run_dir.mkdir(exist_ok=False)

    # ---- Sampling ----
    rng = np.random.default_rng(base_seed + run_id)
    indices = rng.choice(FULL_DATASET_LEN, subset_size, replace=False)
    subset = Subset(dataset, indices)

    loader = DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=False
    )

    # ---- Embeddings ----
    emb_path = run_dir / "embeddings.npy"
    if emb_path.exists():
        embeddings = np.load(emb_path)
    else:
        chunks = []
        for batch, _ in tqdm(loader):
            chunks.append(extract_embeddings(batch))
        embeddings = np.vstack(chunks)
        np.save(emb_path, embeddings)

    d = embeddings.shape[1]

    # ---- Exact index ----
    index_exact = faiss.IndexFlatL2(d)
    index_exact.add(embeddings)

    exact_path = run_dir / "exact_neighbors.npz"
    if exact_path.exists():
        data = np.load(exact_path)
        D_exact_full = data["D"]
        I_exact_full = data["I"]
    else:
        D_exact_full, I_exact_full = index_exact.search(
            embeddings, len(embeddings)
        )
        np.savez(exact_path, D=D_exact_full, I=I_exact_full)

    # ---- Baseline ANN index ----
    index_ann = faiss.IndexHNSWFlat(d, M, faiss.METRIC_L2)
    index_ann.hnsw.efConstruction = ef_construction
    index_ann.hnsw.efSearch = ef_search
    index_ann.add(embeddings)

    # ==================================================
    # Experiment 1: fixed efSearch
    # ==================================================
    if RUN_EF_SEARCH:
        base_dir = run_dir / "efSearch_study"
        base_dir.mkdir(exist_ok=True)

        for ef in EF_SEARCH_VALUES:
            index_ann.hnsw.efSearch = ef
            results = run_ann_condition(
                embeddings, index_ann,
                D_exact_full, I_exact_full
            )
            out = base_dir / f"ef_{ef}"
            out.mkdir(exist_ok=True)
            json.dump(
                to_json_safe(results),
                open(out / "metrics.json", "w"),
                indent=2
            )

    # ==================================================
    # Experiment 2: efConstruction sensitivity
    # ==================================================
    if RUN_EF_CONSTRUCTION:
        base_dir = run_dir / "efConstruction_study"
        base_dir.mkdir(exist_ok=True)

        for efc in EF_CONSTRUCTION_VALUES:
            index_ann = faiss.IndexHNSWFlat(d, M, faiss.METRIC_L2)
            index_ann.hnsw.efConstruction = efc
            index_ann.hnsw.efSearch = ef_search
            index_ann.add(embeddings)

            results = run_ann_condition(
                embeddings, index_ann,
                D_exact_full, I_exact_full
            )
            out = base_dir / f"efc_{efc}"
            out.mkdir(exist_ok=True)
            json.dump(
                to_json_safe(results),
                open(out / "metrics.json", "w"),
                indent=2
            )

    # ==================================================
    # Experiment 3: alpha study
    # ==================================================
    if RUN_ALPHA:
        base_dir = run_dir / "alpha_study"
        base_dir.mkdir(exist_ok=True)

        for alpha in ALPHAS:
            results = {
                "alpha": alpha,
                "k": [],
                "mean_overlap": [],
                "mean_exact_dist": [],
                "mean_ann_dist": [],
                "mean_barycenter_shift": [],
                "mean_lid_diff": [],
                "mean_lid_exact": [],
                "mean_lid_ann": [],
            }

            for k in K_VALUES:
                ef = max(EF_MIN, int(alpha * k))
                index_ann.hnsw.efSearch = ef

                D_exact = D_exact_full[:, 1:k+1]
                I_exact = I_exact_full[:, 1:k+1]

                D_ann, I_ann = index_ann.search(embeddings, k + 1)
                D_ann = D_ann[:, 1:]
                I_ann = I_ann[:, 1:]

                mean_ov, _ = compute_overlap_stats(I_exact, I_ann, k)
                dist_stats = compute_distance_stats(D_exact, D_ann)
                bary_shift, _ = compute_barycenter_stats(
                    embeddings, I_exact, I_ann, D_exact
                )
                lid_stats = compute_lid_stats(D_exact, D_ann)

                lid_exact = lid_stats["mean_lid_exact"]
                lid_ann = lid_stats["mean_lid_ann"]
                lid_diff = lid_ann - lid_exact

                results["k"].append(k)
                results["mean_overlap"].append(mean_ov)
                results["mean_exact_dist"].append(dist_stats["mean_exact_dist"])
                results["mean_ann_dist"].append(dist_stats["mean_ann_dist"])
                results["mean_barycenter_shift"].append(bary_shift)
                results["mean_lid_diff"].append(lid_diff)
                results["mean_lid_ann"].append(lid_ann)
                results["mean_lid_exact"].append(lid_diff)


            out = base_dir / f"alpha_{alpha}"
            out.mkdir(exist_ok=True)
            json.dump(
                to_json_safe(results),
                open(out / "metrics.json", "w"),
                indent=2
            )


Using device: cuda
Full STL-10 size: 100000



100%|██████████| 313/313 [00:12<00:00, 25.07it/s]


KeyboardInterrupt: 