In [1]:
import torch
from functools import partial

show = partial(print, sep="\n", end="\n\n")

## Linear Transformations are Function Application (conceptually)

Vector-Matrix multiplication and Matrix-Matrix multiplication are linear transformations. This notebook explores linear transformations through the lens of function application and list indexing - concepts programmers have intuition about - with the goal of making the topic more approachable and concrete.

This perspective is highly relevant because linear transformations underpin neural networks, and understanding them as "functions" can demystify how models process data.

### Vector-Matrix Multiplication is Function Application

(N,) @ (N, K) => (K,)

f @ X == f(X)

Where:
- f is a row vector (shape (N,))
- X is a matrix (shape (N, K))

f acts on the columns of X

In [2]:
torch.manual_seed(0)

X = torch.randint(0, 10, (3, 3))

In [3]:
"""
What if I want to select the first row of a matrix?

That can be expressed as:
- Selecting the first row
- Selecting the first element of each column

Vector-Matrix multiplication frames the problem in the second way.

The following are equivalent:
- "Return the first row of the matrix"
- For each column, returns its 0th element
    - More accurately: (1 * col[0]) + (0 * col[1]) + (0 * col[2])
"""

f1 = torch.tensor([1, 0, 0])

result = f1 @ X

show("X:", X)
show("f1 @ X:", result)

X:
tensor([[4, 9, 3],
        [0, 3, 9],
        [7, 3, 7]])

f1 @ X:
tensor([4, 9, 3])



In [4]:
"""
What if I want to select the second row of a matrix?

This can be expressed as:
- Selecting the second row directly
- Selecting the second element from each column

Vector-Matrix multiplication frames it as the second approach.

The following are equivalent:
- "Return the second row of the matrix"
- For each column, return its 1st element
    - More accurately: (0 * col[0]) + (1 * col[1]) + (0 * col[2])
"""

f2 = torch.tensor([0, 1, 0])

result = f2 @ X

show("X:", X)
show("f2 @ X:", result)

X:
tensor([[4, 9, 3],
        [0, 3, 9],
        [7, 3, 7]])

f2 @ X:
tensor([0, 3, 9])



In [5]:
"""
What if I want to double each element and sum each column?

Vector-Matrix multiplication frames this as selecting the value of each column and combining them via addition (AKA the dot product).

The following are equivalent:
- "Double each element and sum each column"
- For each column, return (2 * col[0]) + (2 * col[1]) + (2 * col[2])
"""

f3 = torch.tensor([2, 2, 2])

result = f3 @ X

show("X:", X)
show("f3 @ X:", result)

X:
tensor([[4, 9, 3],
        [0, 3, 9],
        [7, 3, 7]])

f3 @ X:
tensor([22, 30, 38])



### Matrix-Matrix Multiplication is Batch Function Application

(M, N) @ (N, K) => (M, K)

F @ X == F(X)

Where:
- F is a matrix (shape (M, N))
- X is a matrix (shape (N, K))

Each row of F is a vector (function!) that acts on the columns of X

In [6]:
"""
What if I want to apply multiple operations at once?

For demonstration, let's apply all the operations from the Vector-Matrix section simultaneously. This can be done easily by stacking row vectors into a matrix.

Each row of the first matrix applies its transformation to all columns of the second.
"""

F = torch.stack(
    [
        f1,
        f2,
        f3,
    ]
)

result = F @ X

show("X:", X)
show("F:", F)
show("F @ X:", result)

X:
tensor([[4, 9, 3],
        [0, 3, 9],
        [7, 3, 7]])

F:
tensor([[1, 0, 0],
        [0, 1, 0],
        [2, 2, 2]])

F @ X:
tensor([[ 4,  9,  3],
        [ 0,  3,  9],
        [22, 30, 38]])

