In [None]:
!sudo apt update
!sudo apt install libcairo2-dev ffmpeg \
    texlive texlive-latex-extra texlive-fonts-extra \
    texlive-latex-recommended texlive-science \
    tipa libpango1.0-dev
!pip install manim
!pip install IPython --upgrade

In [50]:
from manim import *

from matplotlib import pyplot as pyplot

import numpy as np
import tensorflow as tf

In [51]:
def get_squares_dataset(n_samples=1000, noise_level=0.1):
    num_samples_per_square = n_samples // 4
    centers = [[0, 3.75], [0, 0], [3.5, 0], [3.5, 3.75]]
    variance = [[noise_level, 0], [0, noise_level]]
    X = np.vstack([
        np.random.multivariate_normal(center, variance, num_samples_per_square) for center in centers
    ])
    y = np.hstack([
        np.full(num_samples_per_square, i % 2) for i in range(4)
    ])
    return X, y

X, y = get_squares_dataset(n_samples=2000, noise_level=0.2)

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

X_scaled = scaler.fit_transform(X)

In [59]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

from keras.layers import Dense
from keras.optimizers import Adam
from keras.losses import BinaryCrossentropy


def build_model(activation="leaky_relu", layers=5):

  model = tf.keras.Sequential()

  for _ in range(layers):
    layer = Dense(2, activation=activation)
    model.add(layer)

  return model

model = build_model('leaky_relu', layers=6)

model.compile(
    optimizer=Adam(learning_rate=0.005),
    loss=BinaryCrossentropy(),
    metrics=['accuracy']
)

In [None]:
from keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_accuracy', patience=200, mode='max',
                               baseline=0.9, restore_best_weights=True)

history = model.fit(X_train, y_train, epochs=400, batch_size=32,
                    validation_data=(X_test, y_test), callbacks=[early_stopping])

In [61]:
model.evaluate(X_test, y_test)



[0.04640892520546913, 0.987500011920929]

In [98]:
%%manim -qh -v WARNING MatrixMultiplicationScene

class MatrixMultiplicationScene(Scene):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ax = Axes(
            x_range=[-4, 4, 1],
            y_range=[-4, 4, 1],
            x_length=7,
            y_length=7,
            tips=False,
        )

        self.dataset = X_test[:100]
        self.y = y_test[:100]

        self.dots = VGroup(*self.c2p(self.dataset))

    def construct(self):

        self.add(self.ax, self.dots)

        for idx, l in enumerate(model.layers):
          if hasattr(l, 'get_weights') and len(l.get_weights()) > 0:
            w, b = l.get_weights()
            self.display_title(f'x @ w{idx+1}')
            self.mm(w)
            self.display_title(f'x + b{idx+1}')
            self.vector_addition(b)
            activation_function = tf.keras.activations.get(l.activation)
            self.display_title(l.activation.__name__.replace('_', ' ').title())
            self.dataset_transform(lambda dataset: activation_function(dataset).numpy())

        self.wait(1)

    def dataset_transform(self, f):
        transformed_dataset = f(self.dataset)
        transformed_dots = VGroup(*self.c2p(transformed_dataset))
        self.play(Transform(self.dots, transformed_dots))
        self.dataset = transformed_dataset

    def vector_addition(self, vector):
        if isinstance(vector, list):
          matrix = np.array(vector)
        self.dataset_transform(lambda dataset: dataset + vector)

    def mm(self, matrix):
        if isinstance(matrix, list):
          matrix = np.array(matrix)
        self.dataset_transform(lambda dataset: dataset @ matrix)

    def display_title(self, title_text):
        title = Text(title_text, font_size=24, font='Roboto').to_corner(UP + LEFT)
        self.play(Write(title), run_time=1)
        self.wait(0.5)
        self.play(FadeOut(title), run_time=1)

    def c2p(self, points):
      def get_color(idx):
        return ManimColor('#3476cd') if self.y[idx] == 0 else ManimColor('#ab482a')
      return [Dot(self.ax.c2p(p[0], p[1]), color=get_color(idx), fill_opacity=0.9, radius=0.05) for idx, p in enumerate(points)]

