## K-means Clustering

In [1]:
# http://benalexkeen.com/k-means-clustering-in-python/

In [11]:
import pandas as pd
import numpy as np

In [36]:
def assignment(df, centroids):
    for i in centroids.keys():
        # sqrt((x1 - x2)^2 - (y1 - y2)^2)
        df['distance_from_{}'.format(i)] = (np.sqrt((df['x'] - centroids[i][0]) ** 2 
                                                    + (df['y'] - centroids[i][1]) ** 2))
    centroid_distance_cols = ['distance_from_{}'.format(i) for i in centroids.keys()]
    df['closest'] = df.loc[:, centroid_distance_cols].idxmin(axis=1)
    df['closest'] = df['closest'].map(lambda x: int(x.lstrip('distance_from_')))
    
    return df

In [37]:
def update(centroids, data):
    for i in centroids.keys():
        centroids[i][0] = np.mean(data[data['closest'] == i]['x'])
        centroids[i][1] = np.mean(data[data['closest'] == i]['y'])
    return centroids

In [38]:
def k_means(data, k):
    # Initialisation – K initial “means” (centroids) are generated at random
    centroids = {
        i+1: [np.random.randint(0, 80), np.random.randint(0, 80)]
        for i in range(k)
                }
    data = assignment(data, centroids) 
    
    while True:
        closest_centroids = data['closest'].copy(deep=True)
        centroids = update(centroids,data)
        data = assignment(data, centroids)
        if closest_centroids.equals(data['closest']):
            break
    return data

In [39]:
df = pd.DataFrame({
    'x': [12, 20, 28, 18, 29, 33, 24, 45, 45, 52, 51, 52, 55, 53, 55, 61, 64, 69, 72],
    'y': [39, 36, 30, 52, 54, 46, 55, 59, 63, 70, 66, 63, 58, 23, 14, 8, 19, 7, 24]
})

In [40]:
k_means(df, 3)

Unnamed: 0,x,y,distance_from_1,distance_from_2,distance_from_3,closest
0,12,39,45.033629,12.714286,55.408834,2
1,20,36,40.472556,9.231711,46.891423,2
2,28,30,39.799846,15.271689,37.141247,2
3,18,52,33.892395,9.20071,57.214266,2
4,29,54,22.913485,10.951656,50.673519,2
5,33,46,24.159769,9.677451,42.07698,2
6,24,55,27.252421,10.444215,54.803943,2
7,45,59,6.508541,25.952075,46.516723,1
8,45,63,5.002777,28.371443,50.25076,1
9,52,70,7.120003,38.248383,55.1435,1
