# Matrix Multiplication: Unleashing the Power of Tensors! ⚡

> "Behold! The sacred art of matrix multiplication - where dimensions dance and vectors bend to my will!" — **Professor Victor py Torchenstein**

## The Attention Formula (Preview of Things to Come)

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ is the Query matrix
- $K$ is the Key matrix  
- $V$ is the Value matrix
- $d_k$ is the dimension of the key vectors
- $\text{softmax}$ normalizes the attention weights

## Basic Matrix Operations

Let's start with the fundamentals before we conquer attention mechanisms! 

Element-wise multiplication: 

$C_{ij} = A_{ij} \times B_{ij}$

Matrix multiplication: $C_{ij} = \sum_{k} A_{ik} \times B_{kj}$

In [2]:
import torch

# Create some matrices for experimentation
A = torch.randn(3, 4)
B = torch.randn(4, 2)

print("Matrix A shape:", A.shape)
print("Matrix B shape:", B.shape)

# Matrix multiplication
C = torch.matmul(A, B)
print("Result C shape:", C.shape)
print("\nMwahahaha! The matrices have been multiplied!")

Matrix A shape: torch.Size([3, 4])
Matrix B shape: torch.Size([4, 2])
Result C shape: torch.Size([3, 2])

Mwahahaha! The matrices have been multiplied!


## PyTorch Matrix Multiplication Methods

Professor Torchenstein's arsenal includes multiple ways to multiply matrices:

1. **`torch.matmul()`** - The general matrix multiplication function
2. **`@` operator** - Pythonic matrix multiplication (same as matmul)
3. **`torch.mm()`** - For 2D matrices only
4. **`torch.bmm()`** - Batch matrix multiplication

### Mathematical Foundations

For matrices $A \in \mathbb{R}^{m \times n}$ and $B \in \mathbb{R}^{n \times p}$:

$$C = AB \quad \text{where} \quad C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$$

This operation is fundamental to:
- Linear transformations
- Neural network forward passes  
- Attention mechanisms in Transformers
- And much more! 🧠⚡

## Pytorch doesn't support int matrix multiplication

In [2]:
import time
import torch

# 1. Our "Model" and "Input"
fp32_matrix = torch.randn(2**10, 2**10)
fp16_matrix = fp32_matrix.to(torch.float16)
bf16_matrix = fp32_matrix.to(torch.bfloat16)


# 2. The Symmetric Quantization Spell (zero_point = 0)
def quantize_symmetric(tensor, dtype=torch.int8):
    ''' 
    This function quantizes a tensor to desired integer dtype.
    It quantizes the  tesnsor based on the absolute max value in the tensor.
    '''

    assert dtype in [torch.int8, torch.int16, torch.int32]

    int_max = torch.iinfo(dtype).max
    int_min = torch.iinfo(dtype).min
    # Find the scale: map the absolute max of the tensor to the quantization range
    scale = tensor.abs().max() / int_max
    # Quantize
    quantized_tensor = (tensor / scale).round().clamp(int_min, int_max).to(dtype)
    return quantized_tensor, scale

# 3. Quantize input and weights INDEPENDENTLY
int8_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int8)
int16_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int16)
int32_quantized_matrix, scale_matrix = quantize_symmetric(fp32_matrix, dtype=torch.int32)

# --- GPU Operations ---
if torch.cuda.is_available():
    gpu_device = torch.device("cuda")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
    gpu_device = None
    print("GPU not available.")
if gpu_device:
    print("\n--- GPU Timings ---")

    # Move matrices to GPU
    fp32_matrix_gpu = fp32_matrix.to(gpu_device)
    fp16_matrix_gpu = fp16_matrix.to(gpu_device)
    bf16_matrix_gpu = bf16_matrix.to(gpu_device)
    int8_quantized_matrix_gpu = int8_quantized_matrix.to(gpu_device)
    int16_quantized_matrix_gpu = int16_quantized_matrix.to(gpu_device)
    int32_quantized_matrix_gpu = int32_quantized_matrix.to(gpu_device)
    
    # Correct timing for GPU operations requires synchronization
    
    # FP32 on GPU
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    fp32_output_gpu = torch.matmul(fp32_matrix_gpu, fp32_matrix_gpu)
    end.record()
    torch.cuda.synchronize()
    fp32_time_gpu = start.elapsed_time(end) / 1000  # Convert ms to s
    print(f"Float32 matmul on GPU took: {fp32_time_gpu:.6f} seconds")

    # FP16 on GPU
    start.record()
    fp16_output_gpu = torch.matmul(fp16_matrix_gpu, fp16_matrix_gpu)
    end.record()
    torch.cuda.synchronize()
    fp16_time_gpu = start.elapsed_time(end) / 1000
    print(f"Float16 matmul on GPU took: {fp16_time_gpu:.6f} seconds")

    # BF16 on GPU
    # Note: BF16 performance depends on GPU architecture (Ampere and newer)
    try:
        start.record()
        bf16_output_gpu = torch.matmul(bf16_matrix_gpu, bf16_matrix_gpu)
        end.record()
        torch.cuda.synchronize()
        bf16_time_gpu = start.elapsed_time(end) / 1000
        print(f"BFloat16 matmul on GPU took: {bf16_time_gpu:.6f} seconds")
    except Exception as e:
        print(f"BFloat16 matmul on GPU failed: {e}")

    # INT32 on GPU
    start.record()
    int32_output_gpu = torch.mm(int32_quantized_matrix_gpu, int32_quantized_matrix_gpu)
    end.record()
    torch.cuda.synchronize()
    int32_time_gpu = start.elapsed_time(end) / 1000
    print(f"Int32 matmul on GPU took: {int32_time_gpu:.6f} seconds")

    # INT16 on GPU
    start.record()
    int16_output_gpu = torch.matmul(int16_quantized_matrix_gpu, int16_quantized_matrix_gpu)
    end.record()
    torch.cuda.synchronize()
    int16_time_gpu = start.elapsed_time(end) / 1000
    print(f"Int16 matmul on GPU took: {int16_time_gpu:.6f} seconds")

    # INT8 on GPU
    start.record()
    int8_output_gpu = torch.matmul(int8_quantized_matrix_gpu, int8_quantized_matrix_gpu)
    end.record()
    torch.cuda.synchronize()
    int8_time_gpu = start.elapsed_time(end) / 1000
    print(f"Int8 matmul on GPU took: {int8_time_gpu:.6f} seconds")

GPU Name: NVIDIA GeForce RTX 3080 Laptop GPU

--- GPU Timings ---
Float32 matmul on GPU took: 0.107393 seconds
Float16 matmul on GPU took: 0.234943 seconds
BFloat16 matmul on GPU took: 0.116011 seconds


RuntimeError: "addmm_cuda" not implemented for 'Int'