# K-means
Simple, interactive example of the idea behind k-means algorithm

## Load libraries

In [1]:
%matplotlib nbagg
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from IPython.display import display

## Functions
### Generate random data

In [2]:
STD = 0.2
N = 10

def get_random_data():
    x1 = np.random.normal(3, STD, N)
    y1 = np.random.normal(3, STD, N)
    x2 = np.random.normal(4, STD, N)
    y2 = np.random.normal(1, STD, N)
    x3 = np.random.normal(2.5, STD, N)
    y3 = np.random.normal(1, STD, N)

    x = np.concatenate((x1, x2, x3))
    y = np.concatenate((y1, y2, y3))
    colours = len(x) * ['black']
    return pd.DataFrame({'x': x, 'y': y, 'colour': colours})

### Generate initial positions of the centres

In [3]:
def get_random_centres():
    red = [np.random.uniform(1.5, 4.5), np.random.uniform(0.5, 3.5)]
    green = [np.random.uniform(1.5, 4.5), np.random.uniform(0.5, 3.5)]
    blue = [np.random.uniform(1.5, 4.5), np.random.uniform(0.5, 3.5)]
    return red, green, blue

### Simple plotting functions

In [4]:
def plot_points(data):
    plt.clf()
    plt.scatter(data.x, data.y, marker='o', c=data.colour)
    plt.show()
    
def plot_points_with_centres(data, red, green, blue):
    plt.clf()
    plt.scatter(data.x, data.y, marker='o', c=data.colour)
    plt.scatter(*red, marker='x', c='red', s=200)
    plt.scatter(*green, marker='x', c='green', s=200)
    plt.scatter(*blue, marker='x', c='blue', s=200)
    plt.show()

### Generate data and centres

In [5]:
data = get_random_data()
red, green, blue = get_random_centres()

### Functions to bind to the widgets

In [6]:
def calculate_colours(b):
    global data
    global red
    global green
    global blue
    data['colour'] = data.apply(get_colour, axis=1, args=(red, green, blue))
    plot_points_with_centres(data, red, green, blue)

def calculate_means(b):
    global data
    global red
    global green
    global blue
    
    red = np.array((
        get_mean(data['x'][data['colour'] == 'red']),
        get_mean(data['y'][data['colour'] == 'red'])
    ))
    
    green = np.array((
        get_mean(data['x'][data['colour'] == 'green']), 
        get_mean(data['y'][data['colour'] == 'green'])
    ))
    
    blue = np.array((
        get_mean(data['x'][data['colour'] == 'blue']), 
        get_mean(data['y'][data['colour'] == 'blue'])
    ))
    
    plot_points_with_centres(data, red, green, blue)
    
def generate_new_data(b):
    global data
    global red
    global green
    global blue
    data = get_random_data()
    red, green, blue = get_random_centres()

def get_mean(x):
    m = np.mean(x)
    if np.isnan(m):
        return np.random.uniform(0, 5)
    return m
    
def get_colour(row, red, green, blue):
    p = np.array((row['x'], row['y']))
    d = {}
    d[np.linalg.norm(p - red)] = 'red'
    d[np.linalg.norm(p - green)] = 'green'
    d[np.linalg.norm(p - blue)] = 'blue'
    return d[min(d.keys())]

def plot_initial(b):
    global data
    global red
    global green
    global blue
    plot_points_with_centres(data, red, green, blue)
    
    

### Widgets with plots
1. `New data set` - generate new data
1. `Plot initial` - plot the data and centres
1. `Calculate colours` - assign each point to a group
1. `Move centres` - move the centres to the centres of mass of each group

After generating new data you can plot it. Afterwards you shoud calculate new colours, move centres, calculate new colours, ... untill you end up in the equilibrium state.

In [7]:
btn_new_data = widgets.Button(description="New data set")
btn_plot_initial = widgets.Button(description="Plot initial")
btn_calculate_colours = widgets.Button(description="Calculate colurs")
btn_calculate_means = widgets.Button(description="Move centres")


btn_new_data.on_click(generate_new_data)
btn_plot_initial.on_click(plot_initial)
btn_calculate_colours.on_click(calculate_colours)
btn_calculate_means.on_click(calculate_means)

widgets.HBox([btn_new_data, btn_plot_initial, btn_calculate_colours, btn_calculate_means])
