# 12 K-means

Read about the algorithm: [K-means Clustering](https://en.wikipedia.org/wiki/K-means_clustering).

## Implement Functions

In [3]:
import numpy as np

In [None]:
def pairwise_squared_distances(x: np.ndarray, centers: np.ndarray) -> np.ndarray:
    """
    Compute squared Euclidean distances from each point to each center.

    Args:
      x:       [N, D] float
      centers: [K, D] float

    Returns:
      d2: [N, K] float where d2[n, k] = ||x[n] - centers[k]||^2

    Constraints:
      - Do NOT loop over N or K.
      - Use broadcasting / vectorization.
    """
    raise NotImplementedError

In [None]:
def assign_clusters(d2: np.ndarray) -> np.ndarray:
    """
    Assign each point to the nearest center.

    Args:
      d2: [N, K] float distances squared

    Returns:
      labels: [N] int, labels[n] in {0..K-1}

    Constraints:
      - No loops over N.
    """
   raise NotImplementedError

In [None]:
def update_centers(x: np.ndarray, labels: np.ndarray, K: int) -> np.ndarray:
    """
    Update centers as the mean of assigned points.

    Args:
      x:      [N, D] float
      labels: [N] int in {0..K-1}
      K:      number of clusters

    Returns:
      new_centers: [K, D] float

    Constraints:
      - No Python loops over N.
      - Prefer vectorized group-by (e.g., one-hot + matmul, or np.add.at).
      - Handle empty clusters deterministically:
          Keep the previous center OR reinitialize (choose one and document it).
          (You will be passed prev_centers via kwargs if you choose keep-previous.)
    """
    raise NotImplementedError

## Check Test Cases

In [7]:
def kmeans(
    x: np.ndarray,
    K: int,
    *,
    num_iters: int = 20,
    seed: int = 0,
    init: str = "random",  # "random" or "kmeans++" (optional)
    return_history: bool = False,
) -> dict:
    """
    Run K-means clustering.

    Args:
      x: [N, D]
      K: number of clusters

    Returns dict with:
      centers: [K, D]
      labels:  [N]
      inertia: float  # sum of squared distances to assigned center
      history (optional): list of inertia values
    """
    rng = np.random.default_rng(seed)
    N, D = x.shape

    if init == "random":
        # Initialize by sampling K distinct points
        idx = rng.choice(N, size=K, replace=False)
        centers = x[idx].copy()
    else:
        raise NotImplementedError("Optional: implement kmeans++ init")

    history = []
    for _ in range(num_iters):
        d2 = pairwise_squared_distances(x, centers)         # [N, K]
        labels = assign_clusters(d2)                        # [N]

        # Option A: keep-previous-center behavior for empty clusters
        new_centers = update_centers(x, labels, K)          # [K, D]

        # Convergence / debug checks
        if not np.all(np.isfinite(new_centers)):
            raise ValueError("Non-finite centers encountered (likely empty cluster / divide-by-zero).")

        centers = new_centers

        inertia = float(d2[np.arange(N), labels].sum())
        history.append(inertia)

    out = {"centers": centers, "labels": labels, "inertia": history[-1]}
    if return_history:
        out["history"] = history
    return out

In [8]:
def _assert_close(a, b, atol=1e-6, rtol=1e-6, msg=""):
    np.testing.assert_allclose(a, b, atol=atol, rtol=rtol, err_msg=msg)

def run_basic_kmeans_tests():
    # --------
    # Test 1: pairwise distances sanity
    # --------
    x = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 2.0]], dtype=np.float32)     # [N=3, D=2]
    c = np.array([[0.0, 0.0], [1.0, 1.0]], dtype=np.float32)                # [K=2, D=2]
    # Expected squared distances:
    # to [0,0]: [0, 1, 4]
    # to [1,1]: [2, 1, 2]
    expected = np.array([[0.0, 2.0], [1.0, 1.0], [4.0, 2.0]], dtype=np.float32)
    d2 = pairwise_squared_distances(x, c)
    _assert_close(d2, expected, msg="pairwise_squared_distances failed basic check")

    # --------
    # Test 2: assignment sanity (ties choose smaller index, as argmin does)
    # --------
    labels = assign_clusters(expected)
    # row0: nearest center0
    # row1: tie (1 vs 1) -> choose 0
    # row2: nearest center1
    expected_labels = np.array([0, 0, 1], dtype=np.int64)
    if labels.dtype.kind != "i":
        raise AssertionError("labels must be integer dtype")
    if labels.shape != (3,):
        raise AssertionError(f"labels shape expected (3,), got {labels.shape}")
    if not np.array_equal(labels, expected_labels):
        raise AssertionError(f"assign_clusters expected {expected_labels}, got {labels}")

    # --------
    # Test 3: center update sanity (no empties)
    # --------
    x2 = np.array(
        [[0.0, 0.0],
         [1.0, 0.0],
         [10.0, 0.0],
         [11.0, 0.0]],
        dtype=np.float32,
    )
    labels2 = np.array([0, 0, 1, 1], dtype=np.int64)
    new_c = update_centers(x2, labels2, K=2)
    expected_c = np.array([[0.5, 0.0], [10.5, 0.0]], dtype=np.float32)
    _assert_close(new_c, expected_c, msg="update_centers failed basic mean aggregation")

    # --------
    # Test 4: end-to-end KMeans on easy separated blobs
    # --------
    rng = np.random.default_rng(0)
    blob0 = rng.normal(loc=[0.0, 0.0], scale=0.2, size=(50, 2)).astype(np.float32)
    blob1 = rng.normal(loc=[5.0, 5.0], scale=0.2, size=(50, 2)).astype(np.float32)
    X = np.concatenate([blob0, blob1], axis=0)

    out = kmeans(X, K=2, num_iters=15, seed=0, return_history=True)

    centers = out["centers"]
    history = out["history"]

    if not (history[-1] <= history[0] + 1e-6):
        raise AssertionError("inertia should not increase overall on this simple dataset")

    # Centers should be near [0,0] and [5,5] up to permutation
    # We'll match by sorting centers by x-coordinate.
    centers_sorted = centers[np.argsort(centers[:, 0])]
    _assert_close(centers_sorted[0], np.array([0.0, 0.0], dtype=np.float32), atol=0.5, msg="center near [0,0] off")
    _assert_close(centers_sorted[1], np.array([5.0, 5.0], dtype=np.float32), atol=0.5, msg="center near [5,5] off")

    print("All basic K-means scaffold tests passed ✅")

In [9]:
run_basic_kmeans_tests()

All basic K-means scaffold tests passed ✅


## Solutions

This is what I came up with (2026-02-04):

```python
def pairwise_squared_distances(x: np.ndarray, centers: np.ndarray) -> np.ndarray:
    """
    Compute squared Euclidean distances from each point to each center.

    Args:
      x:       [N, D] float
      centers: [K, D] float

    Returns:
      d2: [N, K] float where d2[n, k] = ||x[n] - centers[k]||^2

    Constraints:
      - Do NOT loop over N or K.
      - Use broadcasting / vectorization.
    """

    diffs = x[:, None, :] - centers[None, :, :]
    squared = np.square(diffs)
    summed = squared.sum(axis=2)
    return summed
```

```python
def assign_clusters(d2: np.ndarray) -> np.ndarray:
    """
    Assign each point to the nearest center.

    Args:
      d2: [N, K] float distances squared

    Returns:
      labels: [N] int, labels[n] in {0..K-1}

    Constraints:
      - No loops over N.
    """
    return d2.argmin(axis=1)
```

```python
def update_centers(x: np.ndarray, labels: np.ndarray, K: int) -> np.ndarray:
    """
    Update centers as the mean of assigned points.

    Args:
      x:      [N, D] float
      labels: [N] int in {0..K-1}
      K:      number of clusters

    Returns:
      new_centers: [K, D] float

    Constraints:
      - No Python loops over N.
      - Prefer vectorized group-by (e.g., one-hot + matmul, or np.add.at).
      - Handle empty clusters deterministically:
          Keep the previous center OR reinitialize (choose one and document it).
          (You will be passed prev_centers via kwargs if you choose keep-previous.)
    """

    D = x.shape[1]
    new_centers = np.zeros((K, D))

    for i in range(K):
        cluster = x[labels == i, :]
        cluster_size = cluster.shape[0]
        if cluster_size != 0:
            new_centers[i, :] = cluster.sum(axis=0) / cluster_size

    return new_centers
```

Seems like the for-loop over `K` is slightly suboptimal. I also don't properly handle the empty cluster case. I should.

This is what ChatGPT came up with:

```python
def update_centers(x: np.ndarray, labels: np.ndarray, K: int, *, prev_centers=None) -> np.ndarray:
    """
    x: [N, D], labels: [N]
    returns: [K, D]
    Empty-cluster policy: keep previous center if provided, else keep as zeros.
    """
    N, D = x.shape
    sums = np.zeros((K, D), dtype=x.dtype)
    counts = np.zeros((K,), dtype=np.int64)

    np.add.at(sums, labels, x)        # sums[k] += x[n] for all n with labels[n]=k
    np.add.at(counts, labels, 1)

    # Avoid divide-by-zero
    new_centers = sums / np.maximum(counts[:, None], 1)

    # Deterministic empty-cluster handling: keep previous if available
    if prev_centers is not None:
        empty = counts == 0
        new_centers[empty] = prev_centers[empty]

    return new_centers

```