# Cluster Properties by Depth

In [1]:
import enum
import pathlib

import numpy
import pandas
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
EXPECTED_COLUMNS = [
    "center_id",
    "depth",
    "cardinality",
    "radius",
    "lfd",
    "radial_sum",
    "span",
    "num_children",
]

COL_DTYPE = {
    "center_id": numpy.uint32,
    "depth": numpy.uint32,
    "cardinality": numpy.uint32,
    "radius": numpy.float32,
    "lfd": numpy.float32,
    "radial_sum": numpy.float32,
    "span": numpy.float32,
    "num_children": numpy.uint32,
}


class PlottableColumns(enum.StrEnum):
    """Columns that may be plotted on the y-axis against depth on the x-axis."""

    Cardinality = "cardinality"
    Radius = "radius"
    LFD = "lfd"
    RadialSum = "radial_sum"
    Span = "span"
    NumChildren = "num_children"

    def should_log_scale(self) -> bool:
        """Whether the column should be plotted on a log scale."""
        return self in {
            PlottableColumns.Cardinality,
            PlottableColumns.Radius,
            PlottableColumns.RadialSum,
            PlottableColumns.Span,
        }


class PartitionStrategy(enum.StrEnum):
    """Partition strategies for clustering."""
    # Branching factor
    Fixed2 = "bf(fixed(2))"
    Logarithmic = "bf(logarithmic)"
    Adaptive = "bf(adaptive(128))"
    # Span reduction factor
    Sqrt2 = "srf(sqrt2)"
    Two = "srf(two)"
    E = "srf(e)"
    Pi = "srf(pi)"
    Phi = "srf(phi)"


class SearchAlgorithm(enum.StrEnum):
    """Search algorithms for clustering."""
    # Depth-first Sieve
    KnnDfs10 = "KnnDfs(k=10)"
    KnnDfs100 = "KnnDfs(k=100)"
    # Breadth-first Sieve
    KnnBfs10 = "KnnBfs(k=10)"
    KnnBfs100 = "KnnBfs(k=100)"


class ReportFilePostfixes(enum.StrEnum):
    """Postfixes for report files."""
    Tree = "tree.csv"
    Performance = "performance.json"
    Neighbors = "neighbors.npy"
    Distances = "distances.npy"


class Datasets(enum.StrEnum):
    """Datasets from ann-benchmarks."""
    FashionMnist = "fashion-mnist"
    Mnist = "mnist"
    Sift = "sift"
    Gist = "gist"
    Glove25 = "glove-25"
    Glove50 = "glove-50"
    Glove100 = "glove-100"
    Glove200 = "glove-200"
    DeepImage = "deep-image"
    LastFM = "lastfm"

In [3]:
benchmarks_dir = pathlib.Path.cwd().parents[3] / "data" / "ann_data" / "cakes-benchmarks"
assert benchmarks_dir.exists()

for d in Datasets:
    data_dir = benchmarks_dir / d
    assert data_dir.exists(), f"Data directory {data_dir} does not exist."

benchmarks_dir

PosixPath('/home/nishaq/Documents/research/data/ann_data/cakes-benchmarks')

In [4]:
def plot_for_data(d: Datasets, p: PlottableColumns) -> go.Figure:
    """Create a subplot for each partition strategy."""
    fig = make_subplots(
        rows=2, cols=4,
        subplot_titles=list(PartitionStrategy),
        horizontal_spacing=0.025,
        vertical_spacing=0.1,
    )

    for i, ps in enumerate(PartitionStrategy):
        col = i // 2 + 1
        row = i % 2 + 1

        tree_file = data_dir / f"{ps}-{ReportFilePostfixes.Tree}"
        assert tree_file.exists(), f"Tree file {tree_file} does not exist."
        tree_df = pandas.read_csv(tree_file, dtype=COL_DTYPE)
        tree_df = tree_df[tree_df["num_children"] > 0]  # Ignore leaves

        prop_vs_depth = tree_df[["depth", p]].groupby("depth").agg(
            min=(p.value, "min"),
            p5=(p.value, lambda x: x.quantile(0.05)),
            p25=(p.value, lambda x: x.quantile(0.25)),
            median=(p.value, "median"),
            p75=(p.value, lambda x: x.quantile(0.75)),
            p95=(p.value, lambda x: x.quantile(0.95)),
            max=(p.value, "max"),
        ).reset_index()

        color_map = px.colors.qualitative.Plotly

        pvd_melted = prop_vs_depth.melt(id_vars=["depth"], var_name="stat", value_name=p.value)
        for stat in pvd_melted["stat"].unique():
            stat_df = pvd_melted[pvd_melted["stat"] == stat]
            fig.add_trace(
                go.Scatter(
                    x=stat_df["depth"],
                    y=stat_df[p.value],
                    mode="lines",
                    name=stat,
                    showlegend=(i == 0),
                    line=dict(
                        color=color_map[["min", "p5", "p25", "median", "p75", "p95", "max"].index(stat)],
                        dash="solid" if stat in {"median"} else "dash" if stat in {"p25", "p75"} else "dot",
                        width=2 if stat in {"median"} else 1.25 if stat in {"p25", "p75"} else 0.75,
                    ),
                ),
                row=row,
                col=col,
            )

        del tree_df, prop_vs_depth, pvd_melted  # Free memory

    if p.should_log_scale():
        fig.update_yaxes(type="log")

    fig.update_layout(title=f"{d}: {p} vs Depth", height=800, width=2400)
    return fig

In [5]:
d = Datasets.FashionMnist

In [6]:
fig = plot_for_data(d, PlottableColumns.LFD)
fig.show()

In [None]:
fig = plot_for_data(d, PlottableColumns.Radius)
fig.show()

In [None]:
fig = plot_for_data(d, PlottableColumns.Span)
fig.show()

In [None]:
fig = plot_for_data(d, PlottableColumns.NumChildren)
fig.show()