# 2. Matrix Multiplication in PyTorch

Matrix multiplication is essential for neural networks!
PyTorch provides efficient operations for multiplying matrices and tensors.


In [None]:
import torch


## 1. Matrix Multiplication (@ operator)

Use the @ operator or torch.matmul() for matrix multiplication!
A @ B multiplies matrices A and B.


In [None]:
# Create two matrices
A = torch.tensor([[1, 2], 
                  [3, 4]], dtype=torch.float32)

B = torch.tensor([[5, 6], 
                  [7, 8]], dtype=torch.float32)

print("Matrix A:")
print(A)
print(f"Shape: {A.shape}")
print()

print("Matrix B:")
print(B)
print(f"Shape: {B.shape}")
print()

# Matrix multiplication using @
result = A @ B
print("A @ B (matrix multiplication):")
print(result)
print(f"Shape: {result.shape}")
print()

# Using torch.matmul()
result2 = torch.matmul(A, B)
print("torch.matmul(A, B):")
print(result2)
print("(Same result!)")


## 2. Matrix-Vector Multiplication

Multiply a matrix by a vector to transform it!


In [None]:
# Matrix-vector multiplication
M = torch.tensor([[1, 2], 
                  [3, 4]], dtype=torch.float32)

v = torch.tensor([5, 6], dtype=torch.float32)

print("Matrix M:")
print(M)
print(f"Shape: {M.shape}")
print()

print("Vector v:")
print(v)
print(f"Shape: {v.shape}")
print()

# Matrix × Vector
result = M @ v
print("M @ v (matrix-vector multiplication):")
print(result)
print(f"Shape: {result.shape}")
print()

# How it works
print("How it works:")
print(f"Result[0] = M[0, 0]*v[0] + M[0, 1]*v[1] = {M[0, 0]}*{v[0]} + {M[0, 1]}*{v[1]} = {result[0]}")
print(f"Result[1] = M[1, 0]*v[0] + M[1, 1]*v[1] = {M[1, 0]}*{v[0]} + {M[1, 1]}*{v[1]} = {result[1]}")


## 3. Batch Matrix Multiplication

Process multiple matrices at once using batch dimensions!
Very useful for neural networks!


In [None]:
# Batch matrix multiplication
# Create batch of matrices: (batch_size, rows, cols)
batch_A = torch.randn(3, 2, 4)  # 3 matrices, each 2×4
batch_B = torch.randn(3, 4, 3)  # 3 matrices, each 4×3

print("Batch A (3 matrices, each 2×4):")
print(batch_A)
print(f"Shape: {batch_A.shape}")
print()

print("Batch B (3 matrices, each 4×3):")
print(batch_B)
print(f"Shape: {batch_B.shape}")
print()

# Batch matrix multiplication
batch_result = torch.bmm(batch_A, batch_B)
print("Batch matrix multiplication result:")
print(batch_result)
print(f"Shape: {batch_result.shape}")
print("(3 matrices, each 2×3)")
print()

# Verify: first matrix multiplication
A0 = batch_A[0]
B0 = batch_B[0]
result0 = A0 @ B0

print("First matrix multiplication (manual):")
print(result0)
print()

print("First matrix multiplication (from batch):")
print(batch_result[0])
print("(Should match!)")
print(f"Are they equal? {torch.allclose(result0, batch_result[0])}")


## 4. Element-wise Multiplication

Multiply elements at the same position (not matrix multiplication)!
Use * for element-wise multiplication.


In [None]:
# Element-wise multiplication (different from matrix multiplication!)
A = torch.tensor([[1, 2], 
                  [3, 4]], dtype=torch.float32)

B = torch.tensor([[5, 6], 
                  [7, 8]], dtype=torch.float32)

print("Matrix A:")
print(A)
print()

print("Matrix B:")
print(B)
print()

# Element-wise multiplication
element_wise = A * B
print("A * B (element-wise multiplication):")
print(element_wise)
print()

# Matrix multiplication (for comparison)
matrix_mult = A @ B
print("A @ B (matrix multiplication):")
print(matrix_mult)
print()

print("Notice: Element-wise (*) is different from matrix multiplication (@)!")
print(f"Element-wise: {A[0, 0]} * {B[0, 0]} = {element_wise[0, 0]}")
print(f"Matrix mult: A[0, 0]*B[0, 0] + A[0, 1]*B[1, 0] = {matrix_mult[0, 0]}")
