# Implementing KMeans (optimized version)

In [1]:
from __future__ import print_function
import math
from collections import namedtuple

## Parameters

In [2]:
# Number of clusters to find
K = 5
# Convergence threshold
THRESHOLD = 0.1
# Maximum number of iterations
MAX_ITERS = 20

## Load data

In [3]:
def parse_coordinates(line):
    fields = line.split(',')
    return (float(fields[3]), float(fields[4]))

In [4]:
data = sc.textFile('datasets/locations')

In [5]:
points = data.map(parse_coordinates)

Let's **cache the points** because the algorithm will be reusing them:

In [6]:
points.cache()

PythonRDD[2] at RDD at PythonRDD.scala:53

## Useful functions

In [7]:
def distance(p1, p2):  
    "Calculate the squared distance between two given points"
    return (p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2

def closest_centroid(point, centroidsBC):    
    "Calculate the closest centroid to the given point: eg. the cluster this point belongs to"
    distances = [distance(point, c) for c in centroidsBC.value]
    shortest = min(distances)
    return distances.index(shortest)

def add_points(p1,p2):
    "Add two points of the same cluster in order to calculate later the new centroids"
    return [p1[0] + p2[0], p1[1] + p2[1]]

## Iteratively calculate the centroids

In [8]:
%%time
# Initial centroids: we just take K randomly selected points
centroids = points.takeSample(False, K, 42)
# Broadcast var
centroidsBC = sc.broadcast(centroids)

# Just make sure the first iteration is always run
variation = THRESHOLD + 1
iteration = 0

while variation > THRESHOLD  and iteration < MAX_ITERS:
     # Map each point to (centroid, (point, 1))
    with_centroids = points.map(lambda p : (closest_centroid(p, centroidsBC), (p, 1)))
    # For each centroid reduceByKey adding the coordinates of all the points
    # and keeping track of the number of points
    cluster_stats = with_centroids.reduceByKey(lambda (p1, n1), (p2, n2):  (add_points(p1, p2), n1 + n2))
    # For each existing centroid find the new centroid location calculating the average of each closest point
    new_centroids = cluster_stats.map(lambda (c, ((x, y), n)): (c, [x/n, y/n])).collect()
    # Calculate the variation between old and new centroids
    variation = 0
    for  (c, point) in new_centroids: variation += distance(centroids[c], point)
    print('Variation in iteration {}: {}'.format(iteration, variation))
    # Replace old centroids with the new values
    for (c, point) in new_centroids: centroids[c] = point
    # Replace the centroids broadcast var with the new values
    centroidsBC = sc.broadcast(centroids)
    iteration += 1
        
print('Final centroids: {}'.format(centroids))

Variation in iteration 0: 4989.32008451
Variation in iteration 1: 2081.17551268
Variation in iteration 2: 1.6011620119
Variation in iteration 3: 2.55059475168
Variation in iteration 4: 0.994848416636
Variation in iteration 5: 0.0381850235415
Final centroids: [[35.08592000544936, -112.57643826547803], [0.0, 0.0], [38.05200414101911, -121.20324355675143], [43.891507710205694, -121.32350131512835], [34.28939789970032, -117.77840744773651]]
CPU times: user 84 ms, sys: 26.9 ms, total: 111 ms
Wall time: 18.3 s
