<h1>Distributed K-means</h1>

In [0]:
# load data
data = spark.read.parquet("/mnt/ddscoursedatabricksstg/ddscoursedatabricksdata/random_data.parquet")
display(data)

_1,_2,_3
6.496,6.313,6.15
8.443,8.412,8.417
0.196,0.278,0.383
0.335,0.123,0.496
6.676,6.057,6.773
4.8,4.373,4.259
0.043,0.473,0.797
2.869,2.121,2.051
4.482,4.074,4.932
8.566,8.539,8.698


In [0]:
# set parameters
k = 5
max_iter = 80

In [0]:
def kmeans_fit(data, k, max_iter):
  # imports
  import pyspark.sql.functions as F
  from scipy.spatial import distance
  import numpy as np
  
  # functions
  def equal_centoids(new_c, old_c):
    eps = 10**(-3)   # for float comparing (instead of rounding the results)
    for i in range(k):
      if np.linalg.norm(np.array(new_c[i])-np.array(old_c[i])) > eps:
        return False
    return True
  
  
  data.cache()
  n = data.count()
  
  # choose k seeds as centoids
  centroids = spark.createDataFrame(data.rdd.takeSample(False, k, seed= 42))
  centroids = centroids.collect()
    
  l = [F.col(i) for i in data.schema.names]  # for use in udf function
  
  i = 0
  while i < max_iter: 
    spark.conf.set("spark.sql.shuffle.partitions", n)
    
    # ---------------------------------------
    @udf
    def cluster_udf(*args):
      # find cluster which its centroid is the closest point to the given point(args)
      dists = [0]*k
      for i in range(k):
        dists[i] = (distance.euclidean(np.array(args),np.array(centroids[i])), i+1)
      return (min(dists))[1]  
  # ---------------------------------------
    
    # assign points to clusters
    clustered_data = data.withColumn("c_id", cluster_udf(*l))
    
    # compute new centroids
    clustered_data = clustered_data.repartition(k, 'c_id')
    new_centroids = clustered_data.groupBy("c_id").avg().sort('c_id').drop('c_id')
    new_centroids = new_centroids.collect()
    
    # check convergence
    if equal_centoids(new_centroids, centroids):
      break
    centroids = new_centroids
    i = i+1
    
  return spark.createDataFrame(centroids, data.schema.names)


In [0]:
# run
res = kmeans_fit(data, k, max_iter)
display(res)

_1,_2,_3
2.499594368999991,2.4999649660000243,2.4999454049999823
5.4998204009987965,5.500029650132301,5.499504803338366
8.500455036853216,8.498855505546683,8.248949580437364
8.498583878391518,8.499608077385913,8.750093865246585
0.5000486420000094,0.4997565690000028,0.5000142620000091
