In [84]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import *
#from ipywidgets import HBox, VBox, Label, IntSlider, FloatSlider, Button, Output
from IPython.display import display

def mean_shift(X, bandwidth, max_iter = 5):
    
    new_X = np.copy(X)
    
    for iteration in range(max_iter):
        
        for i, xi in enumerate(new_X):
            
            numer, denom = 0, 0
            
            for xj in X:
                
                dist = euclid_dist(xi, xj)
                
                weight = kernel(dist, bandwidth)
                
                numer += weight * xj
                denom += weight
                
            new_X[i] = numer / denom
            
    return new_X
            
def mean_shift2(X, max_dist, bandwidth, max_iter = 5):
    
    new_X = np.zeros_like(X)
    
    for iteration in range(max_iter):
        
        for i, xi in enumerate(X):
            
            numer, denom = 0, 0
            
            for xj in X:
                
                dist = euclid_dist(xi, xj)
                
                if dist > max_dist: continue
                
                weight = kernel(dist, bandwidth)
                
                numer += weight * xj
                denom += weight
                
            new_X[i] = numer / denom
            
        X = new_X
            
    return X
                

In [101]:
## create a global array to store points
points = np.array([])

def rand_2D_points(num_points, num_centroids, **kwargs):
    
    space = kwargs.get('space', 20)
    noise = kwargs.get('noise' , 0.1)
    
    points    = np.zeros(shape = (num_points, 2))
    centroids = np.zeros(shape = (num_centroids, 2))
    
    np.random.seed()
    
    for i in range(num_centroids):
        centroids[i] = space*(np.random.rand(2) - 0.5)
    
    for i in range(num_points):
        
        points[i] = np.random.normal(
            loc = centroids[np.random.choice(num_centroids)],
            scale = noise * space,
            size = (2)
        )
        
    return points

def plot_2D_points(points, **kwargs):
    
    fig, ax = plt.subplots(figsize = (6,6))

    plt.plot(points[:,0], points[:,1], '.', markerfacecolor = 'None')
    
    plt.xticks([])
    plt.yticks([])
    
    plt.title(kwargs.get('title'))
    plt.ylim(kwargs.get('ylim'))
    plt.xlim(kwargs.get('xlim'))
    
    plt.axis('equal')
    plt.tight_layout()
    
    return fig, ax

def generate_on_click(b):
    
    global points
    
    output_points.clear_output(wait = True)
    
    num_points = points_box.value
    num_centroids = centroids_box.value
    noise = noise_box.value
    
    points = rand_2D_points(num_points, num_centroids, noise = noise)
    
    plot_2D_points(points)
    
    with output_points:
        plt.show()
    
points_box = IntSlider(
    value = 100,
    min = 10,
    max = 1000
)

centroids_box = IntSlider(
    value = 3,
    min = 1,
    max = 8
)

noise_box = FloatSlider(
    value = 0.1,
    min = 0.01,
    max = 0.20,
    step = 0.01,
    readout_format = '0.2f'
)

generate_button = Button(
    description = 'Generate Points'
)

generate_button.on_click(generate_on_click)

output_points = Output()

display(
    HBox([VBox([Label('Number of points: '), Label('Number of centroids: '), Label('Noise: ')]),
         VBox([points_box, centroids_box, noise_box])]),
    generate_button,
    output_points
)

HBox(children=(VBox(children=(Label(value='Number of points: '), Label(value='Number of centroids: '), Label(v…

Button(description='Generate Points', style=ButtonStyle())

Output()

In [110]:
import time

max_iter = 20

centroids = np.zeros([])
coherence = np.zeros(max_iter)

def find_coherence(points, current, s):
    return np.sum([np.linalg.norm(points[i] - current[s[i]]) for i in range(points.shape[0])])

def init_clusters(points):
    
    global coherence
    global centroids
    
    s = kmeans_step(points, centroids[0])[0]
    
    coherence = np.zeros(max_iter)
    coherence[0] = find_coherence(points, centroids[0], s)
    
    plot_2D_clusters(points, centroids, s, 0)
        
    output_clusters.clear_output(wait = True)
    
    with output_clusters:
        print('Iteration 0')
        print(f'Coherence = {coherence[0]:.2f}')
        plt.show()

def kmeans_step(points, current):
    
    n = points.shape[0]
    s = np.zeros(n, dtype = int)
    
    k_num = current.shape[0]
    
    ## assign points to closest centroid
    for j in range(n):
        d = np.array([euclid_dist(points[j], current[k]) for k in range(k_num)])
        s[j] = np.argmin(d)
        
    ## update centroids
    new_centroids = np.array([np.mean(points[s == j], axis = 0) for j in range(k_num)])
        
    return s, new_centroids, find_coherence(points, new_centroids, s)

def plot_2D_clusters(points, centroids, s, num_iter):
    
    fig, ax = plt.subplots(figsize = (6, 6))

    for k in range(centroids.shape[1]):
        
        plt.plot(points[s == k][:,0], points[s == k][:,1], '.', markerfacecolor = 'None')
        
        plt.plot(centroids[:num_iter + 1,k,0], centroids[:num_iter + 1,k,1],
                 linestyle = ':',
                 color = 'k',
                 alpha = 0.6
                )
        
        for i in range(num_iter):
            plt.plot(*centroids[i,k,:], '.', color = 'k', alpha = 0.6)
        
        plt.plot(*centroids[num_iter,k,:], 'x', color = 'k')

    plt.xticks([])
    plt.yticks([])
        
    plt.axis('equal')
    plt.tight_layout()
    
def centroids_on_click(b):
    
    global centroids
    
    num_k = num_k_slider.value
    
    ## create an array to store the centroid history
    centroids = np.zeros((coherence.shape[0], num_k, points.shape[1]))
    
    ## randomly pick the initial centroids
    centroids[0] = points[np.random.choice(points.shape[0], num_k), :]
    
    init_clusters(points)
        
    cluster_button.disabled = False
    
def cluster_on_click(b):
    
    global centroids
    global coherence
    
    delay = speed_slider.value
    
    if coherence[1] != 0: init_clusters(points)
    
    for i in range(1, coherence.shape[0]):
    
        s, centroids[i], coherence[i] = kmeans_step(points, centroids[i-1])
        
        if coherence[i] == coherence[i-1]: break
        
        plot_2D_clusters(points, centroids, s, i)
        
        time.sleep(delay)
        
        output_clusters.clear_output(wait = True)
        
        with output_clusters:
            print(f'Iteration {i}')
            print(f'Coherence = {" ".join(str(coherence[j].round(1)) for j in range(i+1))}')
            plt.show()
    
num_k_slider = IntSlider(
    value = 3,
    min = 1,
    max = 8
)

speed_slider = FloatSlider(
    value = 2.0,
    min = 0,
    max = 3,
    step = 0.1,
    readout_format = '0.1f'
)

centroids_button = Button(
    description = 'Random centroids'
)

centroids_button.on_click(centroids_on_click)

cluster_button = Button(
    description = 'Find clusters',
    disabled = True
)

cluster_button.on_click(cluster_on_click)

output_clusters = Output()

display(
    HBox([VBox([Label('Number of centroids: '), Label('Speed of Animation: ')]),
          VBox([num_k_slider, speed_slider])]),
    centroids_button,
    cluster_button,
    output_clusters
)

plot_2D_points(points)
    
with output_clusters:
    print('Iteration N/A')
    print('Coherence = N/A')
    plt.show()

HBox(children=(VBox(children=(Label(value='Number of centroids: '), Label(value='Speed of Animation: '))), VBo…

Button(description='Random centroids', style=ButtonStyle())

Button(description='Find clusters', disabled=True, style=ButtonStyle())

Output()

In [111]:
new_points = np.copy(points)

def euclid_dist(a, b):
    return np.sqrt(np.sum((a - b)**2))

def kernel(d, h):
    return (1/(h * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (d/h)**2)

def kernel_density(xi, X, bandwidth):
    
    dist = np.array([euclid_dist(xi, X[j]) for j in range(X.shape[0])])
    
    return np.sum(kernel(dist, bandwidth))

def mean_shift_step(X, bandwidth):
    
    new_X = np.copy(X)
        
    for i, xi in enumerate(new_X):
            
        numer, denom = 0, 0
            
        for xj in X:
                
            dist = euclid_dist(xi, xj)
                
            weight = kernel(dist, bandwidth)
                
            numer += weight * xj
            denom += weight
                
        new_X[i] = numer / denom
            
    return new_X

def contours_on_click(b):
    
    global kde, x1, y1
    
    bandwidth = bandwidth_slider.value
    
    fig, ax = plot_2D_points(points, title = f'KDE contour with bandwidth = {bandwidth}')

    #x1 = np.linspace(*ax.get_xlim())
    #y1 = np.linspace(*ax.get_ylim())

    #kde = np.zeros((len(x1), len(y1)))

    for i in range(len(x1)):
        for j in range(len(y1)):
        
            kde[j, i] = kernel_density((x1[i], y1[j]), points, bandwidth)
            progress_box.value = i/len(x1) + j/(len(y1) * len(x1))

    plt.contour(x1, y1, kde)
    
    output_contours.clear_output(wait = True)
    
    with output_contours:
        print('Iteration 0')
        plt.show()
        
    shift_button.disabled = False
        
def shift_on_click(b):
    
    global new_points, kde, x1, y1
    
    if not np.array_equal(new_points, points):
        
        new_points = np.copy(points)
        plot_2D_points(points, title = 'Raw points')
        plt.contour(x1, y1, kde)
            
        output_contours.clear_output(wait = True)
    
        with output_contours:
            print('Iteration 0')
            plt.show()
        
    bandwidth = bandwidth_slider.value
    
    max_iter  = 10
    tolerance = 1e-2
    
    for i in range(max_iter):
        
        old_points = np.copy(new_points)
    
        new_points = mean_shift_step(new_points, bandwidth)
    
        fig, ax = plot_2D_points(
            new_points,
            xlim = (x1[0], x1[-1]),
            ylim = (y1[0], y1[-1]),
            title = f'KDE contour with bandwidth = {bandwidth}'
        )
        
        plt.contour(x1, y1, kde)
        
        time.sleep(speed_slider.value)
        
        output_contours.clear_output(wait = True)
    
        with output_contours:
            print(f'Iteration {i + 1}')
            plt.show()
            
        if np.linalg.norm(new_points - old_points) < tolerance: break
    
bandwidth_slider = FloatSlider(
    value = 1.5,
    min = 0,
    max = 3,
    step = 0.1,
    readout_format = '0.1f'
)

progress_box = FloatProgress(
    value = 0,
    min = 0,
    max = 1,
    #description = 'Progress: '
)

contours_button = Button(
    description = 'Show KDE contours'
)

contours_button.on_click(contours_on_click)

shift_button = Button(
    description = 'Means shift',
    disabled = True
)

shift_button.on_click(shift_on_click)

output_contours = Output()

fig, ax = plot_2D_points(points, title = 'Raw points')

x1 = np.linspace(*ax.get_xlim())
y1 = np.linspace(*ax.get_ylim())

kde = np.zeros((len(x1), len(y1)))

with output_contours:
    print('Iteration 0')
    plt.show()

display(
    HBox([VBox([Label('Bandwidth: '), Label('Speed of Animation: ')]),
          VBox([bandwidth_slider, speed_slider])]),
    
    HBox([contours_button, progress_box]),
    
    shift_button,
    output_contours
)


HBox(children=(VBox(children=(Label(value='Bandwidth: '), Label(value='Speed of Animation: '))), VBox(children…

HBox(children=(Button(description='Show KDE contours', style=ButtonStyle()), FloatProgress(value=0.0, max=1.0)…

Button(description='Means shift', disabled=True, style=ButtonStyle())

Output()