# Linear Algebra Intuition

### Loading Libraries

In [1]:
# Numerical Computing
import numpy as np

# Data Manipulation
import pandas as pd

# Data Visualization
import plotly.express as px
import plotly.graph_objs as go
import matplotlib.pyplot as plt

# PyTorch
import torch

In [2]:
np.random.seed(42)

### Helper Functions

In [3]:
def sample_grid(M=500, x_max=2.0):
    ii, jj = torch.meshgrid(torch.linspace(-x_max, x_max, M,),
                          torch.linspace(-x_max, x_max, M))
    X_all = torch.cat([ii.unsqueeze(-1),
                     jj.unsqueeze(-1)],
                     dim=-1).view(-1, 2)
    return X_all

# This function assigns a unique color based on position
def colorizer(x, y):
    """
    Map x-y coordinates to a rgb color
    """
    r = min(1, 1-y/3)
    g = min(1, 1+y/3)
    b = 1/4 + x/16
    return (r, g, b)

In [4]:
def plot_vector_points(vector_points, names, title="Vectors", width=600, height=600, font_size=15, axis_range = [-1,1]):
    layout1= go.Layout(title=go.layout.Title(text=title,x=0.5),
            xaxis={'title':'x', 'range':axis_range},
            yaxis={'title':'y','range':axis_range})
    point_plots = []
    for (x,y), name in zip(vector_points, names):
        point_plot=[
          go.Scatter(x=[0,x],
                     y=[0,y],
                     name=name,
                    ),
        ]
        point_plots += point_plot

    fig = go.Figure(data=point_plots, layout=layout1)
    fig.update_layout(
        autosize=False,
        width=width,
        height=height,
        title_text=title,
        titlefont={"size": 20},
        legend_title=None,
        showlegend=False,
        yaxis=dict(
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        ),
        xaxis=dict(
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        )
    )
    for (x,y), name in zip(vector_points, names):
        fig.add_annotation(x=x, y=y, text=f"{name}-> [{x:.2f}, {y:.2f}]", font=dict(
            size=font_size,
            ),
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor="#636363",)
    return fig

def plot_grid(xygrids, titles, colors, figsize=(16,8)):
    assert len(xygrids)>1
    c = len(xygrids)
    fig, axes = plt.subplots(1, c, figsize=figsize, facecolor="w", sharey=True, sharex=True)
    # Plot grid points 
    # plt.figure()
    for ax, xygrid, title in zip(axes, xygrids, titles):
        ax.scatter(xygrid[:,0], xygrid[:,1], s=36, c=colors, edgecolor="none")
        # Set axis limits
        ax.grid(True)
        # ax1.axis("equal")
        ax.set_title(title)
        ax.set_aspect('equal')
    plt.show()