In [None]:
import dask.array as da
import numpy as np
from distributed import Client

def init(X, k):
    ''' Choose the values of the means for the first iteration'''
    return X[np.random.choice(range(X.shape[0]), size=k)]
    
def assign(X, means, x_chunks, means_chunks):
    '''Assign closest mean to the points'''
    return euclidean_distance(X, means, x_chunks, means_chunks).argmin(axis=1)

def update(X, labels):
    '''Update the means by using the labels computed by assign'''
    Y = X.to_dask_dataframe()
    return Y.groupby(labels.to_dask_dataframe()).mean().values


def euclidean_distance(X, Y, x_chunks, y_chunks):
    '''X, a matrix (n,p), Y a matrix (m,p), returns an (n,m) distance matrix'''
    X_square = (X * X).sum(axis=1, keepdims=True) * da.ones((1, Y.shape[0]), chunks=y_chunks)
    Y_square = da.ones((X.shape[0], 1), chunks=(x_chunks)) * (Y * Y).sum(axis=1, keepdims=True).transpose()
    return X_square - 2 * X.dot(y.transpose()) + Y_square

def main(X,k,max_iter,x_chunks,means_chunks,cores):
    client = Client()
    means = init(X,k)
    for i in range(max_iter):
        old_means = means
        labels = assign(X,means, x_chunks, means_chunks)
        means = client.compute(update(X,labels),num_workers=cores)
        if means == old_means:
            break
    return (means,labels)

In [None]:
client = Client()

In [None]:
#test
size = 1000
cores = 2
x_chunks = (100,100)
means_chunks = x_chunks

k = 5
max_iter = 100
X = da.random.random((size,size),chunks=x_chunks)

means,labels = main(X,k,max_iter,x_chunks,means_chunks,cores)

In [None]:
means