# Part 4, Lab 6: Fused Quantized Attention

**Time:** ~45 minutes

Combine quantization with fused attention kernels for maximum inference efficiency. This lab brings together concepts from attention and quantization.

## Learning Objectives

1. Understand fused attention kernels
2. Implement attention with quantized KV cache
3. Measure memory bandwidth savings
4. Profile and optimize

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

---
## 1. Standard vs Fused Attention

Standard attention computes QK^T, softmax, and V multiplication separately.
Fused attention (FlashAttention) combines these operations, reducing memory traffic.

In [None]:
def standard_attention(Q, K, V, scale):
    """Standard attention: high memory usage, clear implementation."""
    # QK^T: [batch, heads, seq_q, seq_kv]
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # Softmax
    attn_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum with V
    output = torch.matmul(attn_weights, V)
    
    return output

def memory_efficient_attention(Q, K, V, scale, block_size=64):
    """
    Memory-efficient attention using online softmax.
    Processes K, V in blocks to reduce memory.
    """
    batch, heads, seq_q, head_dim = Q.shape
    seq_kv = K.shape[2]
    
    # Initialize accumulators
    output = torch.zeros_like(Q)
    m = torch.full((batch, heads, seq_q, 1), float('-inf'), device=Q.device)
    l = torch.zeros((batch, heads, seq_q, 1), device=Q.device)
    
    # Process K, V in blocks
    for block_start in range(0, seq_kv, block_size):
        block_end = min(block_start + block_size, seq_kv)
        K_block = K[:, :, block_start:block_end, :]
        V_block = V[:, :, block_start:block_end, :]
        
        # Compute scores for this block
        scores = torch.matmul(Q, K_block.transpose(-2, -1)) * scale
        
        # Online softmax update
        m_block = scores.max(dim=-1, keepdim=True)[0]
        m_new = torch.maximum(m, m_block)
        
        # Rescale old accumulator
        exp_old = torch.exp(m - m_new)
        l = l * exp_old
        output = output * exp_old
        
        # Add new block contribution
        exp_scores = torch.exp(scores - m_new)
        l = l + exp_scores.sum(dim=-1, keepdim=True)
        output = output + torch.matmul(exp_scores, V_block)
        
        m = m_new
    
    # Normalize
    output = output / l
    return output

# Compare implementations
batch, heads, seq_len, head_dim = 1, 8, 512, 64
Q = torch.randn(batch, heads, 1, head_dim, device=device)  # Decode: 1 query token
K = torch.randn(batch, heads, seq_len, head_dim, device=device)
V = torch.randn(batch, heads, seq_len, head_dim, device=device)
scale = 1.0 / (head_dim ** 0.5)

out_standard = standard_attention(Q, K, V, scale)
out_efficient = memory_efficient_attention(Q, K, V, scale)

print(f"Output difference: {(out_standard - out_efficient).abs().max().item():.6f}")

---
## 2. Attention with Quantized KV Cache

Dequantize K, V on-the-fly during attention computation.

In [None]:
def quantized_kv_attention(Q, K_int8, K_scales, V_int8, V_scales, scale):
    """
    Attention with INT8 quantized KV cache.
    
    K_int8, V_int8: [batch, heads, seq_kv, head_dim] as INT8
    K_scales, V_scales: [batch, heads, seq_kv, 1] per-token scales
    """
    # Dequantize K
    K = K_int8.float() * K_scales
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn_weights = F.softmax(scores, dim=-1)
    
    # Dequantize V and compute output
    V = V_int8.float() * V_scales
    output = torch.matmul(attn_weights, V)
    
    return output

def quantize_kv(K, V):
    """Quantize K, V to INT8 with per-token scales."""
    # Quantize K
    K_abs_max = K.abs().max(dim=-1, keepdim=True)[0]
    K_scales = K_abs_max / 127.0
    K_scales = K_scales.clamp(min=1e-8)
    K_int8 = (K / K_scales).round().clamp(-128, 127).to(torch.int8)
    
    # Quantize V
    V_abs_max = V.abs().max(dim=-1, keepdim=True)[0]
    V_scales = V_abs_max / 127.0
    V_scales = V_scales.clamp(min=1e-8)
    V_int8 = (V / V_scales).round().clamp(-128, 127).to(torch.int8)
    
    return K_int8, K_scales, V_int8, V_scales

# Test quantized attention
K_int8, K_scales, V_int8, V_scales = quantize_kv(K, V)
out_quantized = quantized_kv_attention(Q, K_int8, K_scales, V_int8, V_scales, scale)

print(f"Quantized vs Standard difference: {(out_standard - out_quantized).abs().max().item():.6f}")

# Memory comparison
fp16_memory = K.numel() * 2 + V.numel() * 2  # FP16
int8_memory = K_int8.numel() + V_int8.numel() + K_scales.numel() * 2 + V_scales.numel() * 2
print(f"Memory: {fp16_memory / 1024:.1f} KB (FP16) vs {int8_memory / 1024:.1f} KB (INT8)")

---
## 3. Bandwidth Analysis

For decode (batch=1, seq_q=1), attention is entirely memory-bound.

In [None]:
def analyze_bandwidth(seq_kv, num_heads, head_dim, dtype_bytes):
    """
    Analyze memory bandwidth requirements for decode attention.
    
    For each generated token, we must load:
    - All K: seq_kv × num_heads × head_dim × dtype_bytes
    - All V: seq_kv × num_heads × head_dim × dtype_bytes
    """
    kv_bytes = 2 * seq_kv * num_heads * head_dim * dtype_bytes
    return kv_bytes

# Llama-2-7B-like configuration
num_layers = 32
num_heads = 32
head_dim = 128

print("Decode Attention Memory Bandwidth Analysis (Llama-2-7B):")
print("=" * 60)
print(f"{'Context':<10} {'FP16':>12} {'FP8/INT8':>12} {'Savings':>12}")
print("-" * 60)

for ctx in [1024, 4096, 16384, 65536]:
    bw_fp16 = analyze_bandwidth(ctx, num_heads, head_dim, 2) * num_layers
    bw_int8 = analyze_bandwidth(ctx, num_heads, head_dim, 1) * num_layers
    
    print(f"{ctx:<10} {bw_fp16 / 1024**2:>10.1f} MB {bw_int8 / 1024**2:>10.1f} MB {bw_fp16 / bw_int8:>10.1f}x")

# Estimate tokens per second
h100_bandwidth = 3.35e12  # bytes/second
ctx = 4096
bw_per_token_fp16 = analyze_bandwidth(ctx, num_heads, head_dim, 2) * num_layers
bw_per_token_int8 = analyze_bandwidth(ctx, num_heads, head_dim, 1) * num_layers

print(f"\nEstimated decode throughput at ctx={ctx} (H100, attention only):")
print(f"  FP16: {h100_bandwidth / bw_per_token_fp16:.0f} tokens/sec")
print(f"  INT8: {h100_bandwidth / bw_per_token_int8:.0f} tokens/sec")

---
## 4. Triton Kernel Structure (Conceptual)

Here's the structure of a fused quantized attention kernel.

In [None]:
# Pseudo-code for fused quantized attention kernel
quantized_attention_kernel = '''
@triton.jit
def fused_quantized_attention(
    Q_ptr, K_int8_ptr, K_scale_ptr, V_int8_ptr, V_scale_ptr, O_ptr,
    seq_len, num_heads, head_dim,
    BLOCK_KV: tl.constexpr, BLOCK_HEAD: tl.constexpr
):
    # Each program handles one query head
    head_idx = tl.program_id(0)
    
    # Load query vector (fits in registers)
    q = tl.load(Q_ptr + head_idx * head_dim + tl.arange(0, BLOCK_HEAD))
    
    # Initialize online softmax accumulators
    m = float('-inf')  # Running max
    l = 0.0           # Running sum of exp
    acc = tl.zeros([BLOCK_HEAD], dtype=tl.float32)  # Weighted V accumulator
    
    # Stream through KV cache in blocks
    for kv_start in range(0, seq_len, BLOCK_KV):
        # Load and dequantize K block
        k_int8 = tl.load(K_int8_ptr + ...)  # INT8 values
        k_scale = tl.load(K_scale_ptr + ...) # Scale factors
        k = k_int8.to(tl.float32) * k_scale
        
        # Compute dot products: q @ k.T
        scores = tl.dot(q, k.T) * (1.0 / sqrt(head_dim))
        
        # Online softmax update
        m_block = tl.max(scores)
        m_new = tl.maximum(m, m_block)
        
        # Rescale old accumulator
        exp_diff = tl.exp(m - m_new)
        l = l * exp_diff
        acc = acc * exp_diff
        
        # Add new block
        exp_scores = tl.exp(scores - m_new)
        l += tl.sum(exp_scores)
        
        # Load and dequantize V block, accumulate weighted sum
        v_int8 = tl.load(V_int8_ptr + ...)
        v_scale = tl.load(V_scale_ptr + ...)
        v = v_int8.to(tl.float32) * v_scale
        acc += tl.dot(exp_scores, v)
        
        m = m_new
    
    # Normalize and store
    output = acc / l
    tl.store(O_ptr + head_idx * head_dim + tl.arange(0, BLOCK_HEAD), output)
'''

print("Fused Quantized Attention Kernel Structure:")
print(quantized_attention_kernel)

---
## Exercises

1. **Implement in Triton**: Turn the pseudo-code into a working Triton kernel
2. **FP8 Variant**: Implement FP8 E4M3 KV cache instead of INT8
3. **Profile**: Use Nsight Compute to measure actual memory bandwidth utilization

## Key Takeaways

- Decode attention is entirely memory-bound (bandwidth limited)
- INT8/FP8 KV cache doubles effective bandwidth
- Fused kernels minimize memory traffic by computing online softmax
- Per-token scaling is a good balance between accuracy and overhead