In [2]:
import jupyter_manim
import numpy as np
from manim import *

config.media_width = "100%"
config.verbosity = "WARNING"

In [None]:
%%manim -qm ClusterVectors


# This shows 3 2d vectors, visually shows their centroid, and then orients them around the centroid
class ClusterVectors(Scene):

    def setup_axes(self):
        self.axes = Axes(
            x_range=[-4, 4],
            y_range=[-4, 4],
            x_length=8,
            y_length=8,
            axis_config={"color": BLUE},
        )
        self.add(self.axes)
    
    def plot_vectors(self):
        # create 64 vectors, randomly oriented around 3 clusters
        self.vectors = []
        self.cluster_centers = [np.array([-1, 1]), np.array([1, 1]), np.array([0, -1])]
        # ensure the random vectors are centered around the clusters
        # and are generated with a static seed
        # also ensure that they are spread out, almost touching and possibly some
        np.random.seed(0)
        for i in range(64):
            cluster = self.cluster_centers[np.random.choice([0, 1, 2])]
            angle = np.random.uniform(0, 1.5 * np.pi)
            length = np.random.uniform(-2.5, 2.5)
            vector = cluster + length * np.array([np.cos(angle), np.sin(angle)])
            self.vectors.append(vector)
        # create the vectors as dots, uncolored
        # visually allow them to "pop" in the scene
        self.dots = []
        for vector in self.vectors:
            dot = Dot(color=WHITE, radius=0.05)
            dot.move_to(self.axes.c2p(vector[0], vector[1]))
            # add a small animation to the dot
            self.play(FadeIn(dot), run_time=0.01)
            self.add(dot)
            self.dots.append(dot)


    def cluster_vectors(self):
        # using our cluster centers, we can find the centroid of the vectors and color all the dots
        # to the cluster they are closest to
        # while adding the centers to the scene
        # find the centroid of the vectors
        # first add the centroids, larger and colored
        colors = [RED, GREEN, BLUE]
        idx = 0
        for center in self.cluster_centers:
            dot = Dot(color=colors[idx], radius=0.1)
            self.play(FadeIn(dot), run_time=0.01)
            dot.move_to(self.axes.c2p(center[0], center[1]))
            self.add(dot)
            idx += 1
        idx = 0
        for vector in self.vectors:
            # find the closest cluster center
            distances = [np.linalg.norm(vector - center) for center in self.cluster_centers]
            closest_center = np.argmin(distances)
            # color the dot based on the cluster center
            if closest_center == 0:
                color = RED
            elif closest_center == 1:
                color = GREEN
            else:
                color = BLUE
            #self.dots[idx].set_color(color)
            # play a small animation to the dot transforming its color
            self.play(self.dots[idx].animate.set_color(color), run_time=0.01)
            idx += 1
    
    def transform_to_postings(self):
        # transform the dots into three postings lists
        # fading out the axis, indicating how IVF works
        # and showing the postings lists
        # fade out the axis
        colors = [RED, GREEN, BLUE]
        self.play(FadeOut(self.axes), run_time=0.5)
        # create a list of postings lists
        # this is three rows, each starting with the cluster center
        # animating the dots to their respective lists
        postings = []
        for i in range(3):
            postings.append(VGroup())
            # add the cluster center to the postings list
            dot = Dot(color=colors[i], radius=0.1)
            dot.move_to(self.axes.c2p(self.cluster_centers[i][0], self.cluster_centers[i][1]))
            postings[i].add(dot)
            # animate the dots to their respective postings list
            idx = 0
            for vector in self.vectors:
                distances = [np.linalg.norm(vector - center) for center in self.cluster_centers]
                closest_center = np.argmin(distances)
                if closest_center == i:
                    # animate the dot to the postings list
                    postings[i].add(self.dots[idx])
                    # offset by vector dot list position, e.g. len(postings[i]) * 0.1
                    postings[i].move_to(self.axes.c2p(self.cluster_centers[i][0], self.cluster_centers[i][1]) + DOWN * len(postings[i]) * 0.1)
                    # animate the dot to the postings list
                    self.dots[idx].move_to(postings[i])
                    self.play(self.dots[idx].animate.move_to(postings[i]), run_time=0.01)
                idx += 1
        # create a group of postings lists
        postings_group = VGroup()
        for i in range(3):
            postings_group.add(postings[i])
        # position the postings lists
        postings_group.arrange(DOWN, buff=0.5)
        postings_group.move_to(ORIGIN)
        # add the postings lists to the scene
        self.play(FadeIn(postings_group), run_time=0.5)
        # add a label to each postings list
        labels = ["Cluster 1", "Cluster 2", "Cluster 3"]
        for i in range(3):
            label = Text(labels[i], color=colors[i])
            label.move_to(postings[i].get_center() + UP * 0.5)
            self.play(FadeIn(label), run_time=0.5)
            postings[i].add(label)
        # add a label to the entire group
        label = Text("Postings Lists", color=WHITE)
        label.move_to(postings_group.get_center() + UP * 1.5)
        self.play(FadeIn(label), run_time=0.5)



    # create an axis and plot the vectors, origin at 0,0
    def construct(self):
        self.setup_axes()
        self.plot_vectors()
        # pause
        self.wait(1)
        self.cluster_vectors()
        self.wait(1)
        self.transform_to_postings()
        self.wait(1)

                                                                                              