Skip to content
Go to file

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time

Multilevel Clustering via Wasserstein Means

This is a Python 2 implementation of MWM and MWMS algorithms of Multilevel Clustering via Wasserstein Means (N. Ho, X. Nguyen, M. Yurochkin, H. Bui, V. Huynh, D. Phung); plus implementation of Algorithms 1, 2 and 3 of Fast Computation of Wasserstein Barycenters (M. Cuturi, A. Doucet). Code written by Mikhail Yurochkin.


First compile Cython code in cython_cuturi folder. Install Anaconda and run (on Ubuntu):

cython algos.pyx
python build_ext --inplace

It implemets Algorithm 3 of Cuturi, which is the main computational routine. Implements Algorithm 1 and Algorithm 2 of Cuturi implements our clustering algoritms as a scikit-learn estimator has some simulated examples

Implementation is designed to be used in the interactive mode (e.g. Python IDE like Spyder).

Usage guide

W_means(K=5, K_a=6, k=4, n_iter=10, weight=True, verbose=0, init = 'kmeans', k_init=5, method = 'NC', n_iter_cuturi = [5, 10, 50])


K: number of global clusters

K_a: number of atoms in global clusters (can be list of len(K_a) = K or integer)

k: number of atoms in local clusters (if method = 'LC' - number of atoms in the constraint set)

n_iter: number of iterations of the main algorithm

weight: Wheather to use a scaling factor 1/M when updating local barycenters. True is recommended

verbose: if 1, objective function value will be printed on every iteration and total running time

init: which initialization to use. 'kmeans' is recommended for 3stage kmeans initialization

k_init: used for 'kmeans' initialization of 'LC' method; Recomended - value of k if you were to fit 'NC' method

method: 'NC' for no constraint; 'LC' for local constraint; '3means' for 3 stage kmeans

n_iter_cuturi: Number of iterations to run for Cuturi [algo2; algo1; algo3]


fit(data, truth=None)

data: list of length M (number of groups) of d x Nm arrays (d - dimensionality; Nm - number of points in group m) truth: can be used by score function later

Returns: H_: list of atoms of global barycenters
a_: list of weights of atoms of global barycenters
Y_: list of atoms of local barycenters (equal to len(Z)*[S] for LC method)
b_: lsit of weights of atoms of local barycneters
S_: array of atoms in the constraint set (only if method = 'LC')
labels_: list of label assignments of groups to global clusters


Note: can only be used on fitted object.
Unless truth provided before, provide either true label assignments or true model parameters (later is used for simulations).

Returns: Wasserstein distance to truth if true model parameters are given AMI score if true label assignemtns are given


No description, website, or topics provided.



No releases published


No packages published


You can’t perform that action at this time.