# Week 2, Day 6: Tensor Cores

**Time:** ~1 hour

**Goal:** Understand and use Tensor Cores for matrix operations.

## The Challenge

CUDA cores do ~15 TFLOPS. Tensor Cores do ~990 TFLOPS (H100 FP16). That's **66x faster** for matrix math!

But Tensor Cores have rules: specific shapes, specific types, specific layouts.

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench

---
## Step 1: The Challenge (5 min)

### What Are Tensor Cores?

Specialized hardware units that compute **D = A × B + C** in one operation.

```
Single Tensor Core Operation:

    A              B              C              D
  [m×k]    ×    [k×n]    +    [m×n]    =    [m×n]

Example (FP16, V100):
  [16×16]  ×   [16×16]  +   [16×16]  =   [16×16]
  = 8,192 FLOPs in ONE cycle!
```

In [None]:
# Check Tensor Core support
def check_tensor_core_support():
    """Check GPU's Tensor Core capabilities."""
    if not torch.cuda.is_available():
        return None
    
    props = torch.cuda.get_device_properties(0)
    sm = props.major * 10 + props.minor
    
    capabilities = {
        'name': props.name,
        'sm': f"SM{props.major}{props.minor}",
        'has_tensor_cores': sm >= 70,  # Volta+
        'fp16': sm >= 70,
        'bf16': sm >= 80,
        'tf32': sm >= 80,
        'fp8': sm >= 89,
        'fp4': sm >= 100,  # Blackwell
    }
    
    return capabilities

caps = check_tensor_core_support()
if caps:
    print(f"GPU: {caps['name']} ({caps['sm']})")
    print(f"\nTensor Core Support:")
    print(f"  FP16: {'Yes' if caps['fp16'] else 'No'}")
    print(f"  BF16: {'Yes' if caps['bf16'] else 'No'}")
    print(f"  TF32: {'Yes' if caps['tf32'] else 'No'}")
    print(f"  FP8:  {'Yes' if caps['fp8'] else 'No'}")
else:
    print("No CUDA GPU available")

---
## Step 2: Explore (15 min)

### Tensor Core Evolution

| Generation | Architecture | MMA Shape | Types | Peak TFLOPS (FP16) |
|------------|-------------|-----------|-------|--------------------|
| 1st | Volta (V100) | 16×16×16 | FP16 | 125 |
| 2nd | Turing (T4) | 16×16×16 | FP16, INT8 | 65 |
| 3rd | Ampere (A100) | 16×8×16 | +BF16, TF32 | 312 |
| 4th | Hopper (H100) | Warpgroup | +FP8 | 990 |
| 5th | Blackwell (B100) | — | +FP4, FP6 | 2500+ |

### How Tensor Cores Work

```
32 threads (1 warp) cooperate to compute a small matrix multiply:

Thread 0-3:   Load A[0:4, 0:4]     Thread 0-3:   Store D[0:4, 0:4]
Thread 4-7:   Load A[4:8, 0:4]  →  Thread 4-7:   Store D[4:8, 0:4]
...                                 ...
Thread 28-31: Load B[0:4, 12:16]   Thread 28-31: Store D[12:16, 12:16]

The actual MMA is done by dedicated hardware, not CUDA cores.
```

In [None]:
# Triton uses tl.dot which maps to Tensor Cores automatically
@triton.jit
def matmul_tensor_core(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Matmul using Tensor Cores via tl.dot.
    
    Triton automatically uses Tensor Cores when:
    - Input types are FP16, BF16, or FP8
    - Tile sizes are compatible (multiples of 16)
    - Using tl.dot operation
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    # Accumulator in FP32 for precision
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        
        # tl.dot uses Tensor Cores when inputs are FP16
        acc += tl.dot(a, b)
        
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

In [None]:
def benchmark_tensor_cores(M, N, K):
    """Benchmark Tensor Core matmul."""
    # FP16 inputs (required for Tensor Cores)
    a = torch.randn(M, K, device='cuda', dtype=torch.float16)
    b = torch.randn(K, N, device='cuda', dtype=torch.float16)
    c = torch.empty(M, N, device='cuda', dtype=torch.float32)
    
    # Tile sizes should be multiples of 16 for Tensor Cores
    BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    
    # Triton kernel (uses Tensor Cores)
    ms_triton = do_bench(lambda: matmul_tensor_core[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    ))
    
    # PyTorch (also uses Tensor Cores via cuBLAS)
    ms_torch = do_bench(lambda: torch.mm(a, b))
    
    # Calculate TFLOPS
    flops = 2 * M * N * K
    tflops_triton = flops / (ms_triton * 1e-3) / 1e12
    tflops_torch = flops / (ms_torch * 1e-3) / 1e12
    
    return {
        'triton_ms': ms_triton,
        'torch_ms': ms_torch,
        'triton_tflops': tflops_triton,
        'torch_tflops': tflops_torch,
        'efficiency': tflops_triton / tflops_torch * 100,
    }

print("Tensor Core Matmul Benchmark")
print("=" * 60)
print(f"{'Size':<15} {'Triton (TFLOPS)':<18} {'PyTorch (TFLOPS)':<18} {'Efficiency':<12}")
print("-" * 60)

for size in [1024, 2048, 4096, 8192]:
    result = benchmark_tensor_cores(size, size, size)
    print(f"{f'{size}x{size}':<15} {result['triton_tflops']:<18.1f} {result['torch_tflops']:<18.1f} {result['efficiency']:<12.1f}%")

---
## Step 3: The Concept (10 min)

### Tensor Core Requirements

To use Tensor Cores, you must satisfy:

1. **Data Types:**
   - A, B: FP16, BF16, TF32, FP8, INT8
   - C, D: FP32 or same as A, B

2. **Tile Dimensions:**
   - Must be multiples of the native MMA shape
   - Typically 16×16 or 8×16

3. **Memory Layout:**
   - Specific thread-to-data mapping
   - Triton handles this automatically

### Data Type Comparison

| Type | Bits | Range | Precision | TFLOPS (H100) |
|------|------|-------|-----------|---------------|
| FP32 | 32 | ±3.4e38 | High | 67 (CUDA cores) |
| TF32 | 19 | ±3.4e38 | Medium | 495 |
| FP16 | 16 | ±65504 | Medium | 990 |
| BF16 | 16 | ±3.4e38 | Lower | 990 |
| FP8 | 8 | ±448 (E4M3) | Low | 1979 |

In [None]:
def compare_dtypes(M, N, K):
    """Compare performance across data types."""
    results = {}
    flops = 2 * M * N * K
    
    dtypes = [
        ('float32', torch.float32),
        ('float16', torch.float16),
        ('bfloat16', torch.bfloat16),
    ]
    
    for name, dtype in dtypes:
        try:
            a = torch.randn(M, K, device='cuda', dtype=dtype)
            b = torch.randn(K, N, device='cuda', dtype=dtype)
            
            ms = do_bench(lambda: torch.mm(a, b))
            tflops = flops / (ms * 1e-3) / 1e12
            results[name] = {'ms': ms, 'tflops': tflops}
        except Exception as e:
            results[name] = {'error': str(e)}
    
    return results

print("Data Type Performance Comparison (4096x4096)")
print("=" * 50)

results = compare_dtypes(4096, 4096, 4096)
for dtype, data in results.items():
    if 'error' in data:
        print(f"{dtype:<10}: Not supported")
    else:
        print(f"{dtype:<10}: {data['tflops']:.1f} TFLOPS")

---
## Step 4: Code It (30 min)

### Mixed Precision Matmul

Common pattern: FP16 inputs, FP32 accumulator, FP16 output.

In [None]:
@triton.jit
def matmul_mixed_precision(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    OUTPUT_FP16: tl.constexpr,
):
    """Mixed precision matmul: FP16 inputs, FP32 accumulator."""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    # FP32 accumulator for precision
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        # Load FP16 inputs
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        
        # Tensor Core MMA with FP32 accumulator
        acc += tl.dot(a, b)
        
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    # Store output
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    
    if OUTPUT_FP16:
        # Convert to FP16 for output
        tl.store(c_ptrs, acc.to(tl.float16), mask=mask)
    else:
        tl.store(c_ptrs, acc, mask=mask)

In [None]:
def benchmark_mixed_precision(M, N, K):
    """Compare FP32 output vs FP16 output."""
    a = torch.randn(M, K, device='cuda', dtype=torch.float16)
    b = torch.randn(K, N, device='cuda', dtype=torch.float16)
    c_fp32 = torch.empty(M, N, device='cuda', dtype=torch.float32)
    c_fp16 = torch.empty(M, N, device='cuda', dtype=torch.float16)
    
    BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    
    # FP32 output
    ms_fp32 = do_bench(lambda: matmul_mixed_precision[grid](
        a, b, c_fp32, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c_fp32.stride(0), c_fp32.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        OUTPUT_FP16=False,
    ))
    
    # FP16 output
    ms_fp16 = do_bench(lambda: matmul_mixed_precision[grid](
        a, b, c_fp16, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c_fp16.stride(0), c_fp16.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        OUTPUT_FP16=True,
    ))
    
    flops = 2 * M * N * K
    
    return {
        'fp32_out_ms': ms_fp32,
        'fp16_out_ms': ms_fp16,
        'fp32_tflops': flops / (ms_fp32 * 1e-3) / 1e12,
        'fp16_tflops': flops / (ms_fp16 * 1e-3) / 1e12,
    }

print("Mixed Precision Output Comparison")
print("=" * 50)

for size in [2048, 4096]:
    result = benchmark_mixed_precision(size, size, size)
    print(f"\nSize {size}x{size}:")
    print(f"  FP32 output: {result['fp32_tflops']:.1f} TFLOPS")
    print(f"  FP16 output: {result['fp16_tflops']:.1f} TFLOPS")
    print(f"  Speedup: {result['fp32_out_ms'] / result['fp16_out_ms']:.2f}x")

### Warpgroup MMA (Hopper)

Hopper introduces **Warpgroup MMA**: 128 threads (4 warps) cooperate on larger tiles.

Benefits:
- Larger tiles (64×64 or bigger)
- Better data reuse
- Native async execution

Triton handles this automatically on Hopper when tile sizes are appropriate.

---
## Step 5: Verify (10 min)

### Quiz

**Q1:** What data types enable Tensor Cores?

A) FP64 only  
B) FP32 only  
C) FP16, BF16, TF32, FP8  
D) Any floating point

**Q2:** Why use FP32 accumulator with FP16 inputs?

A) Faster computation  
B) Prevents precision loss during accumulation  
C) Required by hardware  
D) Saves memory

**Q3:** What tile size constraint exists for Tensor Cores?

In [None]:
print("Quiz Answers")
print("=" * 50)
print()
print("Q1: C) FP16, BF16, TF32, FP8")
print("    Tensor Cores support reduced precision types.")
print("    FP64 uses different (slower) Tensor Core units on some GPUs.")
print()
print("Q2: B) Prevents precision loss during accumulation")
print("    When adding many FP16 values, small values can be lost.")
print("    FP32 accumulator maintains precision until final output.")
print()
print("Q3: Multiples of 16 (or 8 for some shapes)")
print("    Native MMA shapes are 16x16x16 or 16x8x16.")
print("    Tile dimensions must be multiples of these.")

---
## Summary

### Key Takeaways

1. **Tensor Cores = 66x faster** than CUDA cores for matrix math
2. **Data type matters**: FP16/BF16/TF32/FP8 enable Tensor Cores
3. **Tile sizes**: multiples of 16 for best performance
4. **FP32 accumulator**: use with FP16 inputs for precision
5. **Triton's tl.dot**: automatically uses Tensor Cores

### Tensor Core Performance (H100)

| Type | TFLOPS | vs FP32 CUDA |
|------|--------|-------------|
| FP32 (CUDA) | 67 | 1x |
| TF32 (TC) | 495 | 7.4x |
| FP16 (TC) | 990 | 14.8x |
| FP8 (TC) | 1979 | 29.5x |

### Tomorrow: Optimized GEMM

We have all the pieces: coalescing, bank conflicts, pipelining, TMA, Tensor Cores. Tomorrow we combine everything into a production-quality GEMM kernel.