In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# Generate synthetic data
from sklearn.datasets import make_blobs
X, _ = make_blobs(n_samples=200, centers=3, cluster_std=0.8, random_state=42)

# Mean-shift parameters
bandwidth = 1.0  # Radius of the kernel
max_iter = 10

# Pick a few random starting points
np.random.seed(42)
start_points = X[np.random.choice(len(X), size=10, replace=False)]
trajectories = [[pt.copy()] for pt in start_points]  # store path of each point

def shift_point(point, X, bandwidth):
    # Compute distances to all points
    distances = np.linalg.norm(X - point, axis=1)
    # Points within bandwidth
    within_band = X[distances < bandwidth]
    # Mean of nearby points (move to the center of them)
    if len(within_band) > 0:
        return np.mean(within_band, axis=0)
    else:
        return point

# Run mean-shift manually
for _ in range(max_iter):
    for i, path in enumerate(trajectories):
        current_point = path[-1]
        new_point = shift_point(current_point, X, bandwidth)
        path.append(new_point)

# Animation
fig, ax = plt.subplots(figsize=(6, 6))

def animate(frame):
    ax.clear()
    ax.scatter(X[:, 0], X[:, 1], c='lightgray', label='Data')
    for i, path in enumerate(trajectories):
        trail = np.array(path[:frame+1])
        ax.plot(trail[:, 0], trail[:, 1], color='blue', alpha=0.6)
        ax.scatter(*trail[-1], color='red', s=50)
    ax.set_title(f'Mean-Shift Iteration {frame}')
    ax.legend(loc='upper left')

anim = FuncAnimation(fig, animate, frames=max_iter+1, interval=800)
plt.close()

# Show animation
HTML(anim.to_jshtml())
