# K-means interactive

> Yang

In [16]:
%matplotlib qt
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button
from matplotlib.widgets import PolygonSelector
from sklearn.cluster import KMeans

def colors_from_lbs(lbs, colors=None):
    mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
          '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
          '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',
          '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']
    
    if colors is None:
        colors = np.array(mpl_20)
    else:
        colors = np.array(colors)
    lbs = np.array(lbs) % len(colors)
    return colors[lbs]

rng = np.random.RandomState(0)
n_samples = 1000
cov = [[0.4, 0], [0, 0.4]]
X = np.concatenate([
    rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), 
    rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples),
    rng.multivariate_normal(mean=[0.3, 1], cov=cov, size=n_samples)
    ])

kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto")
labels = kmeans.fit_predict(X)

centers = kmeans.cluster_centers_

fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))

def plot_figure(axe_list, X, centers):
    ax_orig, ax_redim = axe_list

    kmeans.cluster_centers_ = np.array(centers, dtype=np.float64)
    labels = kmeans.predict(X)    

    ax_orig.clear()
    ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label="samples", c=colors_from_lbs(labels))
    ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')
    ax_orig.set(
        aspect="auto", 
        title="Interactive K-means",
        xlabel="first feature",
        ylabel="second feature",
    )

    ax_redim.clear()
    class_name = ['class {0}'.format(i+1) for i in range(len(centers))]

    # update labels
    counts = [np.sum(labels==i) for i in range(len(centers))]
    

    ax_redim.bar(class_name, counts, 
                label=class_name,
                color=colors_from_lbs(range(len(centers))))
    ax_redim.set(
        aspect="auto",
        title="Clustering results",
        xlabel="Main feature",
        ylabel="Number of samples",
    )
    fig.canvas.draw_idle()

plot_figure((ax_orig, ax_redim), X, centers)

def onselect(verts):
    centers = np.array(verts)
    plot_figure((ax_orig, ax_redim), X, centers)

selector = PolygonSelector(ax_orig, onselect=onselect, 
                           props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f"Component"))
selector.verts = centers

plt.tight_layout()
plt.show()
