In [22]:
'''
Created on May 6, 2013
Last Update on Oct 9, 2014

Wrap k-means (implemented by sklearn) clustering algorithm .

@author: HeXiangnan (xiangnan@comp.nus.edu.sg)
'''

from sklearn.cluster import KMeans
from scipy.sparse import csr_matrix
from sklearn.preprocessing import normalize

def norm_data(data, norm):
    """
    norm = 'l1', 'l2' or 'l0' or 'l2+l0'...
        'l0': normalize the data matrix as the sum of all entries=1
    """
    data_norm = data
    norms = norm.split("+")  # the sequence of norm
    for norm in norms:
        if norm in ['l1', 'l2']:
            data_norm = normalize(data_norm, norm, axis=1, copy=True)  # don't change the original input data
        if norm == "l0":
            _sum = data.sum()
            data_norm = data_norm / _sum

    return data_norm

def kmeans(data, k, norm="l2", n_init = 1):
    """
    data: matrix, #item * #feature
    """
    if norm == None:
        km_model = KMeans(n_clusters=k, init='random', max_iter=500, n_init = n_init, verbose = False)
        km_model.fit(data)
        return km_model.labels_, km_model.cluster_centers_
        
    data_norm = norm_data(data, norm)
    km_model = KMeans(n_clusters=k, init='random', max_iter=500, n_init = n_init, verbose = False)
    km_model.fit(data_norm)
    # km_model.cluster_centers_ is k*N of <type 'numpy.ndarray'>
    # H: converted km_model.cluster_centers_ to csr_matrix, shape: k*N
    H = csr_matrix(km_model.cluster_centers_)
    H = H.todense()
    H = H + 0.1 # Add a small number to each element of the centroid matrix
    H_norm = norm_data(H, norm)
    
    return km_model.labels_, H_norm

In [23]:
import numpy as np
import pandas as pd

In [24]:
data = pd.read_csv('datasets/kmeans_data/data.csv', header=None)
labels = pd.read_csv('datasets/kmeans_data/label.csv', header=None)
print('Data: ', data.shape)
print('Labels: ', labels.shape)

Data:  (10000, 784)
Labels:  (10000, 1)


In [25]:
result = kmeans(data=data, k=5)
result



(array([4, 0, 2, ..., 4, 3, 1]),
 array([[0.03055375, 0.03055375, 0.03055375, ..., 0.03055375, 0.03055375,
         0.03055375],
        [0.03065912, 0.03065912, 0.03065912, ..., 0.03065912, 0.03065912,
         0.03065912],
        [0.03141772, 0.03141772, 0.03141772, ..., 0.03141772, 0.03141772,
         0.03141772],
        [0.03093121, 0.03093121, 0.03093121, ..., 0.03093121, 0.03093121,
         0.03093121],
        [0.03093815, 0.03093815, 0.03093815, ..., 0.03093815, 0.03093815,
         0.03093815]]))