# Within-category kNN diversity for BV exemplars

**Purpose:** For each category we have N BV exemplars. For each exemplar we compute the **mean distance to its k nearest neighbors** (within that category, excluding self).

**Interpretation:**
- **Low mean kNN distance** → micro-structure: exemplars form local clusters (e.g. every 5–10 exemplars are similar in views, format), so each point has close neighbors.
- **High mean kNN distance** → no consistent local structure; exemplars are more uniformly spread.

So we get a category-level "kNN diversity" (mean over exemplars of their mean-kNN-distance). **Lower = more micro-structure; higher = more uniform spread.**

**Outputs:**
1. Per-category summary CSV: category, n_exemplars, k, mean_knn_dist, std_knn_dist, median_knn_dist.
2. Optional per-exemplar CSV: category, subject_id, age_mo, mean_knn_dist, rank_within_cat.

Compare with centroid-based spread (`mean_bv_to_bv_centroid`): a category can have large overall spread but low mean_knn_dist (micro-structure) or high spread and high mean_knn_dist (no local consistency). See `visualize_bv_knn_diversity.ipynb`.

## Parameters

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm

SCRIPT_DIR = Path(".").resolve()

embedding = "clip"   # "clip" or "dinov3"
k_list = [5]         # e.g. [5, 10] for multi-k summary
save_exemplar_csv = False
out_dir = SCRIPT_DIR

## Load BV embeddings (reuse from centroid script)

We try to import from `bv_to_things_centroid_distances`; if not available (e.g. running notebook alone), we define a minimal loader and config.

In [None]:
try:
    from bv_to_things_centroid_distances import (
        load_allowed_categories,
        load_bv_embeddings,
        EXCLUDED_SUBJECT,
        MIN_EXEMPLARS,
    )
except ImportError:
    GROUPED_EMBEDDINGS_BASE = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings")
    GROUPED_EMBEDDINGS_DIRS = {
        "clip": GROUPED_EMBEDDINGS_BASE / "clip_embeddings_grouped_by_age-mo_normalized",
        "dinov3": GROUPED_EMBEDDINGS_BASE / "facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized",
    }
    CATEGORIES_FILE = SCRIPT_DIR / "../../../data/things_bv_overlap_categories_exclude_zero_precisions.txt"
    EXCLUDED_SUBJECT = "00270001"
    MIN_EXEMPLARS = 2

    def load_allowed_categories():
        if not CATEGORIES_FILE.exists():
            return None
        with open(CATEGORIES_FILE) as f:
            return set(line.strip() for line in f if line.strip())

    def load_bv_embeddings(embedding_type, allowed_categories, excluded_subject, min_exemplars=2):
        grouped_dir = GROUPED_EMBEDDINGS_DIRS[embedding_type]
        if not grouped_dir.exists():
            raise FileNotFoundError(f"BV grouped dir not found: {grouped_dir}")
        category_embeddings = {}
        category_exemplar_ids = {}
        for cat_folder in sorted(grouped_dir.iterdir()):
            if not cat_folder.is_dir():
                continue
            cat_name = cat_folder.name
            if allowed_categories is not None and cat_name not in allowed_categories:
                continue
            embs, ids = [], []
            for f in cat_folder.glob("*.npy"):
                stem = f.stem
                parts = stem.split("_")
                if len(parts) < 2:
                    continue
                subject_id, age_mo = parts[0], None
                try:
                    age_mo = int(parts[1])
                except ValueError:
                    continue
                if excluded_subject and subject_id == excluded_subject:
                    continue
                try:
                    e = np.load(f)
                    e = np.asarray(e, dtype=np.float64).flatten()
                    embs.append(e)
                    ids.append((subject_id, age_mo))
                except Exception:
                    continue
            if len(embs) >= min_exemplars:
                category_embeddings[cat_name] = np.array(embs)
                category_exemplar_ids[cat_name] = ids
        return category_embeddings, category_exemplar_ids

## kNN mean distance per exemplar

For each row in X we compute mean L2 distance to its k nearest neighbors (excluding self), using k+1 neighbors and dropping the first (self).

In [None]:
def compute_knn_mean_distances(X, k):
    from sklearn.neighbors import NearestNeighbors
    n = X.shape[0]
    n_neighbors = min(k + 1, n)
    if n_neighbors < 2:
        return np.full(n, np.nan)
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric="euclidean", algorithm="auto")
    nn.fit(X)
    distances, _ = nn.kneighbors(X)
    mean_knn = np.mean(distances[:, 1:], axis=1)
    return mean_knn


def run_knn_per_category(bv_embeddings, bv_ids, categories, k):
    summary_rows = []
    exemplar_rows = []
    for cat in tqdm(categories, desc="kNN per category"):
        X = bv_embeddings[cat]
        id_list = bv_ids[cat]
        n = X.shape[0]
        effective_k = min(k, n - 1)
        if effective_k < 1:
            summary_rows.append({
                "category": cat, "n_exemplars": n, "k": k, "effective_k": 0,
                "mean_knn_dist": np.nan, "std_knn_dist": np.nan, "median_knn_dist": np.nan,
            })
            continue
        mean_knn = compute_knn_mean_distances(X, effective_k)
        summary_rows.append({
            "category": cat, "n_exemplars": n, "k": k, "effective_k": effective_k,
            "mean_knn_dist": float(np.nanmean(mean_knn)),
            "std_knn_dist": float(np.nanstd(mean_knn)),
            "median_knn_dist": float(np.nanmedian(mean_knn)),
        })
        for i, (sid, age_mo) in enumerate(id_list):
            exemplar_rows.append({
                "category": cat, "subject_id": sid, "age_mo": age_mo,
                "mean_knn_dist": float(mean_knn[i]),
            })
    return summary_rows, exemplar_rows

## Run for each k and save

In [None]:
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

allowed = load_allowed_categories()
print(f"Using {len(allowed) if allowed else 'all'} categories")

print("Loading BV embeddings...")
bv_embeddings, bv_ids = load_bv_embeddings(
    embedding, allowed_categories=allowed, excluded_subject=EXCLUDED_SUBJECT, min_exemplars=MIN_EXEMPLARS
)
categories = sorted(bv_embeddings.keys())
print(f"Categories: {len(categories)}")

all_summary = []
for k in k_list:
    prefix = f"bv_within_category_knn_{embedding}_k{k}"
    summary_rows, exemplar_rows = run_knn_per_category(bv_embeddings, bv_ids, categories, k)
    summary_df = pd.DataFrame(summary_rows)
    summary_df = summary_df.sort_values("mean_knn_dist", ascending=True).reset_index(drop=True)
    summary_path = out_dir / f"{prefix}_summary.csv"
    summary_df.to_csv(summary_path, index=False)
    print(f"Saved {prefix}: {summary_path}")

    if save_exemplar_csv:
        exemplar_df = pd.DataFrame(exemplar_rows)
        exemplar_df["rank_within_cat"] = exemplar_df.groupby("category")["mean_knn_dist"].rank(
            method="first", ascending=True
        ).astype(int)
        exemplar_path = out_dir / f"{prefix}_per_exemplar.csv"
        exemplar_df.to_csv(exemplar_path, index=False)
        print(f"Saved per-exemplar: {exemplar_path}")

    all_summary.append(summary_df.assign(k_used=k))

if len(k_list) > 1:
    combined = pd.concat(all_summary, ignore_index=True)
    combined_path = out_dir / f"bv_within_category_knn_{embedding}_multi_k_summary.csv"
    combined.to_csv(combined_path, index=False)
    print(f"Saved combined multi-k summary: {combined_path}")

print("Done. Lower mean_knn_dist = more micro-structure; higher = more uniform spread.")