# Week 2, Day 5: TMA — Tensor Memory Accelerator

**Time:** ~1 hour

**Goal:** Learn to use Hopper's TMA for efficient bulk data transfers.

## The Challenge

Even with pipelining, SMs spend cycles computing addresses for memory loads. On Hopper (H100), the **Tensor Memory Accelerator (TMA)** offloads this entirely — the SM just says "load this tile" and moves on.

**Note:** TMA requires Hopper (SM90) or newer. This notebook covers concepts and shows the API.

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench

---
## Step 1: The Challenge (5 min)

### Traditional Memory Loads

```
SM Work for Loading a 64x64 Tile:

1. Calculate 4096 addresses (64*64)
2. Issue 4096 load instructions  
3. Wait for data
4. Repeat for next tile

→ SM is busy with bookkeeping, not compute!
```

### TMA Loads

```
SM Work:

1. Tell TMA: "Load tile at (row, col)" (one instruction)
2. Go do compute
3. TMA handles everything in the background

→ SM focuses on math!
```

In [None]:
# Check GPU capability
def check_tma_support():
    """Check if current GPU supports TMA."""
    if not torch.cuda.is_available():
        return False, "No CUDA available"
    
    props = torch.cuda.get_device_properties(0)
    sm_major = props.major
    sm_minor = props.minor
    
    # TMA requires SM90 (Hopper) or newer
    has_tma = sm_major >= 9
    
    return has_tma, f"SM{sm_major}{sm_minor} ({props.name})"

has_tma, gpu_info = check_tma_support()
print(f"GPU: {gpu_info}")
print(f"TMA Support: {'Yes' if has_tma else 'No (requires Hopper/SM90+)'}")

---
## Step 2: Explore (15 min)

### TMA Key Features

| Feature | Benefit |
|---------|--------|
| **Descriptor-based** | Define tensor layout once, reuse |
| **2D/3D tiles** | Native multidimensional support |
| **Multicast** | Same data to multiple SMs |
| **Async** | Fully non-blocking |
| **Address calculation** | Done by TMA unit, not SM |

### TMA Descriptor

A TMA descriptor contains:
- Tensor base address
- Tensor dimensions (global shape)
- Tile dimensions (what to load)
- Data type
- Swizzle pattern (for bank conflict avoidance)

In [None]:
# Conceptual TMA descriptor (actual API is more complex)
class TMAConcept:
    """Conceptual representation of TMA operations."""
    
    @staticmethod
    def create_descriptor(tensor, tile_shape):
        """Create a TMA descriptor for a tensor.
        
        In real Triton/CUDA:
        - Uses cuTensorMapEncode() to create descriptor
        - Descriptor lives in constant memory
        - Kernel receives pointer to descriptor
        """
        return {
            'base_ptr': tensor.data_ptr(),
            'global_shape': tensor.shape,
            'tile_shape': tile_shape,
            'dtype': tensor.dtype,
            'strides': tensor.stride(),
        }
    
    @staticmethod
    def async_load(desc, smem, tile_coords):
        """Conceptual async TMA load.
        
        Real instruction: cp.async.bulk.tensor
        
        SM just specifies:
        - Which tile (by coordinates)
        - Where in SMEM to put it
        
        TMA unit handles:
        - Address calculation
        - Memory transactions
        - Swizzling
        """
        pass

# Example setup
tensor = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
tile_shape = (64, 64)

desc = TMAConcept.create_descriptor(tensor, tile_shape)
print("Conceptual TMA Descriptor:")
for k, v in desc.items():
    print(f"  {k}: {v}")

---
## Step 3: The Concept (10 min)

### TMA vs Traditional Loads

```
Traditional Load (per tile):
┌───────────┐    ┌───────────┐    ┌───────────┐
│ Addr calc │ →  │   Load    │ →  │  Barrier  │
│ (SM busy) │    │ (SM busy) │    │  (wait)   │
└───────────┘    └───────────┘    └───────────┘

TMA Load (per tile):
┌───────────┐    
│TMA trigger│  → SM does compute while TMA loads!
│(1 instr)  │    
└───────────┘    
```

### Multicast

TMA can send the same data to multiple SMs simultaneously:

```
Without Multicast:                With Multicast:
HBM → SM0                         HBM ─┬→ SM0
HBM → SM1   (3 separate loads)         ├→ SM1  (1 load, 3 destinations)
HBM → SM2                              └→ SM2
```

This is especially useful for broadcast patterns like attention's K/V sharing.

In [None]:
# Triton TMA API (Hopper+)
# Note: This requires Hopper GPU to actually run

if has_tma:
    @triton.jit
    def matmul_tma(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    ):
        """Matmul using TMA for data loading.
        
        On Hopper, Triton can use TMA automatically when:
        - Tensor layouts match TMA requirements
        - Using tl.load with proper masking
        
        For explicit TMA control, use experimental APIs.
        """
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        
        offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        offs_k = tl.arange(0, BLOCK_K)
        
        a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
        b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
        
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        
        for k in range(0, K, BLOCK_K):
            # On Hopper, these loads may use TMA automatically
            a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
            b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
            
            acc += tl.dot(a, b)
            
            a_ptrs += BLOCK_K * stride_ak
            b_ptrs += BLOCK_K * stride_bk
        
        c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
        tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
    
    print("TMA matmul kernel defined (requires Hopper to run)")
else:
    print("TMA not available on this GPU - showing concepts only")

---
## Step 4: Code It (30 min)

### TMA Benefits Simulation

Let's simulate the performance benefit of TMA by comparing:
1. Traditional loads (SM does address calculation)
2. Simulated TMA (address calculation "free")

In [None]:
def estimate_tma_benefit(M, N, K, BLOCK_M=64, BLOCK_N=64, BLOCK_K=32):
    """Estimate potential benefit of TMA vs traditional loads.
    
    This is a simplified model - real benefits depend on many factors.
    """
    # Number of tiles
    num_tiles_m = (M + BLOCK_M - 1) // BLOCK_M
    num_tiles_n = (N + BLOCK_N - 1) // BLOCK_N
    num_tiles_k = (K + BLOCK_K - 1) // BLOCK_K
    
    # Elements per tile
    a_tile_elems = BLOCK_M * BLOCK_K
    b_tile_elems = BLOCK_K * BLOCK_N
    
    # Traditional: SM calculates addresses
    # Rough estimate: ~4 cycles per address calculation
    addr_calc_cycles = (a_tile_elems + b_tile_elems) * 4
    
    # Compute cycles per tile
    # FMA throughput: ~256 FMAs per cycle (simplified for FP16)
    compute_flops = BLOCK_M * BLOCK_N * BLOCK_K * 2
    compute_cycles = compute_flops / 256
    
    # Memory latency
    mem_latency_cycles = 400  # Approximate
    
    # Total per tile (simplified)
    traditional_cycles = addr_calc_cycles + max(compute_cycles, mem_latency_cycles)
    tma_cycles = compute_cycles + mem_latency_cycles / 4  # TMA overlaps better
    
    return {
        'tiles': num_tiles_m * num_tiles_n * num_tiles_k,
        'traditional_cycles_per_tile': traditional_cycles,
        'tma_cycles_per_tile': tma_cycles,
        'estimated_speedup': traditional_cycles / tma_cycles,
        'addr_calc_overhead': addr_calc_cycles / traditional_cycles * 100,
    }

print("Estimated TMA Benefit")
print("=" * 60)

for size in [1024, 2048, 4096, 8192]:
    est = estimate_tma_benefit(size, size, size)
    print(f"\nSize {size}x{size}:")
    print(f"  Tiles: {est['tiles']:,}")
    print(f"  Address calculation overhead: {est['addr_calc_overhead']:.1f}%")
    print(f"  Estimated TMA speedup: {est['estimated_speedup']:.2f}x")

### When TMA Helps Most

TMA benefits are highest when:
1. **Small tiles** (more address calculations per byte)
2. **Complex layouts** (strided access, padding)
3. **Multicast patterns** (same data to many SMs)
4. **High occupancy** (more concurrent operations)

### TMA Limitations

- Only available on Hopper (H100) and newer
- Tile sizes must be powers of 2 or specific shapes
- Requires proper tensor alignment
- Descriptor setup has overhead (but amortized)

In [None]:
# Real benchmark (if on Hopper)
if has_tma:
    def benchmark_tma(M, N, K):
        a = torch.randn(M, K, device='cuda', dtype=torch.float16)
        b = torch.randn(K, N, device='cuda', dtype=torch.float16)
        c = torch.empty(M, N, device='cuda', dtype=torch.float32)
        
        BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
        
        ms = do_bench(lambda: matmul_tma[grid](
            a, b, c, M, N, K,
            a.stride(0), a.stride(1),
            b.stride(0), b.stride(1),
            c.stride(0), c.stride(1),
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        ))
        
        flops = 2 * M * N * K
        tflops = flops / (ms * 1e-3) / 1e12
        
        ms_torch = do_bench(lambda: torch.mm(a, b))
        tflops_torch = flops / (ms_torch * 1e-3) / 1e12
        
        return {'triton_tflops': tflops, 'torch_tflops': tflops_torch}
    
    print("TMA Matmul Benchmark (Hopper)")
    print("=" * 50)
    for size in [2048, 4096]:
        result = benchmark_tma(size, size, size)
        print(f"Size {size}x{size}: Triton={result['triton_tflops']:.1f} TFLOPS, PyTorch={result['torch_tflops']:.1f} TFLOPS")
else:
    print("Skipping TMA benchmark - requires Hopper GPU")

---
## Step 5: Verify (10 min)

### Quiz

**Q1:** What does TMA offload from the SM?

A) Floating point computation  
B) Address calculation for memory loads  
C) Thread synchronization  
D) Register allocation

**Q2:** What GPU architecture introduced TMA?

A) Ampere (A100)  
B) Hopper (H100)  
C) Turing  
D) Volta

**Q3:** What is TMA multicast useful for?

In [None]:
print("Quiz Answers")
print("=" * 50)
print()
print("Q1: B) Address calculation for memory loads")
print("    TMA handles address calculation, boundary checks, and")
print("    memory transactions - SM just triggers the load.")
print()
print("Q2: B) Hopper (H100)")
print("    TMA is a Hopper-specific feature (SM90+).")
print()
print("Q3: Broadcast patterns")
print("    When multiple SMs need the same data (like K/V in attention),")
print("    TMA can send it to all of them with a single memory read.")
print("    This saves memory bandwidth significantly.")

---
## Summary

### Key Takeaways

1. **TMA = Tensor Memory Accelerator** (Hopper+ only)
2. **Offloads address calculation** from SM to dedicated unit
3. **Descriptor-based**: define tensor once, load tiles by coordinates
4. **Multicast**: same data to multiple SMs efficiently
5. **Works with pipelining** for even better latency hiding

### TMA vs Traditional

| Aspect | Traditional | TMA |
|--------|------------|-----|
| Address calc | SM cycles | TMA unit (free) |
| Tile loads | Many instructions | One instruction |
| Multicast | N loads | 1 load, N destinations |
| Swizzling | Manual | Automatic |

### Tomorrow: Tensor Cores

Now that we can load data efficiently, let's use the fastest compute units: **Tensor Cores**.