In [1]:
from time import time
from numba import jit, njit, prange, float32
import numpy as np
from joblib import Memory
from sklearn.cluster.k_means_ import _k_init
from sklearn.datasets import make_blobs
from sklearn.utils import check_random_state

In [2]:
m = Memory(location='/tmp/joblib')
make_blobs = m.cache(make_blobs)
_k_init = m.cache(_k_init)
n_clusters = 1000
rng = check_random_state(42)
kmeanspp_size = int(1e4)

data, true_labels = make_blobs(n_samples=int(1e5), centers=100,
                               n_features=100, cluster_std=30,
                               random_state=rng)
data = data.astype(np.float32)
data_squared_norms = np.sum(data[:kmeanspp_size] * data[:kmeanspp_size],
                            axis=1)
centroids = _k_init(data[:kmeanspp_size], n_clusters,
                    data_squared_norms, rng)

In [3]:
@njit('void(f4[:, ::1], f4[:, ::1], f4[:, ::1], u4[::1])',
      locals={'best_dist': float32, 'dist': float32},
      fastmath=True,
#       parallel=True,
)
def kmeans_kernel(data, centroids, centroids_sum, centroids_pop):
    n_samples, n_features = data.shape
    n_centroids = centroids.shape[0]
    for i in range(n_samples):
        best_dist = 1e7
        best_j = 0
        for j in range(n_centroids):
            dist = 0.
            for k in range(n_features):
                dist += (data[i, k] - centroids[j, k]) ** 2
            if dist < best_dist:
                best_dist = dist
                best_j = j
        for k in range(data.shape[1]):
            centroids_sum[best_j, k] += data[i, k]
        centroids_pop[best_j] += 1

In [4]:
data_chunk = data[:1000]

In [5]:
%%timeit
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.uint32)
kmeans_kernel(data_chunk, centroids, centroids_sum, centroids_pop)

13.6 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


The most arithmetic intensive part is the nested for loop that computes the distance between one sample and one centroid. Each iteration of that loop has 3 floating point operations:

- one difference
- one multiplication (with self to compute a square)
- one addition (accumulation in the dist variable)

In [6]:
n_flop = 3 
n_samples, n_features = data_chunk.shape
n_centroids = centroids.shape[0]
duration = 0.012  # measured by timeit

gflop = (n_samples * n_centroids * n_features * 3) / 1e9
print(f"{gflop / duration:0.3} GFLOP/s")

25.0 GFLOP/s


~ 23 GFLOP/s on a single thread is very good. On this CPU (skylake), I think this is rather close to peak performance.

However the same experiment with n_features=2 to n_features=30 gives much lower results from 2 to 5 GFLOPS. Note that we did not count the book-keeping of `best_dist` though and it starts to be relatively important for low values of `n_features`.

In [7]:
# kmeans_kernel.inspect_types(pretty=True)

In [8]:
len(kmeans_kernel.inspect_asm())

1

If everything went well, the generated assembly should include lines with `vfmadd231ps` on `ymm` 256-bit avx registries meaning that one such instruction can pack 8 x (one multiplication and one addition) on float32 values (each `ymm` registry can pack 8 float32 values).

The highly arithmetic calculation that can benefit from `vfmadd231ps` is in the python source:

```python
    dist += (data[i, k] - centroids[j, k]) ** 2
```



In [9]:
# print(next(iter(kmeans_kernel.inspect_asm().values())))

In [10]:
@jit('u4(f4[:, ::1], f4[:, ::1], u4, u4)',
     fastmath=True, parallel=True)
def kmeans(data, centroids, chunk_size=1000, n_iter=10):
    n_chunks, remainder = divmod(data.shape[0], chunk_size)
    if remainder:
        n_chunks += 1

    centroids_sum = np.zeros_like(centroids)
    centroids_pop = np.zeros(centroids.shape[0], dtype=np.uint32)

    for iteration in range(n_iter):
        for i in prange(n_chunks):
            data_chunk = data[i * chunk_size:(i + 1) * chunk_size]
            kmeans_kernel(data_chunk, centroids, centroids_sum, centroids_pop)
        # TODO: parallel accumulate in nopython mode?
        centroids[:] = centroids_sum / centroids_pop[:, np.newaxis]
    return n_chunks

In [11]:
%%timeit
kmeans(data, centroids, chunk_size=1000, n_iter=1)

1.23 s ± 59.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%load_ext cython
import os
os.environ["CC"] = 'gcc-8'

In [13]:
%%cython -c=-Ofast -c=-march=native -f
#cython: initializedcheck=False, nonecheck=False, boundscheck=False, wraparound=False, cdivision=True, overflowcheck=False, overflowcheck.fold=False
import numpy as np
cimport numpy as np
cimport cython
from scipy.linalg.cython_blas cimport sgemm

ctypedef fused floating_t:
    np.float32_t
    np.float64_t

# numpy array
cpdef void kmeans_chunk_np(np.ndarray[floating_t, ndim=2, mode='c'] X_chunk,
                           np.ndarray[floating_t, ndim=2, mode='c'] C,
                           np.ndarray[floating_t, ndim=2, mode='c'] sums,
                           np.ndarray[np.int32_t, ndim=1, mode='c'] pops):
    cdef:
        Py_ssize_t n_samples_chunk = X_chunk.shape[0],
        Py_ssize_t n_clusters = C.shape[0],
        Py_ssize_t n_features = C.shape[1]
        
        floating_t x, sq_dist, min_sq_dist = 0.0
        Py_ssize_t best_cluster = -1

        Py_ssize_t si, ci, fi = 0

    for si in xrange(n_samples_chunk):
        min_sq_dist = 10000000
        best_cluster = -1
        for ci in xrange(n_clusters):
            sq_dist = 0.0
            for fi in xrange(n_features):
                x = X_chunk[si, fi] - C[ci, fi]
                sq_dist += x * x
            if sq_dist < min_sq_dist:
                min_sq_dist = sq_dist
                best_cluster = ci

        pops[best_cluster] += 1 
        for fi in xrange(n_features):  
            sums[best_cluster, fi] += X_chunk[si, fi]

# pointer   
cdef void kmeans_chunk_ptr(floating_t *X_chunk,
                           floating_t *C,
                           floating_t *sums,
                           np.int32_t *pops,
                           Py_ssize_t n_samples_chunk,
                           Py_ssize_t n_clusters,
                           Py_ssize_t n_features) nogil:
    cdef:
        floating_t x, sq_dist, min_sq_dist = 0.0
        np.int32_t best_cluster = -1

        Py_ssize_t si, ci, fi = 0

    for si in xrange(n_samples_chunk):
        min_sq_dist = 10000000
        best_cluster = -1
        for ci in xrange(n_clusters):
            sq_dist = 0.0
            for fi in xrange(n_features):
                x = X_chunk[si * n_features + fi] - C[ci * n_features + fi]
                sq_dist += x * x
            if sq_dist < min_sq_dist:
                min_sq_dist = sq_dist
                best_cluster = ci

        pops[best_cluster] += 1             
        for fi in xrange(n_features):    
            sums[best_cluster * n_features + fi] += X_chunk[si * n_features + fi]
        

cpdef kmeans_chunk_ptrw(floating_t[:, ::1] X_chunk,
                        floating_t[:, ::1] C,
                        floating_t[:, ::1] sums,
                        np.int32_t[::1] pops):
    cdef:
        Py_ssize_t n_samples_chunk = X_chunk.shape[0]
        Py_ssize_t n_features = X_chunk.shape[1]
        Py_ssize_t n_clusters = C.shape[0]
        
    kmeans_chunk_ptr(&X_chunk[0,0], &C[0,0], &sums[0,0], &pops[0], n_samples_chunk, n_clusters, n_features)
    
    
# memoryview          
cpdef void kmeans_chunk_mv(floating_t[:, ::1] X_chunk,
                           floating_t[:, ::1] C,
                           floating_t[:, ::1] sums,
                           np.int32_t[::1] pops) nogil:
    cdef:
        Py_ssize_t n_samples_chunk = X_chunk.shape[0]
        Py_ssize_t n_clusters = C.shape[0]
        Py_ssize_t n_features = X_chunk.shape[1]

        floating_t x, sq_dist, min_sq_dist = 0.0
        Py_ssize_t best_cluster = -1

        Py_ssize_t si, ci, fi = 0

    for si in xrange(n_samples_chunk):
        min_sq_dist = 10000000
        best_cluster = -1
        for ci in xrange(n_clusters):
            sq_dist = 0.0
            for fi in xrange(n_features):
                x = X_chunk[si, fi] - C[ci, fi]
                sq_dist += x * x
            if sq_dist < min_sq_dist:
                min_sq_dist = sq_dist
                best_cluster = ci
                
        pops[best_cluster] += 1          
        for fi in xrange(n_features):      
            sums[best_cluster, fi] += X_chunk[si, fi]
            

# memoryview 2    
cpdef void kmeans_chunk_gemm(np.float32_t[:, ::1] X_chunk,
                             np.float32_t[:, ::1] C,
                             np.float32_t[::1] C_snorms,
                             np.float32_t[:, ::1] sums,
                             np.int32_t[::1] pops):
    cdef:
        int n_samples_chunk = X_chunk.shape[0]
        int n_clusters = C.shape[0]
        int n_features = X_chunk.shape[1]

        np.float32_t x, sq_dist, min_sq_dist = 0.0
        int best_cluster = -1

        int si, ci, fi, fc= 0
        
        np.float32_t[:, ::1] dots = np.zeros((n_samples_chunk, n_clusters), dtype=np.float32)
        
        np.float32_t alpha = 1.0
        np.float32_t beta = 0.0
        char *transa = 'n'
        char *transb = 't'
        np.float32_t *a0=&X_chunk[0,0]
        np.float32_t *b0=&C[0,0]
        np.float32_t *c0=&dots[0,0]
        
    sgemm(transb, transa, &n_clusters, &n_samples_chunk, &n_features, &alpha, b0, &n_features, a0, &n_features, &beta, c0, &n_clusters)
        
    for si in xrange(n_samples_chunk):
        min_sq_dist = 10000000.0
        best_cluster = -1
        for ci in xrange(n_clusters):
            sq_dist = C_snorms[ci] + dots[si, ci]
            if sq_dist < min_sq_dist:
                min_sq_dist = sq_dist
                best_cluster = ci
                
        pops[best_cluster] += 1          
        for fi in xrange(n_features):      
            sums[best_cluster, fi] += X_chunk[si, fi]

In [14]:
data = data.astype(np.float32)
centroids = centroids.astype(np.float32)
centroids_snorms = (centroids**2).sum(axis=1) * -0.5
data_chunk = data[:1000]

In [15]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_np(data_chunk,
                centroids,
                centroids_sum,
                centroids_pop)

16.4 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


numpy ndarray

In [16]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_ptrw(data_chunk,
                  centroids,
                  centroids_sum,
                  centroids_pop)

16.8 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


pointer

In [17]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_mv(data_chunk,
                centroids,
                centroids_sum,
                centroids_pop)

19.6 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


memoryview

In [23]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_gemm(data_chunk,
                 centroids,
                 centroids_snorms,
                 centroids_sum,
                 centroids_pop)

3.63 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


3.5ms with gemm !

n_samples * n_clusters * n_features * 2 = 2.10⁸ flop

~ 55 Gflops

In [None]:
data = data.astype(np.float64)
centroids = centroids.astype(np.float64)
data_chunk = data[:1000]

In [None]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_np(data_chunk,
                centroids,
                centroids_sum,
                centroids_pop)

In [None]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_ptrw(data_chunk,
                  centroids,
                  centroids_sum,
                  centroids_pop)

In [None]:
%%timeit -n 100
centroids_sum = np.zeros_like(centroids)
centroids_pop = np.zeros(centroids.shape[0], dtype=np.int32)
kmeans_chunk_mv(data_chunk,
                centroids,
                centroids_sum,
                centroids_pop)