In [None]:
# problem_6.py
import torch
import math

def _build_q_to_kv_map(Hq, Hkv):
    # Mirror the original integer-division mapping used in the naive reference
    if Hq % Hkv != 0:
        group = max(1, Hq // Hkv)
    else:
        group = Hq // Hkv
    mapping = [min(hq // group, Hkv - 1) for hq in range(Hq)]
    return torch.tensor(mapping, dtype=torch.long)

def flash_attention_forward(q, k, v, is_causal=True, window_size=128):
    """
    Vectorized implementation for Sliding Window Attention (Problem 6)
    - Preserves the original GQA mapping semantics used by the naive reference
    - Supports causal or non-causal with sliding window of size `window_size`
    - q: (B, Hq, L, D)
    - k: (B, Hkv, L, D)
    - v: (B, Hkv, L, D)
    Returns:
    - out: (B, Hq, L, D)
    """
    assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q/k/v must be 4D tensors"
    B, Hq, Lq, D = q.shape
    _, Hkv, Lk, Dk = k.shape
    assert D == Dk, "query/key head dim mismatch"
    assert Lq == Lk, "sequence lengths of q and k/v must match for sliding attention"

    device = q.device
    dtype = q.dtype

    # Build q->kv mapping exactly like the naive implementation
    mapping = _build_q_to_kv_map(Hq, Hkv).to(device)  # length Hq, values in [0, Hkv-1]

    # Expand K and V to have Hq heads by indexing with mapping
    # resulting shapes: (B, Hq, L, D)
    k_exp = k[:, mapping, :, :]
    v_exp = v[:, mapping, :, :]

    # Compute scaled dot-product scores: (B, Hq, L, L)
    scale = 1.0 / math.sqrt(D)
    # einsum is memory-efficient and clear here
    scores = torch.einsum("b h l d, b h m d -> b h l m", q, k_exp) * scale

    # Build mask of allowed keys for each query position according to naive logic
    # Query positions are rows (i), key positions are cols (j)
    idx = torch.arange(Lq, device=device)
    q_idx = idx[:, None]  # (L,1)
    k_idx = idx[None, :]  # (1,L)

    if window_size is None:
        # window_size None means full window
        if is_causal:
            # causal only: allow j <= i
            mask = (k_idx <= q_idx)  # (L,L)
        else:
            # full non-causal: allow all positions
            mask = torch.ones((Lq, Lq), dtype=torch.bool, device=device)
    else:
        # sliding window with possible causal behavior
        # start = max(0, i - window_size + 1)
        start_idx = q_idx - (window_size - 1)
        start_idx = torch.clamp(start_idx, min=0)
        # allowed j satisfy: j >= start_idx and j < end
        if is_causal:
            # end = i+1
            end_idx = q_idx + 1
            mask = (k_idx >= start_idx) & (k_idx < end_idx)
        else:
            # non-causal: end = L (allow future keys too)
            mask = (k_idx >= start_idx) & (k_idx < Lq)

    # Expand mask to (1,1,L,L) to broadcast over B and Hq
    mask = mask.view(1, 1, Lq, Lq)

    # Apply mask by setting disallowed logits to a large negative number
    scores = scores.masked_fill(~mask, float("-inf"))

    # Softmax over key positions
    weights = torch.softmax(scores, dim=-1)

    # Weighted sum to get output: (B, Hq, L, D)
    out = torch.einsum("b h l m, b h m d -> b h l d", weights, v_exp)

    return out

# small smoke test when run as script
if __name__ == "__main__":
    torch.manual_seed(0)
    B, Hq, Hkv, L, D = 1, 8, 2, 512, 16
    q = torch.randn(B, Hq, L, D, device='cuda' if torch.cuda.is_available() else 'cpu')
    k = torch.randn(B, Hkv, L, D, device=q.device)
    v = torch.randn(B, Hkv, L, D, device=q.device)
    out = flash_attention_forward(q, k, v, is_causal=True, window_size=128)
    print("smoke test output shape:", out.shape)


In [None]:
# Autograder for Problem 6: Sliding Window Attention
import torch
import torch.nn.functional as F
import math
import time

DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

def repeat_kv(x, num_groups):
    """Helper function to repeat K/V heads for GQA naive implementation."""
    if num_groups == 1:
        return x
    B, H_kv, N, D = x.shape
    x = x.unsqueeze(2).expand(B, H_kv, num_groups, N, D)
    return x.reshape(B, H_kv * num_groups, N, D)

def create_mask_bool(
    seq_len: int,
    window_size: int,
    sink_size: int,
    device=None
    ) -> torch.Tensor:
    
    idx = torch.arange(seq_len, device=device)
    row = idx.unsqueeze(1)   # (seq_len, 1)
    col = idx.unsqueeze(0)   # (1, seq_len)

    # 1) sliding window:  i - (window_size-1) <= j <= i
    sliding = (col <= row) & (col >= row - (window_size - 1))

    # 2) sink at start:   j < sink_size  *and*  j <= i
    sink = (col < sink_size) & (col <= row)

    return sliding | sink

def naive_attention(Q, K, V, is_causal=False, window_size=None, sink_size=None):
    """
    A correct, robust PyTorch implementation of standard attention for comparison.
    Supports GQA, Sliding Window, and Attention Sinks.
    """
    
    batch_size, num_heads_q, seq_len, head_dim = Q.shape
    _, num_heads_kv, seq_len, head_dim = K.shape

    if num_heads_q != num_heads_kv:
        num_groups = num_heads_q // num_heads_kv
        K = repeat_kv(K, num_groups)
        V = repeat_kv(V, num_groups)

    scale = 1.0 / math.sqrt(head_dim)
    S = (Q @ K.transpose(-1, -2)) * scale
    
    if is_causal:
        mask = None
        if window_size is None: # Causal only
            mask = create_mask_bool(seq_len=seq_len, window_size=seq_len, sink_size=0, device=Q.device)
        else:
            if sink_size is None: # SWA only
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=0, device=Q.device)
            else: # SWA + Sink
                mask = create_mask_bool(seq_len, window_size=window_size, sink_size=sink_size, device=Q.device)
                
        S.masked_fill_(~mask, -float('inf'))

    P = torch.nn.functional.softmax(S, dim=-1, dtype=torch.float32).to(Q.dtype)
    O_final = P @ V
    L_final = torch.logsumexp(S.to(torch.float32), dim=-1)
    
    return O_final, L_final

def run_correctness_test(test_case, triton_func, is_causal=False, is_gqa=False, is_swa=False, problem_num=1):
    """Run a single correctness test."""
    
    window_size, sink_size = None, None
    if is_gqa and not is_swa: # GQA only 
        batch, heads_q, heads_kv, seq_len, dim = test_case
        config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}"
    elif is_swa: # GQA + SWA
        batch, heads_q, heads_kv, seq_len, dim, *window_params = test_case
        if len(window_params) == 1:
            window_size = window_params[0]
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}"
        else:
            window_size, sink_size = window_params
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}, S={sink_size}"
    else:
        batch, heads_q, seq_len, dim = test_case
        heads_kv = heads_q
        config_str = f"B={batch}, H={heads_q}, L={seq_len}, D={dim}"

    q = torch.randn(batch, heads_q, seq_len, dim, device='cuda', dtype=DTYPE)
    k = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    v = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    
    try:
        if is_swa:
            if sink_size is not None:
                triton_out = triton_func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
            else:
                triton_out = triton_func(q, k, v, is_causal=is_causal, window_size=window_size)
        else:
            triton_out = triton_func(q, k, v, is_causal=is_causal)
        
        naive_out, _ = naive_attention(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        
        if torch.allclose(triton_out, naive_out, atol=1e-2, rtol=1e-2):
            print(f"✅ P{problem_num} Correctness Test Passed! ({config_str})")
            return True
        else:
            max_diff = torch.max(torch.abs(triton_out - naive_out)).item()
            print(f"❌ P{problem_num} Correctness Test Failed! ({config_str}) Max diff: {max_diff:.6f}")
            return False
            
    except Exception as e:
        print(f"❌ P{problem_num} Error during execution ({config_str}): {str(e)}")
        return False

def benchmark_attention(triton_func, naive_func, test_params, is_causal, is_gqa=False, is_swa=False):
    """Utility to benchmark an attention function and compare it to a naive implementation."""
    print("\n--- Running Performance Benchmark ---")
    window_size, sink_size = None, None
    if is_gqa and not is_swa: # GQA only 
        batch, heads_q, heads_kv, seq_len, dim = test_params
        config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}"
    elif is_swa: # GQA + SWA
        batch, heads_q, heads_kv, seq_len, dim, *window_params = test_params
        if len(window_params) == 1:
            window_size = window_params[0]
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}"
        else:
            window_size, sink_size = window_params
            config_str = f"B={batch}, Hq={heads_q}, Hkv={heads_kv}, L={seq_len}, D={dim}, W={window_size}, S={sink_size}"
    else:
        batch, heads_q, seq_len, dim = test_params
        heads_kv = heads_q
        config_str = f"B={batch}, H={heads_q}, L={seq_len}, D={dim}"

    print(f"Benchmark Config: {config_str}, Causal={is_causal}")
    
    q = torch.randn(batch, heads_q, seq_len, dim, device='cuda', dtype=DTYPE)
    k = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)
    v = torch.randn(batch, heads_kv, seq_len, dim, device='cuda', dtype=DTYPE)

    def _run_benchmark(func, is_triton):
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        start_time = time.time()
        
        for _ in range(10):  # Warmup
            if is_triton:
                if is_swa:
                    if sink_size is not None:
                        _ = func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
                    else:
                        _ = func(q, k, v, is_causal=is_causal, window_size=window_size)
                else:
                    _ = func(q, k, v, is_causal=is_causal)
            else:
                _, _ = func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        
        torch.cuda.synchronize()
        start_time = time.time()
        
        for _ in range(100):  # Actual benchmark
            if is_triton:
                if is_swa:
                    if sink_size is not None:
                        _ = func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
                    else:
                        _ = func(q, k, v, is_causal=is_causal, window_size=window_size)
                else:
                    _ = func(q, k, v, is_causal=is_causal)
            else:
                _, _ = func(q, k, v, is_causal=is_causal, window_size=window_size, sink_size=sink_size)
        
        torch.cuda.synchronize()
        end_time = time.time()
        
        avg_time = (end_time - start_time) / 100 * 1000  # Convert to ms
        peak_memory = torch.cuda.max_memory_allocated() / 1e9  # Convert to GB
        
        return avg_time, peak_memory

    # Benchmark both implementations
    triton_time, triton_memory = _run_benchmark(triton_func, is_triton=True)
    naive_time, naive_memory = _run_benchmark(naive_func, is_triton=False)

    print("\n--- Benchmark Results ---")
    print(f"{'Implementation':<25} | {'Avg Time (ms)':<20} | {'Peak Memory (GB)':<20}")
    print("-" * 70)
    print(f"{'PyTorch (Naive)':<25} | {naive_time:<20.4f} | {naive_memory:<20.4f}")
    print(f"{'Triton (Flash)':<25} | {triton_time:<20.4f} | {triton_memory:<20.4f}")
    print("-" * 70)
    
    speedup = naive_time / triton_time
    memory_reduction = naive_memory / triton_memory
    
    print(f"Triton is {speedup:.2f}x faster than PyTorch (Naive).")
    print(f"Triton uses {memory_reduction:.2f}x less memory.")

def check_problem_6():
    """Checks Problem 6: Sliding Window Attention."""
    problem_num = 6
    print(f"\n--- Running Autograder for Problem {problem_num}: Sliding Window Attention ---")
    
    torch.manual_seed(47)
    # Test cases: (Batch, Heads_Q, Heads_KV, SeqLen, Dim, WindowSize)
    window_size = 128
    test_cases = [
        (1, 8, 2, 512, 16, window_size),
        (1, 8, 2, 1024, 16, window_size),
        (1, 16, 2, 2048, 16, window_size),
        (1, 16, 2, 4096, 16, window_size),
    ]
    
    results = [run_correctness_test(case, flash_attention_forward, is_causal=True, is_gqa=True, is_swa=True, problem_num=problem_num) for case in test_cases]
    if all(results):
        print(f"\nAll P{problem_num} correctness tests passed!")
        benchmark_attention(flash_attention_forward, naive_attention, test_cases[-1], is_causal=True, is_gqa=True, is_swa=True)
    else:
        print(f"\n❌ Some P{problem_num} tests failed. Please check your implementation.")

# Run the autograder
if torch.cuda.is_available():
    print("🚀 Starting Problem 6 Autograder...")
    print("📝 Testing: GQA + Sliding Window Attention + Causal Masking")
    check_problem_6()
else:
    print("❌ CUDA not available. Please run this on a GPU-enabled environment.")
