# Matrix Multiplication: Building Intuition

This notebook demonstrates matrix multiplication with PyTorch and builds intuition for how dimensions work when multiplying matrices.

We'll cover:
1. Basic Matrix Multiplication
2. How to interpret Matrix Multiplication as Dot Products
3. Dimension Rules and Compatibility
4. Special Cases like Matrix-Vector Multiplication
5. Real-world Applications

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from typing import Tuple, List

# Set better visualization defaults
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Create a nicer colormap for visualizing matrices
colors = [(0.8, 0.8, 1), (0.1, 0.3, 0.8)]  # Light blue to darker blue
cmap = LinearSegmentedColormap.from_list("custom_blue", colors, N=100)

## 1. Visualizing Matrices

Let's create some helper functions to visualize matrices.

In [None]:
def visualize_matrix(matrix: torch.Tensor, title: str = "") -> None:
    """
    Visualize a matrix as a heatmap.
    
    Args:
        matrix: PyTorch tensor to visualize
        title: Optional title for the plot
    """
    # Convert to numpy for matplotlib
    matrix_np = matrix.detach().cpu().numpy()
    
    plt.figure(figsize=(7, 7))
    plt.imshow(matrix_np, cmap=cmap)
    plt.colorbar(shrink=0.8)
    
    # Add grid lines
    plt.grid(which='minor', color='w', linestyle='-', linewidth=0.5)
    
    # Add row and column indices
    for i in range(matrix_np.shape[0]):
        for j in range(matrix_np.shape[1]):
            plt.text(j, i, f"{matrix_np[i, j]:.1f}", 
                     ha="center", va="center", 
                     color="black" if matrix_np[i, j] < 0.7 else "white")
    
    # Add dimension annotations
    plt.title(f"{title}\nShape: {matrix_np.shape}")
    plt.xlabel(f"Columns (n={matrix_np.shape[1]})")
    plt.ylabel(f"Rows (m={matrix_np.shape[0]})")
    plt.tight_layout()
    
def visualize_matrix_multiplication(A: torch.Tensor, B: torch.Tensor) -> None:
    """
    Visualize matrix multiplication A @ B with dimensions.
    
    Args:
        A: First matrix (m × n)
        B: Second matrix (n × p)
    """
    # Check compatibility
    if A.shape[1] != B.shape[0]:
        raise ValueError(f"Incompatible dimensions: A is {A.shape}, B is {B.shape}")
    
    # Perform the multiplication
    C = A @ B
    
    # Create figure with 3 subplots
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    
    # Plot matrices
    matrices = [A, B, C]
    titles = [
        f"Matrix A\n{A.shape[0]}×{A.shape[1]}", 
        f"Matrix B\n{B.shape[0]}×{B.shape[1]}",
        f"Result C = A @ B\n{C.shape[0]}×{C.shape[1]}"
    ]
    
    for i, (matrix, title) in enumerate(zip(matrices, titles)):
        matrix_np = matrix.detach().cpu().numpy()
        im = axs[i].imshow(matrix_np, cmap=cmap)
        axs[i].set_title(title)
        
        # Add text annotations
        for r in range(matrix_np.shape[0]):
            for c in range(matrix_np.shape[1]):
                axs[i].text(c, r, f"{matrix_np[r, c]:.1f}", 
                           ha="center", va="center", 
                           color="black" if matrix_np[r, c] < 0.7 else "white")
    
    # Add a shared colorbar
    fig.colorbar(im, ax=axs, shrink=0.6)
    
    # Add the operation text between plots
    plt.figtext(0.31, 0.5, "@", fontsize=24)
    plt.figtext(0.64, 0.5, "=", fontsize=24)
    
    # Add dimension explanation
    m, n = A.shape
    n_check, p = B.shape
    plt.suptitle(f"Matrix Multiplication: ({m}×{n}) @ ({n_check}×{p}) → ({m}×{p})\n"
                f"The inner dimensions must match: {n} = {n_check}", fontsize=14)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)

## 2. Basic Matrix Multiplication Example

Let's start with a basic example of matrix multiplication.

In [None]:
# Create two matrices
A = torch.tensor([
    [1., 2., 3.], 
    [4., 5., 6.]
])  # 2×3

B = torch.tensor([
    [7., 8.], 
    [9., 10.], 
    [11., 12.]
])  # 3×2

print(f"Matrix A shape: {A.shape}")
print(f"Matrix B shape: {B.shape}")

In [None]:
# Visualize matrix A
visualize_matrix(A, "Matrix A")

In [None]:
# Visualize matrix B
visualize_matrix(B, "Matrix B")

In [None]:
# Visualize the multiplication
visualize_matrix_multiplication(A, B)

## 3. Matrix Multiplication as Dot Products

Each element of the result matrix can be computed as a dot product of a row from the first matrix and a column from the second matrix.

In [None]:
def demonstrate_dot_product():
    """Demonstrate how matrix multiplication can be understood as dot products."""
    # Create two matrices
    A = torch.tensor([
        [1., 2., 3.], 
        [4., 5., 6.]
    ])  # 2×3
    
    B = torch.tensor([
        [7., 8.], 
        [9., 10.], 
        [11., 12.]
    ])  # 3×2
    
    # Compute the result
    C = A @ B
    
    # Create figure
    plt.figure(figsize=(12, 10))
    
    # Get numpy versions
    A_np = A.numpy()
    B_np = B.numpy()
    C_np = C.numpy()
    
    # Draw matrices with annotations
    plt.text(0, 0.9, "Matrix A (2×3):", fontsize=14, fontweight='bold')
    plt.text(0.5, 0.9, "Matrix B (3×2):", fontsize=14, fontweight='bold')
    plt.text(0.8, 0.9, "Result C (2×2):", fontsize=14, fontweight='bold')
    
    # Draw A
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            plt.text(0 + j*0.1, 0.8 - i*0.1, f"{A_np[i,j]:.0f}", fontsize=12, 
                    bbox=dict(facecolor='lightblue', alpha=0.5))
    
    # Draw B
    for i in range(B.shape[0]):
        for j in range(B.shape[1]):
            plt.text(0.5 + j*0.1, 0.8 - i*0.1, f"{B_np[i,j]:.0f}", fontsize=12,
                    bbox=dict(facecolor='lightgreen', alpha=0.5))
    
    # Draw C
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            plt.text(0.8 + j*0.1, 0.8 - i*0.1, f"{C_np[i,j]:.0f}", fontsize=12,
                    bbox=dict(facecolor='coral', alpha=0.5))
    
    # Create a more visual demonstration
    plt.text(0, 0.6, "Visual Demonstration of C[0,0] Calculation:", fontsize=14, fontweight='bold')
    
    # Draw the row from A
    plt.text(0, 0.55, "Row 0 from A:", fontsize=12)
    for j in range(A.shape[1]):
        plt.text(0.2 + j*0.1, 0.55, f"{A_np[0,j]:.0f}", fontsize=12, 
                bbox=dict(facecolor='lightblue', alpha=0.5))
    
    # Draw the column from B
    plt.text(0, 0.5, "Column 0 from B:", fontsize=12)
    for i in range(B.shape[0]):
        plt.text(0.2 + i*0.1, 0.5, f"{B_np[i,0]:.0f}", fontsize=12,
                bbox=dict(facecolor='lightgreen', alpha=0.5))
    
    # Element-wise multiplication
    plt.text(0, 0.45, "Element-wise multiplication:", fontsize=12)
    products = [A_np[0,i] * B_np[i,0] for i in range(A.shape[1])]
    for i, prod in enumerate(products):
        plt.text(0.2 + i*0.1, 0.45, f"{prod:.0f}", fontsize=12,
                bbox=dict(facecolor='yellow', alpha=0.5))
    
    # Sum
    plt.text(0, 0.4, "Sum of products:", fontsize=12)
    plt.text(0.2, 0.4, f"{sum(products):.0f} = C[0,0]", fontsize=12,
            bbox=dict(facecolor='coral', alpha=0.5))
    
    # Explain one cell calculation in detail - C[0,0]
    plt.text(0, 0.3, "Calculating C[0,0]:", fontsize=14, fontweight='bold')
    plt.text(0, 0.25, "C[0,0] = A[0,0]×B[0,0] + A[0,1]×B[1,0] + A[0,2]×B[2,0]", fontsize=12)
    plt.text(0, 0.2, f"C[0,0] = {A_np[0,0]}×{B_np[0,0]} + {A_np[0,1]}×{B_np[1,0]} + {A_np[0,2]}×{B_np[2,0]}", fontsize=12)
    plt.text(0, 0.15, f"C[0,0] = {A_np[0,0]*B_np[0,0]} + {A_np[0,1]*B_np[1,0]} + {A_np[0,2]*B_np[2,0]} = {C_np[0,0]}", fontsize=12)
    
    # Explain one cell calculation in detail - C[1,1]
    plt.text(0, 0.1, "Calculating C[1,1]:", fontsize=14, fontweight='bold')
    plt.text(0, 0.05, "C[1,1] = A[1,0]×B[0,1] + A[1,1]×B[1,1] + A[1,2]×B[2,1]", fontsize=12)
    plt.text(0, 0.0, f"C[1,1] = {A_np[1,0]}×{B_np[0,1]} + {A_np[1,1]}×{B_np[1,1]} + {A_np[1,2]}×{B_np[2,1]}", fontsize=12)
    plt.text(0, -0.05, f"C[1,1] = {A_np[1,0]*B_np[0,1]} + {A_np[1,1]*B_np[1,1]} + {A_np[1,2]*B_np[2,1]} = {C_np[1,1]}", fontsize=12)
    
    # Add interpretation
    plt.text(0, -0.15, "Intuition: Each element C[i,j] is the dot product of row i from A and column j from B", 
             fontsize=14, fontweight='bold', bbox=dict(facecolor='lightgray', alpha=0.2))
    
    plt.axis('off')
    plt.tight_layout()

# Call the function to demonstrate dot product interpretation
demonstrate_dot_product()

## 4. Matrix Dimension Rules and Compatibility

Let's explore which matrices can be multiplied together and understand the dimension rules.

In [None]:
print("Matrix Multiplication Dimension Rules\n")
print("For multiplication A @ B to be valid:")
print("- A must be of shape (m × n)")
print("- B must be of shape (n × p)")
print("- The inner dimensions (n) must match")
print("- The result C will be of shape (m × p)")
print("\nExamples of compatible dimensions:")

examples = [
    ((2, 3), (3, 4)),
    ((5, 2), (2, 3)),
    ((1, 4), (4, 10)),
    ((10, 7), (7, 1))
]

for (m, n), (n_check, p) in examples:
    print(f"- ({m} × {n}) @ ({n_check} × {p}) → ({m} × {p})")

print("\nExamples of incompatible dimensions:")

incompatible = [
    ((2, 3), (4, 5)),  # Inner dimensions don't match: 3 ≠ 4
    ((5, 2), (3, 4)),  # Inner dimensions don't match: 2 ≠ 3
]

for (m, n), (n_check, p) in incompatible:
    print(f"- ({m} × {n}) @ ({n_check} × {p}) → Error! Inner dimensions don't match: {n} ≠ {n_check}")

In [None]:
# Demonstrate a compatible example
print("Demonstrating a compatible example:")
A = torch.rand(2, 3)
B = torch.rand(3, 4)
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")
C = A @ B
print(f"C = A @ B shape: {C.shape}")

# Visualize the multiplication
visualize_matrix_multiplication(A, B)

In [None]:
# Try an incompatible example
print("Demonstrating an incompatible example:")
A = torch.rand(2, 3)
D = torch.rand(4, 5)
print(f"A shape: {A.shape}")
print(f"D shape: {D.shape}")
try:
    E = A @ D
    print(f"E = A @ D shape: {E.shape}")
except RuntimeError as e:
    print(f"Error: {e}")

## 5. Special Case: Matrix-Vector Multiplication

A common operation is multiplying a matrix by a vector (column vector).

In [None]:
# Matrix-vector multiplication (vector is treated as a column vector)
A = torch.tensor([
    [1., 2., 3.], 
    [4., 5., 6.]
])  # 2×3

v = torch.tensor([
    [7.], 
    [8.], 
    [9.]
])  # 3×1 (column vector)

print(f"Matrix A shape: {A.shape}")
print(f"Vector v shape: {v.shape}")

# Visualize the multiplication
visualize_matrix_multiplication(A, v)

### Alternative Vector Notation

In PyTorch, we can also work with 1D tensors for vectors, which automatically adjust to the correct dimension during matrix multiplication:

In [None]:
# Using 1D tensor for the vector
A = torch.tensor([
    [1., 2., 3.], 
    [4., 5., 6.]
])  # 2×3

v_1d = torch.tensor([7., 8., 9.])  # 1D tensor with 3 elements

print(f"Matrix A shape: {A.shape}")
print(f"Vector v_1d shape: {v_1d.shape}")

# Perform matrix-vector multiplication
result = A @ v_1d

print(f"Result shape: {result.shape}")
print(f"Result: {result}")

## 6. Real-world Applications of Matrix Multiplication

Matrix multiplication is used in many applications. Let's explore a few examples:

### 6.1 Image Transformation

Images can be represented as matrices, and transformations like rotation can be performed using matrix multiplication.

In [None]:
def create_rotation_matrix(angle_degrees):
    angle_radians = np.radians(angle_degrees)
    cos_theta = np.cos(angle_radians)
    sin_theta = np.sin(angle_radians)
    
    # 2x2 rotation matrix
    rotation_matrix = torch.tensor([
        [cos_theta, -sin_theta],
        [sin_theta, cos_theta]
    ])
    
    return rotation_matrix

# Create a simple 2D point cloud in the shape of a square
points = torch.tensor([
    [-1, -1],  # bottom-left
    [1, -1],   # bottom-right
    [1, 1],    # top-right
    [-1, 1],   # top-left
    [-1, -1]   # connect back to the first point
], dtype=torch.float32).T  # Transpose to get a 2×5 matrix where each column is a point

# Create a rotation matrix for 45 degrees
rotation = create_rotation_matrix(45)

# Apply the rotation to the points
rotated_points = rotation @ points

# Plot the original and rotated shapes
plt.figure(figsize=(8, 8))
plt.plot(points[0], points[1], 'b-o', label='Original')
plt.plot(rotated_points[0], rotated_points[1], 'r-o', label='Rotated 45°')
plt.grid(True)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.axis('equal')
plt.legend()
plt.title('2D Rotation Using Matrix Multiplication')
plt.xlabel('X')
plt.ylabel('Y')

# Print the matrices
print("Rotation Matrix (45 degrees):")
print(rotation)
print("\nOriginal Points (each column is a point):")
print(points)
print("\nRotated Points:")
print(rotated_points)

### 6.2 Simple Neural Network Layer

In neural networks, each layer's operation can be represented as a matrix multiplication followed by a non-linear activation function.

In [None]:
def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

# Input features (batch_size x input_dim)
batch_size = 3
input_dim = 4
hidden_dim = 5
output_dim = 2

# Create random input data
X = torch.rand(batch_size, input_dim)
print(f"Input shape: {X.shape}")

# Create random weights for a simple 2-layer neural network
W1 = torch.rand(input_dim, hidden_dim)  # First layer weights
b1 = torch.rand(hidden_dim)            # First layer bias

W2 = torch.rand(hidden_dim, output_dim)  # Second layer weights
b2 = torch.rand(output_dim)            # Second layer bias

print(f"W1 shape: {W1.shape}")
print(f"W2 shape: {W2.shape}")

# Forward pass
# First layer: X @ W1 + b1
hidden = X @ W1 + b1
hidden_activated = sigmoid(hidden)
print(f"Hidden layer output shape: {hidden_activated.shape}")

# Second layer: hidden_activated @ W2 + b2
output = hidden_activated @ W2 + b2
output_activated = sigmoid(output)
print(f"Output shape: {output_activated.shape}")

# Visualize the network architecture
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.imshow(X @ W1, cmap='viridis')
plt.title(f"First Layer Pre-Activation\nX @ W1: {batch_size}×{hidden_dim}")
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(hidden_activated @ W2, cmap='viridis')
plt.title(f"Second Layer Pre-Activation\nH @ W2: {batch_size}×{output_dim}")
plt.colorbar()

plt.tight_layout()
plt.suptitle("Neural Network Layer Operations as Matrix Multiplications", y=1.05)
plt.tight_layout()

## 7. Conclusion

Matrix multiplication is one of the most fundamental operations in linear algebra. It forms the basis for many algorithms in machine learning, computer graphics, and numerical computing.

Key points to remember:

1. For A @ B to be valid, the inner dimensions must match
2. If A is (m×n) and B is (n×p), the result will be (m×p)
3. Each element in the result is a dot product between a row of A and a column of B
4. Matrix multiplication is associative: (A @ B) @ C = A @ (B @ C)
5. Matrix multiplication is NOT commutative: A @ B ≠ B @ A (in general)

Understanding matrix multiplication is crucial for higher-dimensional tensor operations used in deep learning and computer vision.