In [1]:
from time import time
from numba import jit, njit, prange, float32
import numpy as np
from joblib import Memory
from sklearn.cluster.k_means_ import _k_init
from sklearn.datasets import make_blobs
from sklearn.utils import check_random_state

In [2]:
m = Memory(location='/tmp/joblib')
make_blobs = m.cache(make_blobs)
_k_init = m.cache(_k_init)
n_clusters = 1000
rng = check_random_state(42)
kmeanspp_size = int(1e4)

data, true_labels = make_blobs(n_samples=int(1e5), centers=100,
                               n_features=100, cluster_std=30,
                               random_state=rng)
data = data.astype(np.float32)
data_squared_norms = np.sum(data[:kmeanspp_size] * data[:kmeanspp_size],
                            axis=1)
centroids = _k_init(data[:kmeanspp_size], n_clusters,
                    data_squared_norms, rng)

In [3]:
@njit('void(f4[:, ::1], f4[:, ::1], f4[:, ::1], u4[::1])',
      locals={'best_dist': float32, 'dist': float32},
      fastmath=True,
#       parallel=True,
)
def kmeans_kernel(data, centroids, centroids_sum, centroids_pop):
    n_samples, n_features = data.shape
    n_centroids = centroids.shape[0]
    for i in range(n_samples):
        best_dist = 1e7
        best_j = 0
        for j in range(n_centroids):
            dist = 0.
            for k in range(n_features):
                dist += (data[i, k] - centroids[j, k]) ** 2
            if dist < best_dist:
                best_dist = dist
                best_j = j
        for k in range(data.shape[1]):
            centroids_sum[best_j, k] += data[i, k]
        centroids_pop[best_j] += 1

In [4]:
data_chunk = data[:1000]
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.uint32)

In [5]:
%%timeit
kmeans_kernel(data_chunk, centroids, centroids_sum, centroids_pop)

13.3 ms ± 1.94 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


The most arithmetic intensive part is the nested for loop that computes the distance between one sample and one centroid. Each iteration of that loop has 3 floating point operations:

- one difference
- one multiplication (with self to compute a square)
- one addition (accumulation in the dist variable)

In [6]:
n_flop = 3 
n_samples, n_features = data_chunk.shape
n_centroids = centroids.shape[0]
duration = 0.013  # measured by timeit

gflop = (n_samples * n_centroids * n_features * 3) / 1e9
print(f"{gflop / duration:0.3} GFLOP/s")

25.0 GFLOP/s


~ 23 GFLOP/s on a single thread is very good. On this CPU (skylake), I think this is rather close to peak performance.

However the same experiment with n_features=2 to n_features=30 gives much lower results from 2 to 5 GFLOPS. Note that we did not count the book-keeping of `best_dist` though and it starts to be relatively important for low values of `n_features`.

In [7]:
# kmeans_kernel.inspect_types(pretty=True)

In [8]:
len(kmeans_kernel.inspect_asm())

1

If everything went well, the generated assembly should include lines with `vfmadd231ps` on `ymm` 256-bit avx registries meaning that one such instruction can pack 8 x (one multiplication and one addition) on float32 values (each `ymm` registry can pack 8 float32 values).

The highly arithmetic calculation that can benefit from `vfmadd231ps` is in the python source:

```python
    dist += (data[i, k] - centroids[j, k]) ** 2
```



In [9]:
# print(next(iter(kmeans_kernel.inspect_asm().values())))

In [10]:
@jit('u4(f4[:, ::1], f4[:, ::1], u4, u4)',
     fastmath=True, parallel=True)
def kmeans(data, centroids, chunk_size=1000, n_iter=10):
    n_chunks, remainder = divmod(data.shape[0], chunk_size)
    if remainder:
        n_chunks += 1

    centroids_sum = np.zeros_like(centroids)
    centroids_pop = np.zeros(centroids.shape[0], dtype=np.uint32)

    for iteration in range(n_iter):
        for i in prange(n_chunks):
            data_chunk = data[i * chunk_size:(i + 1) * chunk_size]
            kmeans_kernel(data_chunk, centroids, centroids_sum, centroids_pop)
        # TODO: parallel accumulate in nopython mode?
        centroids[:] = centroids_sum / centroids_pop[:, np.newaxis]
    return n_chunks

In [11]:
%%timeit
kmeans(data, centroids, chunk_size=1000, n_iter=1)

1.2 s ± 16.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
