# CLAM Clustering: New Algorithm and Recurrence Relations

Given:

- $f$: a distance function
- $C$: a Cluster to partition with $|C| \geq 2$
- $criteria$: user-defined continuation criteria

The algorithm partitions $C$ into child clusters $L$ and $R$ as follows:

1. $S \leftarrow$ a random sample of $\Big\lceil \sqrt{|C|} \Big\rceil$ points from $C$
2. $c \leftarrow$ the geometric median of $S$
3. Remove $c$ from $C$ and assign it as the center of $C$.
4. $l \leftarrow$ the point in $C$ farthest from $c$
5. $r \leftarrow$ the point in $C$ farthest from $l$
6. $L \leftarrow$ the points in $C$ closer to $l$ than to $r$
7. $R \leftarrow$ the remaining points in $C$
8. If $|L| > 2$ and $criteria(L)$ is true, recursively partition $L$
9. If $|R| > 2$ and $criteria(R)$ is true, recursively partition $R$

The key difference is that we do not pass the center $c$ down to the child clusters. This also changes the definition of a $leaf$ cluster: a leaf cluster either contains 1 point (which is its own center) or contains 2 points (in which case one is the center and the other, being a singular point, cannot be used to make further children).

## Recurrence Relations

Let $T(n)$ be the number of clusters in the tree produced for a dataset containing $n$ points. The recurrence relation for $T(n)$ is given by:

1. Base Case (the leaf clusters): $T(1) = 1$ and $T(2) = 1$.
2. Recursive Case (the parent clusters with $n > 2$):
    - $T(1 + 2n) = 1 + 2T(n)$ for odd $n$
    - $T(2 + 2n) = 1 + T(n + 1) + T(n)$ for even $n$

Clearly $T(n) \leq n$ by the pigeonhole principle. The two are equal when every leaf cluster contains exactly one point (its own center). However, when every leaf cluster contains exactly two points, $T(n)$ approaches a lower bound of $\frac{2n}{3}$ for large $n$.

In [None]:
import numpy
import pandas
import plotly.colors as pc
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.stats import beta


In [None]:
# pyright: reportUnknownMemberType=false

def compute_memo(min_n: int, max_n: int, b: int = 2) -> pandas.DataFrame:
    """Compute the memoization table for our recurrence relation.

    For a dataset of size n and a branching factor of b, the number of clusters in the tree T(n)
    is given by the following recurrence relations:
      - T(1) = 1 and T(2) = 1, the leaf clusters
      - T(n) = n - 1 for 3 <= n <= b + 1, parent cluster whose children are all leaves
      - T(1 + a + b * n) = 1 + a * T(n + 1) + (b - a) * T(n) for n >= b + 2 and 0 <= a < b

    Args:
        min_n: The minimum value of n to compute. This is to reduce noise in the output.
        max_n: The maximum value of n to compute.
        b: The branching factor.

    Returns:
        A pandas DataFrame with columns "n", "T(n)", and "T(n)/n".
    """
    memo = [0] * (max_n + 1)
    memo[0] = 1
    memo[1] = 1
    memo[2] = 1

    for n in range(3, b + 2):
        memo[n] = n - 1

    for n in range(b + 2, max_n + 1):
        q = (n - 1) // b
        a = (n - 1) % b
        memo[n] = 1 + a * memo[q + 1] + (b - a) * memo[q]

    memo = memo[1 + min_n:]
    ratios = [(n, t, t / n) for n, t in enumerate(memo, start=min_n)]
    (n, t, r) = tuple(zip(*ratios))

    return pandas.DataFrame({"n": n, "T(n)": t, "T(n)/n": r})


In [None]:
min_n = 10
max_n = 100_000

tree_size_df = compute_memo(min_n, max_n)


In [None]:
# pyright: reportUnknownMemberType=false

type Data = list[tuple[int, pandas.DataFrame]] | pandas.DataFrame


def make_plot(min_n: int, max_n: int, data: Data) -> go.Figure:
    """Create the plots for the recurrence relations, using the same color for each branching factor and combining the legends."""
    if isinstance(data, pandas.DataFrame):
        mean_ratio = data["T(n)/n"].mean()
        data = [(2, data)]
        fig = _make_plot(min_n, max_n, data, False)
        fig.add_trace(
            go.Scatter(
            x=[min_n, max_n],
            y=[mean_ratio, mean_ratio],
            mode="lines",
            line=dict(color="Blue", dash="dash"),
            name="Mean",
            legendgroup="Mean"
            ),
            row=1,
            col=2
        )
        return fig

    return _make_plot(min_n, max_n, data, True)


def _make_plot(min_n: int, max_n: int, data: list[tuple[int, pandas.DataFrame]], reveal: bool) -> go.Figure:
    """Create the plots for the recurrence relations, using the same color for each branching factor and combining the legends."""
    # Assign a color for each branching factor
    bs = [b for b, _ in data]
    palette = pc.qualitative.Set1 if len(bs) <= len(pc.qualitative.Set1) else pc.qualitative.Plotly
    color_map = {b: palette[i % len(palette)] for i, b in enumerate(bs)}

    if reveal:
        titles = ["T(n, b) vs n for various branching factors b", "Ratio of clusters to data points: T(n, b)/n vs n for various branching factors b"]
    else:
        titles = ["T(n) vs n", "Ratio of clusters to data points: T(n)/n vs n"]

    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=titles,
        column_widths=[0.33, 0.67],
        horizontal_spacing=0.05
    )
    fig.add_shape(
        type="line",
        x0=min_n,
        y0=min_n,
        x1=max_n,
        y1=max_n,
        line=dict(color="Black", dash="dash"),
        row=1, col=1
    )
    if reveal:
        fig.add_trace(
            go.Scatter(
            x=[min_n, max_n],
            y=[1/2, 1/2],
            mode="lines",
            line=dict(color="Black", dash="dash"),
            name="Lower Bound",
            legendgroup="Lower Bound"
            ),
            row=1,
            col=2
        )
    else:
        fig.add_trace(
            go.Scatter(
            x=[min_n, max_n],
            y=[2/3, 2/3],
            mode="lines",
            line=dict(color="Black", dash="dash"),
            name="Lower Bound",
            legendgroup="Lower Bound"
            ),
            row=1,
            col=2
        )

    for b, tree_size_df in data:
        color = color_map[b]
        name = f"b={b}" if reveal else "T(n)"
        # First plot: T(n) vs n
        fig.add_trace(
            go.Scatter(x=tree_size_df["n"], y=tree_size_df["T(n)"], mode="lines", name=name, legendgroup=name, line=dict(color=color)),
            row=1, col=1
        )
        # Second plot: T(n)/n vs n
        fig.add_trace(
            go.Scatter(x=tree_size_df["n"], y=tree_size_df["T(n)/n"], mode="lines", name=name, legendgroup=name, line=dict(color=color), showlegend=False),
            row=1, col=2
        )

    fig.update_xaxes(type="log", title_text="n (log scale)", row=1, col=1)
    fig.update_yaxes(type="log", title_text="T(n) (log scale)", row=1, col=1)
    fig.update_xaxes(type="log", title_text="n (log scale)", row=1, col=2)
    fig.update_yaxes(title_text="T(n)/n", row=1, col=2)
    fig.update_layout(width=1600, height=600, showlegend=True)
    return fig


In [None]:
fig = make_plot(min_n, max_n, tree_size_df)
fig.show()


## Unbalanced Clustering

We can use a specialization of the $\beta eta$ distribution to split points between the left and right children.

The PDF of the $\beta eta$ distribution is given by $f(x; \alpha, \beta) = \textit{constant} \cdot \frac{x^{\alpha - 1}(1 - x)^{\beta - 1}}{B(\alpha, \beta)}$, where the *constant* is chosen such that the integral of $f$ over $[0, 1]$ is 1.

We would use $f(x; \alpha) = \frac{1}{B(\alpha + 1, 2)} \cdot x^\alpha (1 - x)$, which corresponds to a $\beta eta$ distribution with parameters $\alpha + 1$ and $\beta = 2$.

In [None]:
def plot_beta_specialized() -> go.Figure:
    """Plot the specialized Beta distribution with various alpha values."""
    alphas = [1, 2, 5, 10]
    x = numpy.linspace(0, 1, 500)
    ys = [beta.pdf(x, alpha + 1, 2) for alpha in alphas]
    max_y = max(y.max() for y in ys)
    fig = go.Figure()
    colors = pc.qualitative.Set1 if len(alphas) <= len(pc.qualitative.Set1) else pc.qualitative.Plotly

    for alpha, y, color in zip(alphas, ys, colors):
        m = alpha / (alpha + 1)
        fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=f"α={alpha}", line=dict(color=color)))  # noqa: RUF001
        fig.add_trace(
            go.Scatter(
            x=[m, m],
            y=[-0.05, max_y * 1.05],
            mode="lines",
            line=dict(color=color, dash="dash"),
            name=f"Peak (α={alpha})",  # noqa: RUF001
            legendgroup=f"Peak (α={alpha})",  # noqa: RUF001
            showlegend=False,
            )
        )

    fig.update_layout(
        title="Specialized Beta Distribution: The peak moves right as α increases, and is given by α/(α+1)",  # noqa: RUF001
        xaxis_title="x",
        yaxis_title="f(x, α)",  # noqa: RUF001
        legend_title="Parameters",
        width=1600,
        height=800
    )
    return fig


In [None]:
betas = plot_beta_specialized()
betas.show()


## Unbalanced Recurrence Relations

The recurrence relation for $T(n)$ in an unbalanced clustering is given by:

1. Base Case (the leaf clusters): $T(1) = 1$ and $T(2) = 1$.
2. Recursive Case: $T\big(1 + n\big) = 1 + T\big(p \cdot n\big) + T\big((1 - p) \cdot n\big) + n$ for $n > 2$, where $p$ is drawn from the $\beta eta$ distribution described above.

We draw a new $p$ for each evaluation of the recurrence relation.

In [None]:
def estimate_tn_unbalanced(n: int, alphas: list[int], n_samples: int, b: int = 2) -> tuple[numpy.ndarray, numpy.ndarray]:
    """Recursively estimate T(n)/n for unbalanced clustering using the specialized Beta distribution, along with the depth of the tree.

    Args:
        n: The number of data points.
        alphas: A list of alpha values to use for the Beta distribution.
        n_samples: The number of random samples of T(n) to compute for each alpha value.
        b: The branching factor. Currently defaults to 2, and other values are not supported.

    Returns:
        An array of of shape (n_samples, len(alphas)) containing the T(n)/n estimates for each alpha and sample.
    """
    if b != 2:
        raise NotImplementedError("Currently only b=2 is supported.")

    rng = numpy.random.default_rng(42)
    tns = numpy.zeros((n_samples, len(alphas)), dtype=numpy.int32)
    depths = numpy.zeros((n_samples, len(alphas)), dtype=numpy.int32)
    for i, a in enumerate(alphas):
        for j in range(n_samples):
            tn, depth = _estimate_tn_unbalanced(n, a, rng)
            tns[j, i] = tn
            depths[j, i] = depth
    return tns.astype(numpy.float32) / n, depths


def _estimate_tn_unbalanced(n: int, a: int, rng: numpy.random.Generator) -> tuple[int, int]:
    """Recursively estimate T(n) for unbalanced clustering using the specialized Beta distribution."""
    if n <= 2:
        return 1, 1

    p = rng.beta(a + 1, 2)
    q = (n - 3) * p  # We remove the center and both poles before splitting the remaining points
    left = int(numpy.floor(q)) + 1  # We add back the left pole for the left child
    right = n - 1 - left  # The rest goes to the right child
    assert left > 0, f"left={left}, right={right} for n={n} with p={p}"
    assert right > 0, f"left={left}, right={right} for n={n} with p={p}"
    left_tn, left_depth = _estimate_tn_unbalanced(left, a, rng)
    right_tn, right_depth = _estimate_tn_unbalanced(right, a, rng)
    tn = 1 + left_tn + right_tn
    depth = 1 + max(left_depth, right_depth)
    return tn, depth


In [None]:
n = 100_000
alphas = [1, 2, 5, 10]
n_samples = 100

ratios, depths = estimate_tn_unbalanced(n, alphas, n_samples)


In [None]:
def plot_unbalanced_distributions(n: int, alphas: list[int], ratios: numpy.ndarray, depths: numpy.ndarray) -> go.Figure:
    """Plot the unbalanced distributions of T(n)/n and depths for various alpha values."""
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=[
            f"T(n)/n Distribution for n={n}",
            f"Max Depth Distribution for n={n}"
        ],
        vertical_spacing=0.15
    )
    colors = pc.qualitative.Set1 if len(alphas) <= len(pc.qualitative.Set1) else pc.qualitative.Plotly

    min_ratio = ratios.min()
    max_ratio = ratios.max()
    max_depth = depths.max()

    for alpha, color, ratio, depth in zip(alphas, colors, ratios.T, depths.T):
        fig.add_trace(
            go.Histogram(
            x=ratio,
            name=f"α={alpha}",  # noqa: RUF001
            marker_color=color,
            opacity=0.75,
            legendgroup=f"α={alpha}"  # noqa: RUF001
            ),
            row=1, col=1
        )
        fig.add_trace(
            go.Histogram(
            x=depth,
            name=f"α={alpha}",  # noqa: RUF001
            marker_color=color,
            opacity=0.75,
            legendgroup=f"α={alpha}",  # noqa: RUF001
            showlegend=False
            ),
            row=2, col=1
        )

    # Handle the axes
    fig.update_xaxes(title_text="T(n)/n", range=[min_ratio - 0.01, max_ratio + 0.01], row=1, col=1)
    fig.update_yaxes(title_text="Count", row=1, col=1)

    fig.update_xaxes(title_text="Max Depth", range=[0, max_depth + 1], row=2, col=1)
    fig.update_yaxes(title_text="Count", row=2, col=1)

    fig.update_layout(
        title=f"Unbalanced Clustering: Distributions of T(n)/n and Max Depth for Various α, n={n}, and {ratios.shape[0]} samples each",  # noqa: RUF001
        width=1600,
        height=1000,
        barmode="overlay",
        showlegend=True
    )
    return fig


In [None]:
unbalanced_distributions_fig = plot_unbalanced_distributions(n, alphas, ratios, depths)
unbalanced_distributions_fig.show()


## Generalization to Arbitrary Branching Factor `b`

For a dataset of size `n` and a branching factor of `b`, the number of clusters in the tree `T(n)` is given by the following recurrence relations:

1. Base Case (the leaf clusters): $T(1) = 1$ and $T(2) = 1$.
2. Just before the base case (the parents of leaf clusters): $T(n) = n - 1$ for $3 \leq n \leq b + 1$.
3. Recursive Case (the parent clusters with $n > b + 1$):
   - $T(1+a+bn) = 1 + aT(n+1) + (b-a)T(n)$ for $n > b+1$ and $0 \leq a < b$.

In [None]:
bs = list(range(2, 11))
data = [(b, compute_memo(min_n, max_n, b)) for b in bs]


In [None]:
fig = make_plot(min_n, max_n, data)
fig.show()


In [None]:
min_n = 10
max_n = 100_000
bs = list(range(2, 65))
data = [(b, compute_memo(min_n, max_n, b)) for b in bs]


In [None]:
fig = make_plot(min_n, max_n, data)
fig.show()
