In [6]:
import manim 
import torch 
from matplotlib import pyplot as plt
import numpy as np 
import torch
import os 


In [7]:
def setup_plot():
    ## the default setting of plot
    plt.style.use(['dark_background'])
    plt.rc('axes', facecolor='k')
    plt.rc('figure', facecolor='k')
    plt.rc('figure', figsize=(10,10), dpi=100)
    plt.set_cmap('YlOrRd')

In [8]:
from torch import nn 
from torch import optim


class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(2, 2, bias=False)


    def forward(self, x):
        x = self.l1(x)
        return x

In [9]:
linear_model = LinearModel()
linear_model.load_state_dict(torch.load(r'D:\codebox\activation_functions_deeper\linear_model.pth'))

<All keys matched successfully>

In [10]:
# making the dots
dot_counts = 100
d = torch.randn(dot_counts, 2) / 2  # dot_counts number of 2d points, scale down to limit max distance to 1
l = torch.zeros(dot_counts)  # label for dots

# dividing data into 4 quadrants
l[d[:,0] < 0] = 1   
l[d[:,1] < 0] = 2
l[(d[:,0] > 0) & (d[:,1] < 0)] = 3

π = torch.pi
t = torch.linspace(0, 2*π, 100)
circle = torch.stack((torch.cos(t), torch.sin(t)), dim=1)

In [11]:
with torch.no_grad():
    predictions = linear_model(d)
    transformed_circle = linear_model(circle)

In [12]:
from manim import Scene, always_redraw, Axes, Dot, Circle, VMobject, BLUE, RED, GREEN, YELLOW, WHITE
from manim.animation.transform import Transform
from manim.animation.creation import Create
from manim import *

DOT_COLORS = (manim.YELLOW_A, manim.YELLOW_D, manim.ORANGE, manim.PURE_RED)

In [13]:

class LinearTransformationScene(MovingCameraScene):
    def construct(self):
        axes = Axes(
            x_range=[-3, 3, 1],  # Adding step size to show numbers on the x-axis
            y_range=[-3, 3, 1],  # Adding step size to show numbers on the y-axis
            x_length=8,
            y_length=8,
            tips=False,
            axis_config={"color": WHITE, "include_numbers": True},  # Including numbers on both axes
        )
        self.add(axes)

        # Original points
        dots = [Dot(axes.c2p(*point.numpy()), color=DOT_COLORS[int(l[i])]) for i, point in enumerate(d)]
        circle_shape = Circle(radius=1, color=YELLOW).move_to(axes.c2p(0, 0))

        # Transformed points
        transformed_d = linear_model(d)
        transformed_circle = linear_model(circle)
        transformed_dots = [
            Dot(axes.c2p(*point.detach().numpy()), color=DOT_COLORS[int(l[i])], radius=0.04) for i, point in enumerate(transformed_d)
        ]
        
        # Adjusting the input to VMobject to be 3D by adding a zero to the third dimension
        transformed_circle_3d = torch.cat((transformed_circle, torch.zeros(transformed_circle.size(0), 1)), dim=1).detach().numpy()
        transformed_circle_shape = VMobject(color=YELLOW).set_points_as_corners([*transformed_circle_3d])
        
        
        self.play(Create(circle_shape))
        self.play(*[Create(dot) for dot in dots], run_time=0.5)
        self.wait(2)

        # Transforming shapes
        self.play(
            Transform(circle_shape, transformed_circle_shape),
            *[Transform(dots[i], transformed_dots[i]) for i in range(len(dots))]
        )

        # Rescaling axes for more precision after zoom
        new_axes = Axes(
            x_range=[-1.5, 1.5, 0.5],  # More precise step size
            y_range=[-1.5, 1.5, 0.5],  # More precise step size
            x_length=8,
            y_length=8,
            tips=False,
            axis_config={"color": WHITE, "include_numbers": True},  # Smaller numbers
        ).move_to(transformed_circle_shape.get_center())
        
        # Zooming in and adjusting the camera focus on the transformed objects
        self.play(Transform(axes, new_axes),
                  self.camera.frame.animate.scale(0.5).move_to(transformed_circle_shape.get_center()))
        self.wait(2)

In [None]:
%manim -ql -v WARNING LinearTransformationScene2222

                                                                                                    

In [14]:

# Assuming necessary imports, variable initializations, and definitions are done before this class
# Applying ReLU activation function to the transformed points and circle
relu = torch.nn.ReLU()
transformed_d = linear_model(d)
transformed_circle = linear_model(circle)
transformed_d_relu = relu(transformed_d)
transformed_circle_relu = relu(transformed_circle)

In [85]:
# for param in linear_model.parameters():
#     print(param)
for name, param in linear_model.named_parameters():
    if name == 'weight':
        print(param.data)


TypeError: 'generator' object is not subscriptable

In [78]:
from manim import *


class ReLUTransformationScene(Scene):
    def construct(self):
        axes = Axes(
            x_range=[-2, 2, 1],
            y_range=[-2, 2, 1],
            x_length=10,
            y_length=7,
            tips=False,
            axis_config={"color": WHITE, "include_numbers": True},
        )
        # Original points
        dots = [Dot(axes.c2p(*point.numpy()), color=DOT_COLORS[int(l[i])]) for i, point in enumerate(d)]
        first_dot_coords = dots[0].get_center()
        circle_shape = Circle(radius=1, color=YELLOW).move_to(axes.c2p(0, 0))
        group = VGroup(axes, *dots[1:], circle_shape)
        self.add(group)
        group.save_state()
        self.wait(1)

        ### SHOW X AND Y values of a Point #################
        # Show x and y values of the first dot on new 1D axes and display the values on them
        # self.play(group.animate.set_opacity(0.3), run_time=0.5)
        highlighted_dot = Dot(first_dot_coords, color=WHITE).scale(3).move_to(ORIGIN)
        self.play(Transform(dots[0], highlighted_dot))
        self.play(FadeOut(group))

        x_val, y_val = np.round(d[0].numpy(), 4)
        x_axis = NumberLine(x_range=[min(x_val-1, -3), max(x_val+1, 3), 1], length=4, color=BLUE, include_numbers=True).next_to(highlighted_dot, UP)
        y_axis = NumberLine(x_range=[min(y_val-1, -3), max(y_val+1, 3), 1], length=4, color=GREEN, include_numbers=True).next_to(highlighted_dot, DOWN)
        
        # Create labels for "X value" and "Y value"
        x_axis_label = Text("X value").next_to(x_axis, RIGHT, buff=0.1)
        y_axis_label = Text("Y value").next_to(y_axis, RIGHT, buff=0.1)
        
        # Create labels for x_val and y_val with first 4 digits
        x_val_label = MathTex(f"{x_val:.4}").next_to(x_axis, UP)
        y_val_label = MathTex(f"{y_val:.4}").next_to(y_axis, DOWN)
        
        # Create dotted lines from highlighted_dot to the corresponding x and y values on the axes
        x_line_to_axis = DashedLine(start=highlighted_dot.get_center(), end=x_axis.n2p(x_val), color=WHITE)
        y_line_to_axis = DashedLine(start=highlighted_dot.get_center(), end=y_axis.n2p(y_val), color=WHITE)
        
        # Create new dots on the axes to show the values better
        x_val_dot = Dot(x_axis.n2p(x_val), color=BLUE)
        y_val_dot = Dot(y_axis.n2p(y_val), color=GREEN)
        dot_matrix = Matrix([[x_val], [y_val]])

        self.play(FadeIn(x_axis), FadeIn(y_axis), Write(x_axis_label), Write(y_axis_label), 
                Write(x_val_label), Write(y_val_label), Create(x_line_to_axis), Create(y_line_to_axis),
                FadeIn(x_val_dot), FadeIn(y_val_dot))
        substitle_1 = Text("Each dot is composed of two values", font_size=24).to_edge(UP)
        substitle_2 = Text("Let's see what happens to the dot:", font_size=24).to_edge(UP)
        self.play(Write(substitle_1))
        self.wait(2)

        # Fade the axes, labels, and dots out
        self.play(FadeOut(x_axis), FadeOut(y_axis), FadeOut(x_axis_label), FadeOut(y_axis_label),
                FadeOut(x_val_label), FadeOut(y_val_label),
                FadeOut(x_line_to_axis), FadeOut(y_line_to_axis),
                FadeOut(x_val_dot), FadeOut(y_val_dot), 
                # FadeOut(dots[0]),
                FadeIn(dot_matrix),
                dots[0].animate.next_to(dot_matrix, LEFT),
                Transform(substitle_1, substitle_2)
        )

        # Show the weights from the linear_model as a matrix next to the dot_matrix with a "*" in between them
        weights_matrix = Matrix(linear_model.detach().numpy()).next_to(dot_matrix, RIGHT)
        multiplication_sign = MathTex("*").next_to(dot_matrix, RIGHT, buff=0.1)
        self.play(FadeIn(weights_matrix), Write(multiplication_sign))
        

        # Reset the dots to their original state before transformation
        self.wait(2)
        self.play(Restore(group))
        ###################################################

        # # Transforming the original points and circle
        # transformed_dots = [
        #     Dot(axes.c2p(*point.detach().numpy()), color=DOT_COLORS[int(l[i])]) for i, point in enumerate(transformed_d)
        # ]
        # transformed_circle_3d = torch.cat((transformed_circle, torch.zeros(transformed_circle.size(0), 1)), dim=1).detach().numpy()
        # transformed_circle_shape = VMobject(color=YELLOW).set_points_as_corners([*transformed_circle_3d])

        # # Visualizing the initial transformation
        # self.play(
        #     Transform(circle_shape, transformed_circle_shape),
        #     *[Transform(dots[i], transformed_dots[i]) for i in range(len(dots))]
        # )
        # self.wait(1)

        # # Converting the transformed points and circle with ReLU for visualization
        # transformed_dots_relu = [
        #     Dot(axes.c2p(*point.detach().numpy()), color=DOT_COLORS[int(l[i])], radius=0.05) for i, point in enumerate(transformed_d_relu)
        # ]
        # transformed_circle_3d_relu = torch.cat((transformed_circle_relu, torch.zeros(transformed_circle_relu.size(0), 1)), dim=1).detach().numpy()
        # transformed_circle_shape_relu = VMobject(color=RED).set_points_as_corners([*transformed_circle_3d_relu])

        # # Visualizing the transformation with ReLU activation
        # self.play(
        #     Transform(transformed_circle_shape, transformed_circle_shape_relu),
        #     *[Transform(transformed_dots[i], transformed_dots_relu[i]) for i in range(len(transformed_dots))]
        # )
        # self.wait(2)


In [79]:
%manim -ql -v WARNING ReLUTransformationScene

                                                                                                     

AttributeError: 'LinearModel' object has no attribute 'weight'