In [None]:

from scipy.linalg import norm

def label(x, means):
	y = norm(means - x, axis = 1)
	return argmin(y)


In [None]:

# clusters is an inhomogeneous array, 
# so must be dtype = object

def assign_clusters(dataset, means):
	k = len(means)
	labels = array([ label(x,means) for x in dataset ])
	clusters = array([ dataset[labels==i,:] for i in range(k) ], dtype = object)
	return clusters, labels

def update_means_clusters(clusters):
	k = len(clusters)
	means = array([ mean(c, axis=0) for c in clusters if len(c) ])
	clusters = array([ c for c in clusters if len(c) ], dtype = object)
	return means, clusters

def nearest_means(dataset, k):
	close_enough = False
	(N,d) = shape(dataset)
	# random initial guess 
	means = samples((k,d))
	n = 0
	while not close_enough:
		clusters, labels = assign_clusters(dataset, means)
		newmeans, clusters = update_means_clusters(clusters)
		# only check closeness if len(means) is unchanged
		if len(newmeans) == len(means): 
			close_enough = allclose(means,newmeans)
		means = newmeans
		print(f'iteration {n}: ', [ len(c) for c in clusters ])
		n += 1
	
from numpy import *
from numpy.random import default_rng as rng
samples = rng().random

d, k, N = 2, 7, 100
dataset = samples((N,d))
nearest_means(dataset, k)


In [None]:

from matplotlib.pyplot import *

scalars = linspace(0,1,k)
cmap = colormaps['hsv']

# array of RGB colors
colors = cmap(scalars)

from matplotlib.colors import rgb2hex

# array of hex colors #rrggbb
for color in colors: print(rgb2hex(color))


In [None]:

from scipy.spatial import ConvexHull
from io import BytesIO

# import Python image library pillow
from PIL import Image

def plot_clusters(n, clusters, means, colors):
	for clust, mu, color in zip(clusters, means, colors):
		scatter(*clust.T, s = 2, c = rgb2hex(color))
		scatter(*mu, s = 15, c = 'w', edgecolor= 'k', marker = '*')
		hull = ConvexHull(clust)
		for simplex in hull.simplices: 
			plot(*clust[simplex].T, c = rgb2hex(color), lw = .5)
	grid()
	title(f'iteration {n}')
	# empty buffer
	buffer = BytesIO()
	savefig(buffer)
	# To display frames, replace close() by  show()
	close()
	frame = Image.open(buffer)
	return frame
	
# delete colors corresponding to empty clusters

def update_colors(clusters, colors):
	k = len(colors)
	colors = array([ color for color, cluster in zip(colors, clusters) if len(cluster) ])
	return colors



In [None]:


def nearest_means_with_frames(dataset, k):
	close_enough = False
	(N,d) = shape(dataset)
	means = samples((k,d))
	colors = cmap(linspace(0,1,k))
	frames = []
	while not close_enough:
		clusters, labels = assign_clusters(dataset, means)
		colors = update_colors(clusters, colors)
		newmeans, clusters = update_means_clusters(clusters)
		# only check closeness if len(means) is unchanged
		if len(newmeans) == len(means): 
			close_enough = allclose(means, newmeans)
		means = newmeans
		n = len(frames)
		frame = plot_clusters(n,clusters, means, colors)
		frames.append(frame)
		print(f'iteration {n}: ', [ len(cluster) for cluster in clusters ])
	return frames

# here the first frame is saved to file
# then the subsequent frames are appended to file
 
def create_animation(frames, file):
	frames[0].save(file, format = 'PNG', save_all = True, append_images = frames[1:], duration = 500, loop = 1)
	print(f'output at {file}')




In [None]:

d, k, N = 2, 7, 1000
dataset = samples((N,d))
scatter(*dataset.T, s = 2)
grid()
title('iteration 0')
show()

frames = nearest_means_with_frames(dataset, k)
file = 'animation.png'
create_animation(frames, file)

from IPython.display import Image as DisplayImage
DisplayImage(file)
