In [None]:
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
import jax
import jax.numpy as jnp

from jax import jit
from jax import random
from jax import vmap

In [None]:
K = 2
VECTORS = 32
VECTOR_LENGTH = 64

In [None]:
@jit
def initialize_centroids(embeddings, k, key):
    """
    This function initializes k centroids randomly.

    Args:
        embeddings (jax.numpy.ndarray): The input embeddings.
        k (int): The number of clusters.
        key (jax.random.PRNGKey): The random key.

    Returns:
        jax.numpy.ndarray: The initialized centroids.
    """
    indices = random.choice(key, jnp.arange(embeddings.shape[0]), shape=(K,), replace=False)
    return jnp.take(embeddings, indices, axis=0)

In [None]:
@jit
def compute_distances(embedding, centroids):
    """
    This function computes the distance from each centroid to an embedding.

    Args:
        embedding (jax.numpy.ndarray): The input embedding.
        centroids (jax.numpy.ndarray): The centroids.

    Returns:
        jax.numpy.ndarray: The distances.
    """
    return jnp.sqrt(jnp.sum((embedding - centroids)**2, axis=-1))

In [None]:
@jit
def assign_clusters(embeddings, centroids):
    """
    This function assigns each embedding to the nearest centroid.

    Args:
        embeddings (jax.numpy.ndarray): The input embeddings.
        centroids (jax.numpy.ndarray): The centroids.

    Returns:
        jax.numpy.ndarray: The cluster assignments for each embedding.
    """
    distances = vmap(compute_distances, in_axes=(0, None))(embeddings, centroids)
    return jnp.argmin(distances, axis=-1)

In [None]:
@jit
def update_centroids(embeddings, assignments, k):
    """
    This function updates the centroids by computing the mean of all embeddings in each cluster.

    Args:
        embeddings (jax.numpy.ndarray): The input embeddings.
        assignments (jax.numpy.ndarray): The cluster assignments for each embedding.
        K (int): The number of clusters.

    Returns:
        jax.numpy.ndarray: The updated centroids.
    """
    def update_centroid(i):
        mask = jnp.equal(assignments, i)
        masked_embeddings = jnp.where(mask[:, None], embeddings, 0)
        return jnp.sum(masked_embeddings, axis=0) / jnp.sum(mask)

    return jax.vmap(update_centroid)(jnp.arange(K))

In [None]:
def kmeans(embeddings, k, num_iters=100, seed=0):
    """
    This function applies the K-Means algorithm to input embeddings.

    Args:
        embeddings (jax.numpy.ndarray): The input embeddings.
        k (int): The number of clusters.
        num_iters (int, optional): The number of iterations to run the K-Means algorithm. Default is 100.
        seed (int, optional): The random seed for centroid initialization. Default is 0.

    Returns:
        tuple: The final centroids and the cluster assignments for each embedding.
    """
    key = random.PRNGKey(seed)
    centroids = initialize_centroids(embeddings, k, key)

    for _ in range(num_iters):
        assignments = assign_clusters(embeddings, centroids)
        centroids = update_centroids(embeddings, assignments, k)

    return centroids, assignments

In [None]:
key = random.PRNGKey(0)
embeddings = []
for i in range(VECTORS):
    embedding = []
    for j in range(VECTOR_LENGTH):
        key, subkey = random.split(key)
        embedding.append(random.uniform(subkey, minval=0, maxval=1))
    embeddings.append(embedding)
embeddings = jnp.array(embeddings)
%timeit -n1 -r1 kmeans(embeddings, 2)