# Visualize BV within-category kNN diversity

**Purpose:** Reads the summary CSVs from `bv_within_category_knn_diversity.ipynb` and optionally the centroid summary from `bv_to_things_centroid_distances.ipynb` to produce:

1. **Bar plot:** Categories ranked by mean_knn_dist (low = more micro-structure). Options: all categories or top/bottom N.
2. **k=5 vs k=10:** Scatter of mean_knn_dist at k=5 vs k=10 (from multi_k summary).
3. **kNN vs centroid spread:** Scatter mean_knn_dist vs mean_bv_to_bv_centroid â€” categories with high spread but low kNN (micro-structure) vs high spread and high kNN (no local consistency).
4. **Per-exemplar distribution:** Violin/box of mean_knn_dist for selected categories (requires per-exemplar CSV).

## Parameters

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

SCRIPT_DIR = Path(".").resolve()
data_dir = SCRIPT_DIR
out_dir = SCRIPT_DIR

embedding = "clip"
k = 5
bar_n = 40
centroid_summary_path = None  # or e.g. data_dir / "bv_to_things_centroid_clip_summary.csv"
violin_categories = None      # e.g. ["crayon", "cat", "zebra", "chair"] or None to skip
no_multi_k = False

## Load data and define plot helpers

In [None]:
def load_knn_summary(embedding, k, data_dir):
    path = data_dir / f"bv_within_category_knn_{embedding}_k{k}_summary.csv"
    if not path.exists():
        raise FileNotFoundError(f"kNN summary not found: {path}")
    return pd.read_csv(path)

def load_multi_k_summary(embedding, data_dir):
    path = data_dir / f"bv_within_category_knn_{embedding}_multi_k_summary.csv"
    if not path.exists():
        return None
    return pd.read_csv(path)

def load_centroid_summary(path):
    if path is None or not Path(path).exists():
        return None
    return pd.read_csv(path)

if centroid_summary_path is None:
    centroid_summary_path = data_dir / f"bv_to_things_centroid_{embedding}_summary.csv"

knn_df = load_knn_summary(embedding, k, data_dir)
multi_df = load_multi_k_summary(embedding, data_dir)
centroid_df = load_centroid_summary(centroid_summary_path)
print(f"Loaded kNN summary: {len(knn_df)} categories. Multi-k: {multi_df is not None}. Centroid: {centroid_df is not None}.")

## 1. Bar plot: categories ranked by mean kNN distance

In [None]:
def plot_rank_bars(df, out_path, n_show=40, show="both", title_suffix=""):
    if show == "top":
        plot_df = df.head(n_show)
    elif show == "bottom":
        plot_df = df.tail(n_show).iloc[::-1].reset_index(drop=True)
    else:
        plot_df = df.head(n_show)
    fig, ax = plt.subplots(figsize=(10, max(6, len(plot_df) * 0.22)))
    y_pos = np.arange(len(plot_df))
    ax.barh(y_pos, plot_df["mean_knn_dist"], xerr=plot_df["std_knn_dist"], capsize=2, alpha=0.85)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(plot_df["category"], fontsize=9)
    ax.set_xlabel("Mean kNN distance (lower = more micro-structure)")
    ax.set_ylabel("Category")
    ax.set_title(f"Within-category kNN diversity (ranked){title_suffix}")
    ax.invert_yaxis()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {out_path}")

out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
prefix = f"bv_within_category_knn_{embedding}"
plot_rank_bars(knn_df, out_dir / f"{prefix}_k{k}_rank_bars.png", n_show=bar_n, title_suffix=f" (k={k})")

## 2. k=5 vs k=10 scatter (if multi_k summary exists)

In [None]:
if not no_multi_k and multi_df is not None and (multi_df["k_used"] == 10).any():
    k5 = multi_df[multi_df["k_used"] == 5].set_index("category")["mean_knn_dist"]
    k10 = multi_df[multi_df["k_used"] == 10].set_index("category")["mean_knn_dist"]
    common = k5.index.intersection(k10.index)
    x, y = k5.loc[common].values, k10.loc[common].values
    fig, ax = plt.subplots(figsize=(7, 7))
    ax.scatter(x, y, alpha=0.6, s=25)
    lims = [min(x.min(), y.min()), max(x.max(), y.max())]
    ax.plot(lims, lims, "k--", alpha=0.5, label="y=x")
    ax.set_xlabel("Mean kNN distance (k=5)")
    ax.set_ylabel("Mean kNN distance (k=10)")
    ax.set_title(f"k=5 vs k=10 within-category kNN diversity ({embedding})")
    ax.legend()
    ax.set_aspect("equal")
    plt.tight_layout()
    plt.savefig(out_dir / f"{prefix}_k5_vs_k10_scatter.png", dpi=150, bbox_inches="tight")
    plt.close()
    print("  Saved k5 vs k10 scatter.")
else:
    print("  Skipping k5 vs k10 (no multi_k or k=10).")

## 3. kNN diversity vs centroid spread

In [None]:
def plot_knn_vs_centroid_spread(knn_df, centroid_df, out_path, embedding, k, label_n=12):
    merge = knn_df.merge(
        centroid_df[["category", "mean_bv_to_bv_centroid"]], on="category", how="inner"
    )
    if merge.empty:
        print("  No overlap; skip kNN vs centroid plot.")
        return
    fig, ax = plt.subplots(figsize=(8, 7))
    ax.scatter(merge["mean_knn_dist"], merge["mean_bv_to_bv_centroid"], alpha=0.6, s=30)
    low_knn = merge.nsmallest(label_n // 2, "mean_knn_dist")
    high_knn = merge.nlargest(label_n // 2, "mean_knn_dist")
    to_label = pd.concat([low_knn, high_knn]).drop_duplicates()
    for _, row in to_label.iterrows():
        ax.annotate(row["category"], (row["mean_knn_dist"], row["mean_bv_to_bv_centroid"]),
                    fontsize=8, alpha=0.9, xytext=(4, 4), textcoords="offset points")
    ax.set_xlabel("Mean kNN distance (lower = more micro-structure)")
    ax.set_ylabel("Mean BV-to-BV centroid distance (overall spread)")
    ax.set_title(f"kNN diversity vs centroid spread ({embedding}, k={k})")
    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"  Saved: {out_path}")

if centroid_df is not None:
    plot_knn_vs_centroid_spread(
        knn_df, centroid_df, out_dir / f"{prefix}_k{k}_vs_centroid_spread.png",
        embedding, k
    )

## 4. Violin plot for selected categories (per-exemplar CSV)

In [None]:
if violin_categories:
    exemplar_path = data_dir / f"{prefix}_k{k}_per_exemplar.csv"
    if exemplar_path.exists():
        df = pd.read_csv(exemplar_path)
        df = df[df["category"].isin(violin_categories)]
        if not df.empty:
            fig, ax = plt.subplots(figsize=(max(6, len(violin_categories) * 1.2), 5))
            ax.violinplot(
                [df[df["category"] == c]["mean_knn_dist"].values for c in violin_categories],
                positions=range(len(violin_categories)), showmeans=True, showmedians=True
            )
            ax.set_xticks(range(len(violin_categories)))
            ax.set_xticklabels(violin_categories, rotation=45, ha="right")
            ax.set_ylabel("Per-exemplar mean kNN distance")
            ax.set_xlabel("Category")
            ax.set_title(f"Within-category distribution of kNN distance (k={k}, {embedding})")
            plt.tight_layout()
            plt.savefig(out_dir / f"{prefix}_k{k}_violins_selected.png", dpi=150, bbox_inches="tight")
            plt.close()
            print("  Saved violins.")
        else:
            print("  No data for violin categories.")
    else:
        print(f"  Per-exemplar CSV not found: {exemplar_path}. Run bv_within_category_knn_diversity with save_exemplar_csv=True.")
else:
    print("  Skipping violins (violin_categories not set).")

print("Done.")