<a href="https://colab.research.google.com/github/chaoweii/Complete-Python-3-Bootcamp/blob/master/k_means.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Assuming data is [(x0, y0), (x1, y1), .......]

In [2]:
import random

In [3]:
def random_sample(lb, ub):
  return lb + (ub-lb) * random.random()

In [18]:
def initiate_centroids(data, k):
  x_min = y_min = float('inf')
  x_max = y_max = float('-inf')

  for point in data:
    x_min = min(point[0], x_min)
    x_max = max(point[0], x_max)
    y_min = min(point[1], y_min)
    y_max = max(point[1], y_max)

  centroids = []

  for i in range(k):
    centroids.append((random_sample(x_min, x_max),
                     random_sample(y_min, y_max)))
  
  return centroids

In [19]:
def calc_distance(x, y):
  return ((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2) ** 0.5

In [28]:
def get_labels(data, centroids):
  labels = []

  for point in data:
    min_dist = float('inf')
    label = None
    for i, centroid in enumerate(centroids):
      dist = calc_distance(point, centroid)
      if dist < min_dist:
        min_dist = dist
        label = i
      labels.append(i)

  return labels

In [29]:
def update_centroids(data, labels, k):
  centroids = [[0, 0] for i in range(k)]
  n_points = [0 for i in range(k)]

  for point, label in zip(data, labels):
    centroids[label][0] += point[0]
    centroids[label][1] += point[1]
    n_points[label] += 1
  
  for label, (x, y) in enumerate(centroids):
    centroids[label][0] = x / n_points[label]
    centroids[label][1] = y / n_points[label]

  return centroids
    

In [66]:
def should_stop(centroids, prev_centroids, threshold=1e-5):
  diff = 0
  for centroid, prev_centroid in zip(centroids, prev_centroids):
    diff += calc_distance(centroid, prev_centroid)
  if diff < threshold:
    return True

In [67]:
def kmeans(data, k, max_iter=10000):
  centroids = initiate_centroids(data, k)

  if max_iter:
    i = 1
  # update centroids until convergence
    while i<max_iter:
      prev_centroids = centroids
      labels = get_labels(data, prev_centroids)
      centroids = update_centroids(data, labels, k)

      if should_stop(centroids, prev_centroids):
        break
      i += 1
  
  else:
      while True:
        prev_centroids = centroids
        labels = get_labels(data, prev_centroids)
        centroids = update_centroids(data, labels, k)

        if should_stop(centroids, prev_centroids):
          break
      # assign labels given stable centroids

  return labels, centroids

In [63]:
# Test Case 1
seg1 = [(random.random(), random.random()) for i in range(100)]
seg2 = [(5 + random.random(), 5 + random.random()) for i in range(100)]
points = seg1 + seg2

In [64]:
case1_seg, case1_centroids = kmeans(points, 2, max_iter=False)

In [68]:
# Expect them to be (0.5, 0.5) and (5.5, 5.5)
case1_centroids

[[2.970264290075089, 2.988611390957705],
 [2.9753011359817987, 2.997055145122652]]

In [54]:
points_shuffle = random.shuffle(points)

In [56]:
points_shuffle

In [55]:
case1_seg, case1_centroids = kmeans(points_shuffle, 2, max_iter=False)

TypeError: ignored

In [None]:
# Test Case 1
seg1 = [(random.random(), random.random()) for i in range(100)]
seg2 = [(random.randrange(4, 5), random.randrange(1, 5)) for i in range(100)]
points = seg1 + seg2

In [47]:
points

[(0.3219527766097078, 0.27764702419838416),
 (0.08172259895703593, 0.15445626097548915),
 (0.8477251591005451, 0.5487159066652361),
 (0.3229076898438561, 0.5671730126380002),
 (0.9293209284535524, 0.059135833917233316),
 (0.766772209135096, 0.17321506781530127),
 (0.4679646327388548, 0.3865309611725689),
 (0.8183828518047049, 0.5483484936719898),
 (0.2920441994312043, 0.4247748640959572),
 (0.6671647336302325, 0.1931592165037267),
 (0.2790306713365611, 0.38476668929962354),
 (0.5856461405124052, 0.5515404217028171),
 (0.51654012061361, 0.9119296148342096),
 (0.48864504040762113, 0.803772335885906),
 (0.499744617230959, 0.9721465645531927),
 (0.2687789092145042, 0.38005333043769585),
 (0.10303956657999902, 0.7508721794240741),
 (0.16490551539950327, 0.9794496881688972),
 (0.3688315677454983, 0.6077075876127552),
 (0.8733539531081431, 0.3291399784666743),
 (0.8560199659852179, 0.9746096859776783),
 (0.8951154408370732, 0.7063956278933833),
 (0.5108320731631972, 0.9853029835438368),
 (0.8