In [1]:
import os, glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import jax.numpy as jnp

In [2]:
# runs the k-means algorithm on pixels from an input image (this could be used for lossy compression)
def kmeans(img: Image.Image, name: str, save_dir: str = "outputs", num_iters: int = 3):
    arr = jnp.array(img, dtype=jnp.float16)
    shape = arr.shape # save original shape to help produce output
    assert len(shape) == 3 and shape[2] == 3, f"Invalid shape: {shape}"

    os.makedirs(save_dir, exist_ok=True)
    # clear any existing versions of this image to avoid confusion
    for f in glob.glob(f"{save_dir}/{name}_*.png"):
        os.remove(f)

    arr = arr.reshape((-1, 3))
    # use more clusters for images with higher variance
    k = int(jnp.linalg.norm(jnp.std(arr, axis=0))) * 3
    print(f"k is {k} for the image {name}")

    # init with random pixels from the image
    means = jnp.array(arr[np.random.choice(arr.shape[0], k, replace=False)], dtype=jnp.float32) # k x 3
    # clusters[i] := the cluster for pixel i where i is an index into the flattened array
    clusters = jnp.zeros(arr.shape[0], dtype=jnp.int32)

    for iter in range(num_iters):
        old_clusters = clusters.copy() # save for comparison

        # find best fit (this op is m x 3, n x 3 -> m x 1 x 3, 1 x n x 3 -> m x n)
        dists = jnp.linalg.norm(jnp.expand_dims(arr, axis=1) - jnp.expand_dims(means, axis=0), axis=2)
        clusters = jnp.argmin(dists, axis=1)
        if jnp.array_equal(clusters, old_clusters):
            print(f"The algorithm converged after {iter} iteration(s)")
            break

        # compute new means
        # cluster_sizes := number of pixels per cluster; we precompute this for the next step
        cluster_sizes = jnp.expand_dims(jnp.zeros(k, dtype=jnp.int32).at[clusters].add(1), axis=1) # k x 1
        # all of the pixels from cluster i get summed up to compute the new mean for cluster i
        # N.B. we do the division step before the summation step to avoid overflow
        means = jnp.zeros_like(means).at[clusters].add(arr / cluster_sizes[clusters])

        # visualize current iteration
        plt.imsave(f"{save_dir}/{name}_{iter}.png", np.array(means[clusters].reshape(shape), dtype=np.uint8))

In [3]:
kmeans(Image.open("example_imgs/ali.webp"), "ali")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


k is 318 for the image pink_fields
