In [2]:
import manim 
import torch 
from matplotlib import pyplot as plt
import numpy as np 
import torch
import os 


In [3]:
def setup_plot():
    ## the default setting of plot
    plt.style.use(['dark_background'])
    plt.rc('axes', facecolor='k')
    plt.rc('figure', facecolor='k')
    plt.rc('figure', figsize=(10,10), dpi=100)
    plt.set_cmap('YlOrRd')

In [4]:
from torch import nn 
from torch import optim


class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(2, 2)


    def forward(self, x):
        x = self.l1(x)
        return x

In [5]:
linear_model = LinearModel()
linear_model.load_state_dict(torch.load(r'D:\codebox\activation_functions_deeper\linear_model2.pth'))

<All keys matched successfully>

In [6]:
# making the dots
dot_counts = 100
d = torch.randn(dot_counts, 2) / 2  # dot_counts number of 2d points, scale down to limit max distance to 1
l = torch.zeros(dot_counts)  # label for dots

# dividing data into 4 quadrants
l[d[:,0] < 0] = 1   
l[d[:,1] < 0] = 2
l[(d[:,0] > 0) & (d[:,1] < 0)] = 3

π = torch.pi
t = torch.linspace(0, 2*π, 100)
circle = torch.stack((torch.cos(t), torch.sin(t)), dim=1)

In [7]:
with torch.no_grad():
    predictions = linear_model(d)
    transformed_circle = linear_model(circle)

In [26]:
from manim import Scene, always_redraw, Axes, Dot, Circle, VMobject, BLUE, RED, GREEN, YELLOW, WHITE
from manim.animation.transform import Transform
from manim.animation.creation import Create
from manim import *

DOT_COLORS = (manim.YELLOW_A, manim.YELLOW_D, manim.ORANGE, manim.PURE_RED)

In [29]:

class LinearTransformationScene(MovingCameraScene):
    def construct(self):
        axes = Axes(
            x_range=[-3, 3, 1],  # Adding step size to show numbers on the x-axis
            y_range=[-3, 3, 1],  # Adding step size to show numbers on the y-axis
            x_length=8,
            y_length=8,
            tips=False,
            axis_config={"color": WHITE, "include_numbers": True},  # Including numbers on both axes
        )
        self.add(axes)

        # Original points
        dots = [Dot(axes.c2p(*point.numpy()), color=DOT_COLORS[int(l[i])]) for i, point in enumerate(d)]
        circle_shape = Circle(radius=1, color=YELLOW).move_to(axes.c2p(0, 0))

        # Transformed points
        transformed_d = linear_model(d)
        transformed_circle = linear_model(circle)
        transformed_dots = [
            Dot(axes.c2p(*point.detach().numpy()), color=DOT_COLORS[int(l[i])], radius=0.04) for i, point in enumerate(transformed_d)
        ]
        
        # Adjusting the input to VMobject to be 3D by adding a zero to the third dimension
        transformed_circle_3d = torch.cat((transformed_circle, torch.zeros(transformed_circle.size(0), 1)), dim=1).detach().numpy()
        transformed_circle_shape = VMobject(color=YELLOW).set_points_as_corners([*transformed_circle_3d])
        
        
        self.play(Create(circle_shape))
        self.play(*[Create(dot) for dot in dots], run_time=0.5)
        self.wait(2)

        # Transforming shapes
        self.play(
            Transform(circle_shape, transformed_circle_shape),
            *[Transform(dots[i], transformed_dots[i]) for i in range(len(dots))]
        )

        # Rescaling axes for more precision after zoom
        new_axes = Axes(
            x_range=[-1.5, 1.5, 0.5],  # More precise step size
            y_range=[-1.5, 1.5, 0.5],  # More precise step size
            x_length=8,
            y_length=8,
            tips=False,
            axis_config={"color": WHITE, "include_numbers": True},  # Smaller numbers
        ).move_to(transformed_circle_shape.get_center())
        
        # Zooming in and adjusting the camera focus on the transformed objects
        self.play(Transform(axes, new_axes),
                  self.camera.frame.animate.scale(0.5).move_to(transformed_circle_shape.get_center()))
        self.wait(2)

In [30]:
%manim -ql -v WARNING LinearTransformationScene

                                                                                                    

In [44]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

def update_dot(num, data, dot):
    dot.set_data(data[0, num], data[1, num])
    return dot,

point1 = np.array([0, 0])
point2 = np.array([1, 1])
data = np.stack([np.linspace(point1[0], point2[0], 300), np.linspace(point1[1], point2[1], 300)], axis=0)

fig, ax = plt.subplots()
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
dot, = ax.plot([], [], 'ro')

ani = animation.FuncAnimation(fig, update_dot, frames=300, fargs=(data, dot), interval=10)
plt.close(fig)  # Prevents duplicate display
HTML(ani.to_jshtml())  # Display the animation as HTML using JavaScript
