In [1]:

from pyspark.mllib.clustering import KMeans
from numpy import array, random
from math import sqrt
from pyspark import SparkConf, SparkContext
from sklearn.preprocessing import scale

In [2]:
K = 5

In [3]:

conf = SparkConf().setMaster('local').setAppName('spark_k_means')
sc = SparkContext(conf = conf)

In [4]:

# Generate fake income/age clusters for N people in k clusters
def gen_clustered_data(N, k):
    random.seed(105)
    points_per_cluster = float(N)/k
    X = []
    for i in range(k):
        income_centroid = random.uniform(20000.0, 200000.0)
        age_centroid = random.uniform(20.0, 70.0)
        for j in range(int(points_per_cluster)):
            X.append([random.normal(income_centroid, 10000.0), random.normal(age_centroid, 2.0)])
    
    X = array(X)
    return X

In [5]:

# Load the data into Spark
# NB: VERY IMPORTANT to normalise it with scale() before clustering
data = sc.parallelize(scale(gen_clustered_data(100, K)))

In [6]:

# Build model i.e. cluster the data
clusters = KMeans.train(data, K, maxIterations=10, runs=10, initializationMode='random')



In [7]:

# Print out the cluster assignments
result_RDD = data.map(lambda x: clusters.predict(x)).cache()

counts = result_RDD.countByValue()
print ("Counts by value: ", counts)

results = result_RDD.collect()
print ("Cluster assignments: ", results)

Counts by value:  defaultdict(<class 'int'>, {1: 10, 2: 14, 4: 20, 0: 36, 3: 20})
Cluster assignments:  [1, 2, 1, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]


In [10]:

# Evaluate clustering by computing Within-Set-Sum-of_Squared_Errors
def error(point):
    centre = clusters.centers[clusters.predict(point)]
    return sqrt(sum([x**2 for x in (point - centre)]))

In [11]:

WSSSE = data.map(lambda point: error(point)).reduce(lambda x, y: (x+y))
print ("Within Set Sum of Squared Error = ", WSSSE)

Within Set Sum of Squared Error =  28.17351362191744
