In [23]:
import numpy as np

class Point():
    x_max, standard_x_max = (None, None)
    x_min, standard_x_min = (None, None)
    y_max, standard_y_max = (None, None)
    y_min, standard_y_min = (None, None)

    @staticmethod
    def get_x_max():
        return Point.x_max

    @staticmethod
    def get_x_min():
        return Point.x_min
    
    @staticmethod
    def get_y_max():
        return Point.y_max
    
    @staticmethod
    def get_y_min():
        return Point.y_min

    def __init__(self, coords):
        self.coords = coords
        self.n = len(coords)
        self.standard_coords = self.coords
        self.color = self.color_by_quadrant(coords)
        self.standard_unit_i = np.array([1, 0])
        self.standard_unit_j = np.array([0, 1])
        self.unit_i = self.standard_unit_i
        self.unit_j = self.standard_unit_j
        self.update_max_min()

    def update_max_min(self):
        # account for None values
        if Point.x_max is None:
            Point.x_max = self.coords[0]
            Point.x_min = self.coords[0]
            Point.y_max = self.coords[1]
            Point.y_min = self.coords[1]
            Point.standard_x_max = Point.x_max
            Point.standard_x_min = Point.x_min
            Point.standard_y_max = Point.y_max
            Point.standard_y_min = Point.y_min
        else:
            Point.x_max = max(Point.x_max, self.coords[0])
            Point.x_min = min(Point.x_min, self.coords[0])
            Point.y_max = max(Point.y_max, self.coords[1])
            Point.y_min = min(Point.y_min, self.coords[1])
            Point.standard_x_max = Point.x_max
            Point.standard_x_min = Point.x_min
            Point.standard_y_max = Point.y_max
            Point.standard_y_min = Point.y_min

    def transform_max_min(self):
        # transform max and min values
        Point.x_max = max(Point.x_max, self.coords[0])
        Point.x_min = min(Point.x_min, self.coords[0])
        Point.y_max = max(Point.y_max, self.coords[1])
        Point.y_min = min(Point.y_min, self.coords[1])

    def transform(self, matrix):
        self.coords = matrix.dot(self.coords)
        self.unit_i = matrix.dot(self.unit_i)
        self.unit_j = matrix.dot(self.unit_j)
        self.transform_max_min()

    @staticmethod    
    def reset_max_min():
        Point.x_max = Point.standard_x_max
        Point.x_min = Point.standard_x_min
        Point.y_max = Point.standard_y_max
        Point.y_min = Point.standard_y_min

    def reset(self):
        self.coords = self.standard_coords
        self.unit_i = self.standard_unit_i
        self.unit_j = self.standard_unit_j
        Point.reset_max_min()

    def __str__(self):
        return str(self.coords)
    
    def __repr__(self):
        return str(self.coords)
    
    def color_by_quadrant(self, coords):
        if coords[0] > 0 and coords[1] > 0:
            return 'red'
        elif coords[0] < 0 and coords[1] > 0:
            return 'blue'
        elif coords[0] < 0 and coords[1] < 0:
            return 'green'
        elif coords[0] > 0 and coords[1] < 0:
            return 'yellow'
        else:
            return 'black'
        
    def get_i_unit_length(self):
        return np.linalg.norm(self.unit_i)
    
    def get_j_unit_length(self):
        return np.linalg.norm(self.unit_j)
    
    def get_plot_coords(self):
        return [self.coords[0], self.coords[1]]

In [24]:
# create a set of points. I want 100 points in a 2D space in the square [-1,1]x[-1,1]
points = []
for x in np.linspace(-1,1,10):
    for y in np.linspace(-1,1,10):
        points.append(Point(np.array([x,y])))

In [25]:
# Iterate through the points and plot them
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

def plot(points, ax, x_min = None, x_max = None, y_max = None, y_min = None):   # Notice the ax argument
    ax.clear()
    for point in points:
        ax.scatter(point.get_plot_coords()[0], point.get_plot_coords()[1], color=point.color)
    ax.grid()
    
    if (x_min is None):
        x_max = Point.get_x_max()
        x_min = Point.get_x_min()
        y_max = Point.get_y_max()
        y_min = Point.get_y_min()
    
    ax_min = min(x_min, y_min)
    ax_max = max(x_max, y_max)

    ax.set_xlim(ax_min-(ax_max-ax_min)/2, ax_max+(ax_max-ax_min)/2)
    ax.set_ylim(ax_min-(ax_max-ax_min)/2, ax_max+(ax_max-ax_min)/2)
    
    ax.axhline(0, color='black')
    ax.axvline(0, color='black')

    sample_point = points[0]
    scaling_value = 1
    if not np.array_equal(sample_point.unit_i, sample_point.standard_unit_i):
        scaling_value = max(sample_point.get_i_unit_length(), sample_point.get_j_unit_length())

    ax.quiver(0, 0, sample_point.standard_unit_i[0]*scaling_value, sample_point.standard_unit_i[1], color='purple', angles='xy', scale_units='xy', scale = 1)
    ax.quiver(0, 0, sample_point.standard_unit_j[0], sample_point.standard_unit_j[1]*scaling_value, color='orange', angles='xy', scale_units='xy', scale = 1)

    if not np.array_equal(sample_point.unit_i, sample_point.standard_unit_i):
        ax.quiver(0, 0, sample_point.unit_i[0], sample_point.unit_i[1], color='pink', angles='xy', scale_units='xy', scale = 1)
        ax.quiver(0, 0, sample_point.unit_j[0], sample_point.unit_j[1], color='cyan', angles='xy', scale_units='xy', scale = 1)

def reset(points):
    for point in points:
        point.reset()

def transform(points, matrix):
    for point in points:
        point.transform(matrix)

def animate_transform(points, matrix, steps):
    # derive identity from current unit vectors
    identity = np.array([points[0].unit_i, points[0].unit_j]).T
    delta_matrix = (matrix - identity) / steps
    fig, ax = plt.subplots(figsize=(5,5))

    # calculate final max and min values and pass them to the plot function by applying matrix to class variables
    final_x_max = matrix.dot(np.array([Point.get_x_max(), Point.get_y_max()]))[0]
    final_x_min = matrix.dot(np.array([Point.get_x_min(), Point.get_y_min()]))[0]
    final_y_max = matrix.dot(np.array([Point.get_x_max(), Point.get_y_max()]))[1]
    final_y_min = matrix.dot(np.array([Point.get_x_min(), Point.get_y_min()]))[1]

    # Swap min and max if necessary
    if final_x_max < final_x_min:
        temp = final_x_max
        final_x_max = final_x_min
        final_x_min = temp
    if final_y_max < final_y_min:
        temp = final_y_max
        final_y_max = final_y_min
        final_y_min = temp

    def init():
        plot(points, ax, final_x_min, final_x_max, final_y_max, final_y_min)   # Pass the ax object

    def update(i):
        if i > 0:
            ax.clear()
        current_matrix = identity + i * delta_matrix
        transform(points, current_matrix)
        plot(points, ax, final_x_min, final_x_max, final_y_max, final_y_min)  # Pass the ax object
        reset(points)

    ani = FuncAnimation(fig, update, frames=steps+1, init_func=init, repeat=False)
    plt.close(fig)  # Close the figure after displaying the animation
    display(HTML(ani.to_jshtml()))


# Linear Transformation Tool

**3Blue1Brown's Linear Transformation Video**

Link: <https://www.youtube.com/watch?v=kYB8IZa5AuE&list=PLZHQObOWTQDPD3MizzM2xVFitgF8hE_ab&index=3>

In [28]:
import math
transform_matrix = np.array([[math.sqrt(2)/2, -math.sqrt(2)/2], [math.sqrt(2)/2, math.sqrt(2)/2]]).T
animate_transform(points, transform_matrix, 10)