# Week 1, Day 6: Tiling Basics

**Time:** ~1 hour

**Goal:** Understand tiling for data reuse in shared memory.

## The Challenge

Yesterday we learned that data reuse is key to performance. Today we'll implement it:
1. Understand tile coordinates and indexing
2. Use shared memory as a fast scratchpad
3. Build intuition for tiled matmul

In [None]:
import numpy as np

try:
    import torch
    import triton
    import triton.language as tl
    GPU_AVAILABLE = True
except ImportError:
    GPU_AVAILABLE = False

---
## Step 1: The Challenge (5 min)

### Why Tiling?

Consider matrix multiplication C = A @ B where A is MxK and B is KxN.

**Naive approach:**
```
For each output element C[i,j]:
    Load row i of A (K elements)
    Load column j of B (K elements)
    Compute dot product
```

Problem: Each element of A is loaded N times. Each element of B is loaded M times.

**Tiled approach:**
```
For each tile of C:
    For each k-tile:
        Load tile of A into shared memory (once)
        Load tile of B into shared memory (once)
        Compute partial products (many times)
    Write tile of C to global memory
```

Benefit: Each global memory load is reused TILE_SIZE times!

In [None]:
# Calculate the data reuse benefit
def calculate_reuse_benefit(M, N, K, tile_size):
    """Compare memory traffic: naive vs tiled."""
    
    # Naive: each A element loaded N times, each B element loaded M times
    naive_a_loads = M * K * N  # Each A[i,k] loaded for all N columns of C
    naive_b_loads = K * N * M  # Each B[k,j] loaded for all M rows of C
    naive_total = naive_a_loads + naive_b_loads
    
    # Tiled: each element loaded K/tile_size times (once per k-tile)
    num_k_tiles = (K + tile_size - 1) // tile_size
    tiled_a_loads = M * K * num_k_tiles  # Wait, that's wrong...
    
    # Actually: tiles of A and B loaded (M/T) * (N/T) * (K/T) times
    num_m_tiles = (M + tile_size - 1) // tile_size
    num_n_tiles = (N + tile_size - 1) // tile_size
    
    # Each tile of A (T x T) is loaded once per n-tile
    tiled_a_loads = M * K  # Each element loaded (K/T) times but reused T times = K/T * T = K... 
    # Actually simpler: total unique elements = M*K for A, N*K for B
    # With tiling, each element loaded only K/T times (once per k-tile iteration)
    # But reused T times within the tile
    
    # Let's think about it differently:
    # Without tiling: Each C[i,j] loads K elements from A and K elements from B = 2K
    # Total for all M*N outputs: M*N * 2K
    naive_traffic = M * N * 2 * K
    
    # With tiling: Each tile of C (TxT) loads tiles of A and B
    # For each k-tile: load T*T from A, T*T from B
    # Total k-tiles: K/T
    # Total tiles of C: (M/T) * (N/T)
    # Traffic per C-tile: (K/T) * 2 * T * T = 2 * K * T
    # Total traffic: (M/T) * (N/T) * 2 * K * T = M * N * K * 2 / T
    tiled_traffic = M * N * 2 * K / tile_size
    
    reduction = naive_traffic / tiled_traffic
    
    print(f"Matrix sizes: A={M}x{K}, B={K}x{N}, C={M}x{N}")
    print(f"Tile size: {tile_size}x{tile_size}")
    print(f"")
    print(f"Memory traffic (elements):")
    print(f"  Naive:  {naive_traffic:>15,}")
    print(f"  Tiled:  {tiled_traffic:>15,.0f}")
    print(f"  Reduction: {reduction:.0f}x")

calculate_reuse_benefit(M=1024, N=1024, K=1024, tile_size=32)

---
## Step 2: Explore - Tile Coordinates (15 min)

### Understanding Tile Indexing

When processing a tile of the output matrix C:
1. **Block coordinates:** Which tile are we processing?
2. **Thread coordinates:** Where within the tile is this thread?
3. **Global coordinates:** What's the actual matrix position?

In [None]:
def visualize_tiling(M, N, tile_m, tile_n):
    """Visualize how a matrix is divided into tiles."""
    
    num_tiles_m = (M + tile_m - 1) // tile_m
    num_tiles_n = (N + tile_n - 1) // tile_n
    
    print(f"Matrix: {M} x {N}")
    print(f"Tile size: {tile_m} x {tile_n}")
    print(f"Grid: {num_tiles_m} x {num_tiles_n} = {num_tiles_m * num_tiles_n} tiles")
    print()
    
    # Create visualization
    print("Tile layout:")
    print("-" * (num_tiles_n * 8 + 1))
    
    for tm in range(min(num_tiles_m, 4)):
        row = "|"
        for tn in range(min(num_tiles_n, 8)):
            row += f" ({tm},{tn}) |"
        print(row)
        print("-" * (min(num_tiles_n, 8) * 8 + 1))
    
    if num_tiles_m > 4 or num_tiles_n > 8:
        print(f"... ({num_tiles_m * num_tiles_n} tiles total)")

visualize_tiling(M=1024, N=1024, tile_m=128, tile_n=128)

In [None]:
def tile_coordinate_examples(M, N, tile_m, tile_n):
    """Show how to compute coordinates for a specific tile."""
    
    print("Coordinate calculation examples:")
    print("="*60)
    
    for block_m, block_n in [(0, 0), (0, 1), (1, 0), (2, 3)]:
        # Global start position of this tile
        row_start = block_m * tile_m
        col_start = block_n * tile_n
        
        # Range of elements this tile covers
        row_end = min(row_start + tile_m, M)
        col_end = min(col_start + tile_n, N)
        
        print(f"\nTile ({block_m}, {block_n}):")
        print(f"  Row range: [{row_start}, {row_end})")
        print(f"  Col range: [{col_start}, {col_end})")
        print(f"  Elements: {(row_end - row_start) * (col_end - col_start)}")

tile_coordinate_examples(M=1024, N=1024, tile_m=128, tile_n=128)

### The Key Formulas

For a 2D tiled kernel:

```python
# Block (tile) position
block_m = tl.program_id(0)  # Which tile row
block_n = tl.program_id(1)  # Which tile column

# Global row/column offsets for this tile
row_offsets = block_m * TILE_M + tl.arange(0, TILE_M)
col_offsets = block_n * TILE_N + tl.arange(0, TILE_N)

# For accessing A[row, k] in row-major:
# Linear index = row * K + k
a_ptrs = a_ptr + row_offsets[:, None] * K + k_offsets[None, :]

# For accessing B[k, col] in row-major:
# Linear index = k * N + col
b_ptrs = b_ptr + k_offsets[:, None] * N + col_offsets[None, :]
```

---
## Step 3: The Concept - Shared Memory (10 min)

### What is Shared Memory?

- **On-chip SRAM** (not HBM)
- **Shared within a block** (all threads in block can access)
- **Fast:** ~20 cycles vs 400 cycles for HBM
- **Limited:** 48-228 KB per SM (depending on GPU)

### Usage Pattern

```python
# Allocate shared memory
tile_a = tl.zeros((TILE_M, TILE_K), dtype=tl.float32)  # In shared memory

# Load from global to shared (slow, but only once)
tile_a = tl.load(a_ptr + offsets)

# Use from shared (fast, many times)
for i in range(TILE_K):
    result += tile_a[:, i] * tile_b[i, :]  # Fast!
```

In [None]:
# Calculate shared memory requirements
def calculate_smem_usage(tile_m, tile_n, tile_k, dtype_bytes=4):
    """Calculate shared memory needed for tiled matmul."""
    
    # Need to store:
    # - Tile of A: TILE_M x TILE_K
    # - Tile of B: TILE_K x TILE_N
    
    bytes_a = tile_m * tile_k * dtype_bytes
    bytes_b = tile_k * tile_n * dtype_bytes
    total = bytes_a + bytes_b
    
    print(f"Tile sizes: A={tile_m}x{tile_k}, B={tile_k}x{tile_n}")
    print(f"Shared memory needed:")
    print(f"  Tile A: {bytes_a / 1024:.1f} KB")
    print(f"  Tile B: {bytes_b / 1024:.1f} KB")
    print(f"  Total:  {total / 1024:.1f} KB")
    print(f"")
    print(f"Typical SM limits:")
    print(f"  48 KB:  {'OK' if total <= 48*1024 else 'TOO BIG'}")
    print(f"  96 KB:  {'OK' if total <= 96*1024 else 'TOO BIG'}")
    print(f"  164 KB: {'OK' if total <= 164*1024 else 'TOO BIG'}")

# Common tile sizes
print("Small tiles (32x32):")
calculate_smem_usage(32, 32, 32)

print("\nMedium tiles (64x64):")
calculate_smem_usage(64, 64, 64)

print("\nLarge tiles (128x128):")
calculate_smem_usage(128, 128, 64)  # K often smaller

---
## Step 4: Code It - Tiled Copy Kernel (30 min)

Before tackling matmul, let's understand 2D tiling with a simpler operation: matrix copy.

In [None]:
if GPU_AVAILABLE:
    @triton.jit
    def tiled_copy_kernel(
        src_ptr, dst_ptr,
        M, N,
        stride_m, stride_n,  # Strides for 2D indexing
        TILE_M: tl.constexpr,
        TILE_N: tl.constexpr,
    ):
        """Copy a 2D matrix using tiled access."""
        
        # Step 1: Which tile is this program handling?
        tile_m = tl.program_id(0)  # Tile row index
        tile_n = tl.program_id(1)  # Tile column index
        
        # Step 2: Calculate the row and column offsets for this tile
        row_offsets = tile_m * TILE_M + tl.arange(0, TILE_M)
        col_offsets = tile_n * TILE_N + tl.arange(0, TILE_N)
        
        # Step 3: Create masks for valid elements
        row_mask = row_offsets < M
        col_mask = col_offsets < N
        mask = row_mask[:, None] & col_mask[None, :]
        
        # Step 4: Calculate memory offsets (row-major layout)
        # offset = row * stride_m + col * stride_n
        offsets = row_offsets[:, None] * stride_m + col_offsets[None, :] * stride_n
        
        # Step 5: Load and store
        data = tl.load(src_ptr + offsets, mask=mask, other=0.0)
        tl.store(dst_ptr + offsets, data, mask=mask)
    
    print("Tiled copy kernel compiled!")

In [None]:
if GPU_AVAILABLE:
    def tiled_copy(src: torch.Tensor) -> torch.Tensor:
        """Copy a 2D matrix using tiled kernel."""
        
        M, N = src.shape
        dst = torch.empty_like(src)
        
        TILE_M, TILE_N = 32, 32
        
        # Grid: number of tiles in each dimension
        grid = (
            triton.cdiv(M, TILE_M),
            triton.cdiv(N, TILE_N),
        )
        
        tiled_copy_kernel[grid](
            src, dst,
            M, N,
            src.stride(0), src.stride(1),
            TILE_M=TILE_M,
            TILE_N=TILE_N,
        )
        
        return dst
    
    # Test
    M, N = 100, 100  # Non-multiple of tile size to test masking
    src = torch.randn(M, N, device='cuda', dtype=torch.float32)
    dst = tiled_copy(src)
    
    print(f"Source shape: {src.shape}")
    print(f"Destination shape: {dst.shape}")
    print(f"Results match: {torch.allclose(src, dst)}")

### Understanding the 2D Indexing

The key insight is how we construct 2D offsets:

```python
# row_offsets is a 1D array: [row0, row1, ..., row_TILE_M-1]
# col_offsets is a 1D array: [col0, col1, ..., col_TILE_N-1]

# Using broadcasting to create 2D grid of offsets:
# row_offsets[:, None] * stride_m gives rows
# col_offsets[None, :] * stride_n gives columns
# Adding them gives the full offset grid
```

In [None]:
# Visualize the broadcasting
def visualize_2d_indexing(tile_m, tile_n, stride_m, stride_n):
    """Show how 2D offsets are computed."""
    
    row_offsets = np.arange(tile_m)
    col_offsets = np.arange(tile_n)
    
    print(f"row_offsets: {row_offsets}")
    print(f"col_offsets: {col_offsets}")
    print(f"stride_m (row stride): {stride_m}")
    print(f"stride_n (col stride): {stride_n}")
    print()
    
    # Compute 2D offsets
    offsets = row_offsets[:, None] * stride_m + col_offsets[None, :] * stride_n
    
    print("2D offset grid (showing first 4x4):")
    print(offsets[:4, :4])
    print()
    print("For a row-major matrix, adjacent columns are adjacent in memory (stride=1)")
    print("Rows are separated by N elements (stride=N)")

visualize_2d_indexing(tile_m=4, tile_n=4, stride_m=8, stride_n=1)

---
## Step 5: Verify - Quiz (10 min)

In [None]:
# Q1: Why does tiling reduce memory traffic?
print("Q1: Tiling enables data reuse")
print("    - Load data once from slow HBM to fast shared memory")
print("    - Reuse that data multiple times from shared memory")
print("    - Traffic reduction factor â‰ˆ tile size")

In [None]:
# Q2: What limits tile size?
print("Q2: Tile size is limited by shared memory")
print("    - For matmul: need tiles of A and B in SMEM")
print("    - Total SMEM per block: typically 48-164 KB")
print("    - Also limited by registers per thread")

In [None]:
# Q3: Calculate grid size
print("Q3: For 1024x1024 matrix with 64x64 tiles:")
M, N = 1024, 1024
TILE_M, TILE_N = 64, 64
grid_m = (M + TILE_M - 1) // TILE_M
grid_n = (N + TILE_N - 1) // TILE_N
print(f"    Grid size: {grid_m} x {grid_n} = {grid_m * grid_n} tiles")

In [None]:
# Q4: Calculate memory offset
print("Q4: For element [3, 5] in a 1024x1024 row-major matrix:")
row, col = 3, 5
N = 1024
offset = row * N + col
print(f"    Linear offset = row * N + col = {row} * {N} + {col} = {offset}")

---
## Summary

| Concept | Key Point |
|---------|----------|
| Tiling | Divide matrix into blocks that fit in fast memory |
| Data Reuse | Load once, use many times |
| Shared Memory | Fast on-chip cache, programmer-managed |
| Grid | 2D array of tiles (programs) |
| Offsets | `tile_id * TILE_SIZE + tl.arange(0, TILE_SIZE)` |

### Key Formula for 2D Indexing

```python
# For row-major matrix:
offsets = row_offsets[:, None] * stride_row + col_offsets[None, :] * stride_col
```

### What's Next?

Tomorrow we combine everything: tiling + shared memory + coalescing = fast matmul!

---
## Next: Day 7 - Fast Matmul

Tomorrow we'll implement our first fast matrix multiplication kernel.

[Continue to 07_fast_matmul.ipynb](./07_fast_matmul.ipynb)