In [121]:
# !pip install --upgrade setuptools wheel
# !pyenv global 3.9
# !conda install -c conda-forge pycairo --yes
# !pip3 install manim
# !pip install jupyter_manim
# !python -V
# !conda install -c conda-forge pygobject

In [None]:
# use pyenv 3.9.20 version
from manim import *
from copy import deepcopy as copy
import numpy as np
import math
import random
# import jupyter_manim

In [133]:
%%manim -v WARNING -pqh GradientDescent

random.seed(11)

class GradientDescent(VectorScene):
    
    def construct(self):
        name = Text("Made by: Ahmed Alnasser").scale(0.3).to_corner(DL)
        self.add(name)

        dot = Text(".").scale(2)
        dots = VGroup(*[dot.copy() for _ in range (3)]).arrange(DOWN, buff=0.1).scale(0.7)

        table_data = [
                    [Text("2.5").set_color(BLUE_B), Text("16").set_color(RED_A), Text("1").set_color(GREEN_B), Text("65").set_color(YELLOW_C)],
                    [Text("3.7").set_color(BLUE_B), Text("14").set_color(RED_A), Text("0").set_color(GREEN_B), Text("82").set_color(YELLOW_C)],
                    [dots.copy().set_color(BLUE_B), dots.copy().set_color(RED_A), dots.copy().set_color(GREEN_B), dots.copy().set_color(YELLOW_C)],
                    [Text("2.3").set_color(BLUE_B), Text("18").set_color(RED_A), Text("0").set_color(GREEN_B), Text("71").set_color(YELLOW_C)],
                    [Text("3.1").set_color(BLUE_B), Text("15").set_color(RED_A), Text("1").set_color(GREEN_B), Text("90").set_color(YELLOW_C)],
                ]
        
        table = MobjectTable(
            table_data,
            # include_outer_lines=True,
            col_labels=[
                Text("GPA").set_color(BLUE_B), 
                Text("Credits").set_color(RED_A),
                Text("Part-time Job").set_color(GREEN_B), 
                Text("Grade (y)").set_color(YELLOW_C)],
            row_labels=[Tex(f"$Sample_1$"), Tex(f"$Sample_2$"), dots.copy(), Tex(f"$Sample_{{n-1}}$"),Tex(f"$Sample_n$")],
            top_left_entry=Text("DATASET")
        )

        MSE = MathTex("\\epsilon(\\hat{y}, y)", "=", "\\frac{1}{n}", "\sum_{i=1}^{n}", "(", "\hat{y}_i", "-", "y_i", ")", "^2")#, tex_to_color_map={r"y_i": YELLOW_C})
        # y_hat = MathTex(r"\hat{y} = \beta_1x_1 + \beta_2x_2 + \beta_3x_3 + \beta_4")
        MSE[-3].set_color(YELLOW_C)
        y_hat = MathTex("\\hat{y}", "=", "\\beta_1", "+", "\\beta_2", "x_1", "+", "\\beta_2", "x_2", "+", "\\beta_3", "x_3")

        
        columns = table.get_columns()

        y_hat.shift(DOWN*2.5).scale(1.2)
        MSE.move_to(y_hat.get_center()).scale(1.3)
        y_hat_symbol = MSE[5].copy()


        ########################################################################################

        self.play(Write(table.scale(0.65).to_edge(UP)), run_time=3.7)
        self.wait(0.5)
        self.play(Write(y_hat), run_time=3.5)

        self.wait(1)
        columns_list = [columns[i][1:].copy().set_z_index(0) for i in range (1,4)]

        x2_term_copy = y_hat[8].copy()
        self.play(
            y_hat[5].animate.set_color(BLUE_C),
            y_hat[8].animate.set_color(RED_B),
            x2_term_copy.animate.set_color(RED_B),
            y_hat[11].animate.set_color(GREEN_C),
            ReplacementTransform(columns_list[0], y_hat[5].set_z_index(2)),
            ReplacementTransform(columns_list[1], x2_term_copy.set_z_index(2)), # y_hat[8] resulted in unexpected result so I had to play around
            ReplacementTransform(columns_list[2], y_hat[11].set_z_index(2)),

            
            lag_ratio= 0.07,
            run_time=3,
        )
        self.remove(x2_term_copy)
        self.wait(1)
        self.play(Transform(y_hat, y_hat_symbol))
        self.wait(0.2)
        self.play(Write(MSE))
        self.play(FadeOut(y_hat), run_time=0.001)
        self.wait(1.5)

        #######################################################################################

        def func(x):
            x -= 5
            return (1/30)*(x-0.2)*(x-1.11)*(x+1)*(x+3.7)*(x+4)*(x+5)
        
        d_func = lambda x: (30000*x**5-440250*x**4+2270500*x**3-4861779*x**2+3924180*x-826072)/150000

        plane = NumberPlane()

        self.play(
            FadeOut(table),
            MSE.animate.set_color(WHITE).scale(0.7).to_corner(UL),
        )
        background = SurroundingRectangle(MSE, color=BLACK, fill_opacity=0.7, buff=0.2)
        background.stretch(8, dim=1).align_to(background.get_top(), UP)
        self.play(Create(background), run_time=0.001)
        self.bring_to_front(name),
        self.bring_to_front(MSE),
        self.wait(0.5)
        self.bring_to_back(plane)
        self.play(Write(plane))

        func_plt = plane.plot(func).set_color(RED_C)
        self.play(
            Create(func_plt), 
            MSE[0].animate.set_color(RED_C), 
            run_time=3.3, 
            rate_func=smooth
        )
        self.play(background.animate, 
        rate_func=linear)

        betas_val_trackers = [ValueTracker(0) for _ in range(4)]

        # Define the beta equations with initial placeholders
        beta1_eq = MathTex("\\beta_1", "=").next_to(MSE, DOWN, aligned_edge=LEFT).shift(LEFT*0.2)
        beta1_val = DecimalNumber(0, num_decimal_places=2).next_to(beta1_eq, RIGHT).add_updater(lambda x: x.set_value(betas_val_trackers[0].get_value()))
        beta3_eq = MathTex("\\beta_3", "=").next_to(beta1_val, RIGHT).shift(RIGHT*0.2)
        beta3_val = DecimalNumber(0, num_decimal_places=2).next_to(beta3_eq, RIGHT).add_updater(lambda x: x.set_value(betas_val_trackers[2].get_value()))

        beta2_eq = MathTex("\\beta_2", "=").next_to(beta1_eq, DOWN, aligned_edge=LEFT)#.shift(LEFT*0.2)
        beta2_val = DecimalNumber(0, num_decimal_places=2).next_to(beta2_eq, RIGHT).add_updater(lambda x: x.set_value(betas_val_trackers[1].get_value()))
        beta4_eq = MathTex("\\beta_4", "=").next_to(beta2_val, RIGHT).shift(RIGHT*0.2)
        beta4_val = DecimalNumber(0, num_decimal_places=2).next_to(beta4_eq, RIGHT).add_updater(lambda x: x.set_value(betas_val_trackers[3].get_value()))

        beta1_eq[0].set_color(GREEN)
        beta2_eq[0].set_color(GREEN)
        beta3_eq[0].set_color(GREEN)
        beta4_eq[0].set_color(GREEN)

        beta1 = VGroup(beta1_eq, beta1_val)
        beta2 = VGroup(beta2_eq, beta2_val)
        beta3 = VGroup(beta3_eq, beta3_val)
        beta4 = VGroup(beta4_eq, beta4_val)

        beta_vals = VGroup(beta1_val, beta2_val, beta3_val, beta4_val)

        betas = VGroup(beta1, beta2, beta3, beta4)

        # Display the beta values
        self.play(Write(betas), run_time=1.5)

        # Animate the value trackers to increase to 10
        # for anim in animations:
        note = Text("Randomize betas").scale(0.4).next_to(MSE, DOWN*7).set_color(RED_B)
        self.play(Write(note))
        for i in range(3):
            animations = [tracker.animate.set_value(random.uniform(-10,10)).set_color(RED_B) for tracker in betas_val_trackers]
            if i != 2:
                change_color = [
                    beta1_val.animate.set_color(MAROON_B),
                    beta2_val.animate.set_color(MAROON_B),
                    beta3_val.animate.set_color(MAROON_B),
                    beta4_val.animate.set_color(MAROON_B),
                ]
            else:
                change_color = [
                    beta1_val.animate.set_color(WHITE),
                    beta2_val.animate.set_color(WHITE),
                    beta3_val.animate.set_color(WHITE),
                    beta4_val.animate.set_color(WHITE),
                ]
            self.play(*animations, *change_color, run_time=0.5)
        self.play(FadeOut(note))
        self.wait(0.7)

        # self.play(Transform(betas.copy(), MSE[5]))
        self.play(
            ReplacementTransform(beta1_eq[0].copy(), MSE[5]),
            ReplacementTransform(beta2_eq[0].copy(), MSE[5]),
            ReplacementTransform(beta3_eq[0].copy(), MSE[5]),
            ReplacementTransform(beta4_eq[0].copy(), MSE[5]),
        )


        ###################### Display the line with initializations  ######################
        # ValueTrackers for the point's x position, y position, and slope
        point_x = ValueTracker(1.2)
        point_y = ValueTracker(func(point_x.get_value()))
        slope = ValueTracker(d_func(point_x.get_value()))

        # Updaters to keep `point_y` and `slope` in sync with `point_x`
        point_y.add_updater(lambda y: y.set_value(func(point_x.get_value())))
        slope.add_updater(lambda s: s.set_value(d_func(point_x.get_value())))

        # Add trackers to the scene to ensure they update
        self.add(point_y, slope)

        line_length = 1.5

        # Line with dynamic start and end points
        line = Line().add_updater(lambda l: l.put_start_and_end_on(
            # Start point based on midpoint and slope
            np.array([
                point_x.get_value() - (line_length / 2) / np.sqrt(1 + slope.get_value()**2),
                point_y.get_value() - (line_length / 2) * slope.get_value() / np.sqrt(1 + slope.get_value()**2),
                0
            ]),
            np.array([
                point_x.get_value() + (line_length / 2) / np.sqrt(1 + slope.get_value()**2),
                point_y.get_value() + (line_length / 2) * slope.get_value() / np.sqrt(1 + slope.get_value()**2),
                0
            ])
        )).set_color(GREEN_B)
        middle_dot = Dot().set_color(GREEN_E).add_updater(lambda d: d.move_to(np.array([point_x.get_value(), point_y.get_value(), 0])))

        line_dot = VGroup(line, middle_dot)
        threshold_line = DashedLine(start=np.array([-7.2, -2.22, 0]), end=np.array([7.2, -2.22, 0]), dash_length=0.2, dashed_ratio=0.5).set_color(RED_A)
        threshold_text = Text("Threshold line").next_to(threshold_line, DOWN).scale(0.7).shift(RIGHT*0.1).set_color(RED_A)


        
        self.play(Create(line_dot))
        self.wait(1)
        # self.play(Create(line))

        b1_grad = MathTex("\\beta_1", "=", "\\beta_1", "-", "\\alpha", "{\partial\\epsilon\\over\\partial\\beta_1}").scale(0.7).next_to(beta2, DOWN, aligned_edge=LEFT)
        b2_grad = MathTex("\\beta_2", "=", "\\beta_2", "-", "\\alpha", "{\partial\\epsilon\\over\\partial\\beta_2}").scale(0.7).next_to(b1_grad, DOWN, aligned_edge=LEFT)
        b3_grad = MathTex("\\beta_3", "=", "\\beta_3", "-", "\\alpha", "{\partial\\epsilon\\over\\partial\\beta_3}").scale(0.7).next_to(b2_grad, DOWN, aligned_edge=LEFT)
        b4_grad = MathTex("\\beta_4", "=", "\\beta_4", "-", "\\alpha", "{\partial\\epsilon\\over\\partial\\beta_4}").scale(0.7).next_to(b3_grad, DOWN, aligned_edge=LEFT)
        grads = VGroup(b1_grad, b2_grad, b3_grad, b4_grad).shift(RIGHT*1.5)
        grads.shift(RIGHT*0.5)
        grads_text = Text("Gradients: ").scale(0.5).set_color(GREEN).next_to(grads, LEFT*2)


        # def wave_grads():
        #     self.play(ApplyWave(grads, amplitude=0.08, ripples=1))
        
        def update_grads_move_line(x, change_rate=1):
            copy1 = b1_grad[2:].copy()
            copy2 = b2_grad[2:].copy()
            copy3 = b3_grad[2:].copy()
            copy4 = b4_grad[2:].copy()
            self.play(
                Transform(copy1, b1_grad[0]), 
                Transform(copy2, b2_grad[0]), 
                Transform(copy3, b3_grad[0]), 
                Transform(copy4, b4_grad[0]),
            )
            animations = [tracker.animate.set_value(tracker.get_value() + random.uniform(-1*change_rate,1*change_rate)) for tracker in betas_val_trackers]
            # if run_time:
            #     self.play(*animations, point_x.animate.set_value(x), run_time=run+time)
            # else:
            self.play(*animations, point_x.animate.set_value(x))
            # self.play(point_x.animate.set_value(x))
            self.play(FadeOut(copy1), FadeOut(copy2), FadeOut(copy3), FadeOut(copy4), run_time=0.001)

        # colors grads
        for grad in grads:
            grad[-1][-5].set_color(RED)
            grad[-1][-2:].set_color(GREEN)
            grad[0].set_color(GREEN)
            grad[2].set_color(GREEN)

        self.play(Write(grads), Write(grads_text))
        self.wait(0.3)
        # self.bring_to_back(threshold_line)
        plane.set_z_index(-2)
        threshold_line.set_z_index(-1)
        self.play(Create(threshold_line), Write(threshold_text))
        # self.bring_to_back(plane)
        self.wait(1)

        update_grads_move_line(1.68, 1) # 1.6
        self.wait(0.5)
        update_grads_move_line(1.94, 0.8)
        self.wait(0.5)
        update_grads_move_line(2.17, 0.6)
        self.wait(0.5)
        update_grads_move_line(2.3, 0.5)
        self.wait(0.5)
        update_grads_move_line(2.4, 0.2)
        self.wait(0.5)
        update_grads_move_line(2.47, 0.2)
        self.wait(0.5)
        update_grads_move_line(2.57, 0.1)
        # self.wait(0.5)
        # update_grads_move_line(2.6, 0.05)
        self.wait(3)
        

        #######################################################
        # print(len(y_hat_simplified), len(y_hat_with_weights))
        
        y_hat = MathTex("\\hat{y}", "=", "\\beta_1", "+", "\\beta_2", "x_1", "+", "\\beta_3", "x_2", "+", "\\beta_4", "x_3")
        y_hat_with_weights = MathTex("\\hat{y}", "=", f"({round(betas_val_trackers[0].get_value(), 2)})", "+", f"({round(betas_val_trackers[1].get_value(), 2)})", "x_1", "+", f"({round(betas_val_trackers[2].get_value(), 2)})", "x_2", "+", f"({round(betas_val_trackers[3].get_value(), 2)})", "x_3")
        y_hat[5].set_color(BLUE_C)
        y_hat[8].set_color(RED_B)
        y_hat[11].set_color(GREEN_C)

        y_hat_with_weights[5].set_color(BLUE_C)
        y_hat_with_weights[8].set_color(RED_B)
        y_hat_with_weights[11].set_color(GREEN_C)

        y_hat.move_to(ORIGIN).shift(DOWN)
        y_hat_with_weights.move_to(ORIGIN).shift(DOWN)

        self.play(
            # Transform(MSE[5], y_hat[0]),
            FadeOut(MSE),
            FadeOut(grads),
            FadeOut(background),
            FadeOut(plane),
            FadeOut(func_plt),
            FadeOut(line_dot),
            FadeOut(grads_text),
            FadeOut(threshold_line),
            FadeOut(threshold_text),
            FadeOut(MSE[-5]),
            betas.animate.move_to(ORIGIN).shift(UP*2)
        )

        Transform_yhat = [ReplacementTransform(y_hat[i], y_hat_with_weights[i]) for i in range(len(y_hat_with_weights))]
        # self.play(
        #     Transform(betas[0], y_hat[2])
        #     Transform(betas[1], y_hat[5])
        #     Transform(betas[2], y_hat[8])
        #     Transform(betas[3], y_hat[11])
        # )
        self.play(Write(y_hat))
        self.wait(1)
        self.play(
            *Transform_yhat,
            ReplacementTransform(betas[0][0], y_hat_with_weights[2]),
            ReplacementTransform(betas[1][0], y_hat_with_weights[4]),
            ReplacementTransform(betas[2][0], y_hat_with_weights[7]),
            ReplacementTransform(betas[3][0], y_hat_with_weights[10]),
            FadeOut(betas[0][1]),
            FadeOut(betas[1][1]),
            FadeOut(betas[2][1]),
            FadeOut(betas[3][1]),
        )
        self.wait(0.4)
        y_hat_simplified = MathTex("\\hat{y}", "=", f"{round(betas_val_trackers[0].get_value(), 2)}", "+", f"{round(betas_val_trackers[1].get_value(), 2)}",  "x_1", "{}", f"{round(betas_val_trackers[2].get_value(), 2)}", "x_2", "{}", f"{round(betas_val_trackers[3].get_value(), 2)}", "x_3").move_to(y_hat_with_weights)

        y_hat_simplified[5].set_color(BLUE_C)
        y_hat_simplified[8].set_color(RED_B)
        y_hat_simplified[11].set_color(GREEN_C)

        Transform_yhat = [ReplacementTransform(y_hat_with_weights[i], y_hat_simplified[i]) for i in range(len(y_hat_simplified))]
        self.play(*Transform_yhat)

        self.wait(0.4)

        self.play(FadeOut(y_hat_with_weights), run_time=0.001)

        self.play(y_hat_simplified.animate.move_to(ORIGIN))
        self.play(Circumscribe(y_hat_simplified, color=YELLOW, time_width=4))


        self.wait(6)


                                                                                                                                                                      