# Lab2: 使用Spark 建立k-means


In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline

## define utility functions

In [None]:
def parseVector2(tuple):
    return np.array([tuple[0], tuple[1]])

def closestPoint(p, centers):
    bestIndex = 0
    closest = float("+inf")
    for i in range(len(centers)):
        tempDist = np.sum((p - centers[i]) ** 2)
        if tempDist < closest:
            closest = tempDist
            bestIndex = i
    return bestIndex

## create an RDD of 90 random vectors

In [None]:
x1 = [random.uniform(-3,-1) for i in range(30)]
y1 = [random.uniform(-0.5,0.5) for i in range(30)]

x2 = [random.uniform(-1,1) for i in range(30)]
y2 = [random.uniform(0.5,1.5) for i in range(30)]

x3 = [random.uniform(1,3) for i in range(30)]
y3 = [random.uniform(-0.5,0.5) for i in range(30)]

plt.axis([-5, 5, -1, 2])
plt.plot(x1,y1,'r*')
plt.plot(x2,y2,'bs')
plt.plot(x3,y3,'gx')

## create data RDD

In [None]:
p1x = sc.parallelize(x1)
p1y = sc.parallelize(y1)
p1 = p1x.zip(p1y)

p2x = sc.parallelize(x2)
p2y = sc.parallelize(y2)
p2 = p2x.zip(p2y)

p3x = sc.parallelize(x3)
p3y = sc.parallelize(y3)
p3 = p3x.zip(p3y)

data = p1.union(p2).union(p3).map(parseVector2).cache()

In [None]:
K = 3
convergeDist = 0.01

kPoints = data.takeSample(False, K, 2)
tempDist = 1.0

kPoints

### Main loop

In [None]:
while tempDist > convergeDist:
    closest = data.map(lambda p: (closestPoint(p, kPoints), (p, 1)))
    pointStats = closest.reduceByKey(lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
    newPoints = pointStats.map(lambda st: (st[0], st[1][0] / st[1][1])).collect()

    tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints)

    for (iK, p) in newPoints:
        kPoints[iK] = p


print("Final centers: " + str(kPoints))

## plot result

In [None]:
plt.axis([-5, 5, -1, 2])
plt.plot(x1,y1,'r*')
plt.plot(x2,y2,'bs')
plt.plot(x3,y3,'gx')

plt.plot(kPoints[0][0],kPoints[0][1],'yo')
plt.plot(kPoints[1][0],kPoints[1][1],'yo')
plt.plot(kPoints[2][0],kPoints[2][1],'yo')