In [None]:
from sklearn.cluster import estimate_bandwidth
from sklearn.cluster import get_bin_seeds
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array
from sklearn.cluster import MeanShift

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML

np.random.seed(1234)

In [None]:
# set whatever dataset you want and do the preproc ;)
# However: for this to work, we need to have only two features (we want to visualize the clustering)
from sklearn import datasets
dataset = datasets.load_wine()
df = pd.DataFrame(dataset['data'], columns=dataset.feature_names)
# target = pd.DataFrame(dataset['target'])
display(df.head())

# select the two features
df = df[['alcohol', 'color_intensity']]
df.head()

In [None]:
# if you want to normalize, comment this out:

# from sklearn.preprocessing import StandardScaler
# scaler = StandardScaler()
# df = pd.DataFrame(scaler.fit_transform(df), columns=df.columns)
# df.describe()

In [None]:
def start_algo(X, bandwidth=None, seeds=None, bin_seeding=True, max_iter=1000, cluster_all=True):
    """
    Implemantation taken and modified from: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html
    """
    m = MeanShift(bandwidth=bandwidth, bin_seeding=bin_seeding, seeds=seeds,max_iter=max_iter, cluster_all=cluster_all)
    # X = np.array(X)
    X = m._validate_data(X)
    bandwidth = m.bandwidth
    if bandwidth is None:
        bandwidth = estimate_bandwidth(X, n_jobs=None)
    seeds = m.seeds    
    if seeds is None:
        if m.bin_seeding:
            seeds = get_bin_seeds(X, bandwidth, m.min_bin_freq)
        else:
            seeds = X
    
    n_samples, n_features = X.shape
    center_intensity_dict = {}

    nbrs = NearestNeighbors(radius=bandwidth).fit(X)
    all_res = []
    seeds_copy = [[i[0], i[1]] for i in seeds]
    for i in range(len(seeds)):
        # For each seed, climb gradient until convergence or max_iter
        bandwidth = nbrs.get_params()["radius"]
        stop_thresh = 1e-3 * bandwidth  # when mean has converged
        completed_iterations = 0
        while True:
            # Find mean of points within bandwidth
            i_nbrs = nbrs.radius_neighbors([seeds_copy[i]], bandwidth, return_distance=False)[0]
            points_within = X[i_nbrs]
            if len(points_within) == 0:
                break  # Depending on seeding strategy this condition may occur
            my_old_mean = seeds_copy[i]  # save the old mean

            seeds_copy[i] = np.mean(points_within, axis=0)
            yield seeds_copy, X, []

            # If converged or at max_iter, adds the cluster
            if (
                np.linalg.norm(seeds_copy[i] - my_old_mean) < stop_thresh
                or completed_iterations == max_iter
            ):
                break
            completed_iterations += 1
        all_res.append((tuple(seeds_copy[i]), len(points_within), completed_iterations))
        # print('about to yield')
        yield seeds_copy, X, []

    yield seeds_copy, X, []
    
    # Post process
    # copy results in a dictionary
    for i in range(len(seeds)):
        if all_res[i][1]:  # i.e. len(points_within) > 0
            center_intensity_dict[all_res[i][0]] = all_res[i][1]
    m.n_iter_ = max([x[2] for x in all_res])

    if not center_intensity_dict:
        # nothing near seeds
        raise ValueError(
            "No point was within bandwidth=%f of any seed. Try a different seeding"
            " strategy                              or increase the bandwidth."
            % bandwidth
        )

    sorted_by_intensity = sorted(
        center_intensity_dict.items(),
        key=lambda tup: (tup[1], tup[0]),
        reverse=True,
        )
    sorted_centers = np.array([tup[0] for tup in sorted_by_intensity])
    unique = np.ones(len(sorted_centers), dtype=bool)
    nbrs = NearestNeighbors(radius=bandwidth).fit(sorted_centers)
    for i, center in enumerate(sorted_centers):
        if unique[i]:
            neighbor_idxs = nbrs.radius_neighbors([center], return_distance=False)[0]
            for n in neighbor_idxs:
                prv = unique[n]
                unique[n] = 0
                unique[i] = 1  # leave the current point as unique
                if prv:
                    yield sorted_centers[unique], X, []

    cluster_centers = sorted_centers[unique]
    yield cluster_centers, X, []
    
    # ASSIGN LABELS: a point belongs to the cluster that it is closest to
    nbrs = NearestNeighbors(n_neighbors=1).fit(cluster_centers)
    labels = np.zeros(n_samples, dtype=int)
    distances, idxs = nbrs.kneighbors(X)
    if m.cluster_all:
        labels = idxs.flatten()
    else:
        labels.fill(-1)
        bool_selector = distances.flatten() <= bandwidth
        labels[bool_selector] = idxs.flatten()[bool_selector]
    # print(labels)
    # print('about to return')
    m.cluster_centers_ = cluster_centers
    m.labels_ = labels
    yield cluster_centers, X, labels

In [None]:
generator = start_algo(df.copy())

In [None]:
fig, ax = plt.subplots()
rects = ax.scatter(df[df.columns[0]], df[df.columns[1]])

In [None]:
text = ax.text(0.01, 0.95, "", transform=ax.transAxes)
iteration = [0]

In [None]:
def animate(A):
    ax.clear()
    # print(A[2])
    if len(A[2]) == 0:
        ax.scatter(df[df.columns[0]], df[df.columns[1]])
    else:
        ax.scatter(df[df.columns[0]], df[df.columns[1]], c=[A[2]])
    ax.scatter([i[0] for i in A[0]], [i[1] for i in A[0]], c='red')

In [None]:
anim = FuncAnimation(fig, func=animate, frames=generator, interval=100,
                     repeat=False, save_count=10000)

In [None]:
# This will output the animation
HTML(anim.to_jshtml())

In [None]:
# Test if visualizer did right:
from sklearn.cluster import MeanShift
m = MeanShift(bin_seeding=True)
l = m.fit_predict(df)
plt.scatter(df[df.columns[0]], df[df.columns[1]], c=l)
plt.scatter([i[0] for i in m.cluster_centers_], [i[1] for i in m.cluster_centers_], c='red')