In [3]:
from flytracker.utils import FourArenasQRCodeMask
from torch.utils.data import DataLoader
from flytracker.dataset import VideoDataset

from sklearn.cluster import KMeans

import jax
from jax import lax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from functools import partial

In [4]:
# Getting image sample
mask = FourArenasQRCodeMask().mask
path = "/home/gert-jan/Documents/flyTracker/data/movies/4arenas_QR.h264"

dataset = VideoDataset(path, mask)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)
iterator = iter(loader)
data = iterator.next()
data = data.numpy().squeeze()

In [5]:
# First we standardize data
X = (data - jnp.mean(data, axis=0)) / jnp.std(data, axis=0)
n_clusters = 40 
n_samples = X.shape[0]

In [6]:
X.shape

(2549, 2)

In [7]:
# Let's set a common starting point
init = KMeans(n_clusters=n_clusters, n_init=1, algorithm='full', max_iter=1, random_state=20).fit(X).cluster_centers_

In [8]:
cluster_centers = init

In [9]:
%%time
# Let's use sklearn api to get a baseline for speed etc
# Just 1 initiliazation etc to get the implementation right
kmeans = KMeans(n_clusters=40, n_init=1, algorithm='full', init=init).fit(X).cluster_centers_

CPU times: user 51.9 ms, sys: 0 ns, total: 51.9 ms
Wall time: 3.3 ms


In [25]:
@jax.jit
def cdist(x, y):
    return jnp.linalg.norm(x[:, None, :] - y[None, :, :], axis=-1)

@partial(jax.jit, static_argnums=(0))
def kmeans_step(X, mu):  
    n_clusters = mu.shape[0]
    n_samples = X.shape[0]

    # E step
    dist_matrix = cdist(X, mu)
    labels = jnp.argmin(dist_matrix, axis=1)
    
    # Mstep
    M = jnp.zeros((n_clusters, n_samples))
    M = jax.ops.index_update(M, (labels, jnp.arange(n_samples)), 1.0)
    new_mu= (M  / jnp.sum(M, axis=1, keepdims=True)) @ X
    return new_mu, labels

In [26]:
step = partial(kmeans_step, X)

In [27]:
step(cluster_centers);

In [28]:
%timeit step(cluster_centers)

85.6 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [50]:
%%time
cluster_centers = init
cluster_centers, old_cluster_centers = step(cluster_centers)[0], cluster_centers

CPU times: user 87 ms, sys: 48.1 ms, total: 135 ms
Wall time: 203 ms


In [62]:
%timeit step(cluster_centers)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [44]:
@jax.jit
def cdist(x, y):
    return jnp.linalg.norm(x[:, None, :] - y[None, :, :], axis=-1)

@jax.jit
def kmeans(X, init):
    def cond_fun(carry):
        new_centers, old_centers = carry
        return jnp.linalg.norm(new_centers - old_centers) > 1e-4
    
    def kmeans_step(X, mu):  
        # E step
        dist_matrix = cdist(X, mu)
        labels = jnp.argmin(dist_matrix, axis=1)

        # Mstep
        M = jnp.zeros((n_clusters, n_samples))
        M = jax.ops.index_update(M, (labels, jnp.arange(n_samples)), 1.0)
        new_mu= (M  / jnp.sum(M, axis=1, keepdims=True)) @ X
        return new_mu, labels

    def body_fun(carry):
        new, _ = carry 
        return step(new)[0], new
    
    n_clusters = init.shape[0]
    n_samples = X.shape[0]
    step = partial(kmeans_step, X)
    
    init_carry = (step(init)[0], init)
    mu, _ = lax.while_loop(cond_fun, body_fun, init_carry)
    _, labels = step(mu)
    return mu, labels

In [45]:
kmeans(X, init);

In [46]:
%%timeit
kmeans(X, init)

997 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
from jax.scipy.