In [2]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.offline import iplot
from plotly.offline import iplot,init_notebook_mode
init_notebook_mode
init_notebook_mode(connected=True)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def initialize_centers(points,k):
    """
    Initialize the centers to be random points within the generated set
    args: 
        points: the np.array holding the points
        k: the number of groups to split the points into
    """
    random_index = np.random.choice(points.shape[0], k)
    centroids = points[random_index]
    return centroids

def initialize_centroids(points, k):
    """
    Initialize the centers to be random points within the generated set
    args: 
        points: the np.array holding the points
        k: the number of groups to split the points into
    """
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k]

In [4]:
def closest_centroid(points,centroids):
    """
    Given centroids, return the index of the centroid that each point is nearest.
    Use the euclidean distance as metric
    args: 
        points: the np.array holding the points
        centroids: the points that were determined as the center of previous iterations clusters
    """
    index_of_closest =  []
    for i in range(points.shape[0]):
        dist = np.sqrt(((points[i,:] - centroids)**2).sum(axis=1))
        index_of_closest.append(np.argmin(dist, axis=0))
    return np.array(index_of_closest)

def closest_centroid(points, centroids):
    """
    Given centroids, return the index of the centroid that each point is nearest.
    Use the euclidean distance as metric
    args: 
        points: the np.array holding the points
        centroids: the points that were determined as the center of previous iterations clusters
    """
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

In [5]:
def move_centroids(points, closest, centroids):
    """
    Derive the new centroids as the mean of the points that has the current centroids as closest centroid.
    args: 
        points: the np.array holding the points
        closest: vector with length equal thó the number of points with the cluster assignments given
        centroids: the points that were determined as the center of previous iterations clusters
    """
    new_centroids = []
    for k in range(centroids.shape[0]):
        new_centroids.append(points[closest==k].mean(axis=0))
    return np.array(new_centroids)

In [7]:
from kmeans_plot import kmeans_plot

In [8]:
center_1 = np.array([1, 0])
center_2 = np.array([-0.5, 0.5])
center_3 = np.array([-0.5, -0.5])

points = np.vstack(((np.random.randn(350, 2) * 0.75 + center_1),
                  (np.random.randn(150, 2) * 0.25 + center_2),
                  (np.random.randn(150, 2) * 0.5 + center_3)))

In [10]:
kmeans_plot(points,3,10,initialize_centers,closest_centroid,move_centroids)

In [None]:
plt.subplot(121)
plt.scatter(points[:, 0], points[:, 1])
centroids = initialize_centers(points, 3)
plt.scatter(centroids[:, 0], centroids[:, 1], c='r', s=100)

plt.subplot(122)
plt.scatter(points[:, 0], points[:, 1])
closest = closest_centroid(points, centroids)
centroids = move_centroids(points, closest, centroids)
plt.scatter(centroids[:, 0], centroids[:, 1], c='r', s=100)

In [None]:
#linalg.norm((centroids[:,np.newaxis]-points),axis=(2)).shape

#np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2)).shape

#[points[closest==k].mean(axis=0) for k in centroids.shape[0]]


In [None]:
iterations = list(range(10))
# make figure
fig_dict = {
    "data": [],
    "layout": {},
    "frames": []
}
# fill in most of layout
fig_dict["layout"]["xaxis"] = {"range": [-1.5, 3], "title": "X1"}
fig_dict["layout"]["yaxis"] = {"range": [-2, 3], "title": "Y1"}
fig_dict["layout"]["hovermode"] = "closest"

fig_dict["layout"]["updatemenus"] = [
    {
        "buttons": [
            {
                "args": [None, {"frame": {"duration": 500, "redraw": False},
                                "fromcurrent": True, "transition": {"duration": 300,
                                                                    "easing": "quadratic-in-out"}}],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                  "mode": "immediate",
                                  "transition": {"duration": 0}}],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }
]

# make data
iteration = 1
centroids = initialize_centers(points, 3)
closest = closest_centroid(points, centroids)

data_dict = {
    "x": list(points[:, 0])+list(centroids[:, 0]),
    "y": list(points[:, 1])+list(centroids[:, 1]),
    "mode": "markers",
    "marker": {
        "sizemode": "area",
        "sizeref": 200000,
        "size": 8,
        "color": list(closest)+[4,4,4]
    }
}
fig_dict["data"].append(data_dict)

for iter_ in iterations:
    frame = {"data": [], "name": str(iter_)}
    closest = closest_centroid(points, centroids)
    centroids = move_centroids(points, closest, centroids)

    data_dict = {
    "x": list(points[:, 0])+list(centroids[:, 0]),
    "y": list(points[:, 1])+list(centroids[:, 1]),
    "mode": "markers",
    "marker": {
        "sizemode": "area",
        "sizeref": 200000,
        "size": 8,
        "color": list(closest)+[4,4,4]
        }
    #     name: "centroids"
    }
    frame["data"].append(data_dict)

    fig_dict["frames"].append(frame)


fig = go.Figure(fig_dict)
#fig


$$
    d_j = ||\text{centroid}_j-x|| \\
    d_{ij} = ||\text{centroid}_j-x_i|| = \sqrt{ \sum_l (\text{centroid}_{lj}-x_{li})^2 }
$$