In [None]:
# worked solution 
import torch
import math

def flash_attention_forward(q, k, v, is_causal=True, window_size=None, sink_size=None):
    """
    Problem 7: Linear/Kernelized Attention with Sliding Window + Sink Tokens
    Args:
        q: (B,Hq,L,D)
        k: (B,Hkv,L,D)
        v: (B,Hkv,L,D)
        is_causal: bool
        window_size: int or None
        sink_size: int or None
    Returns:
        out: (B,Hq,L,D)
    """
    B,Hq,L,D = q.shape
    _,Hkv,_,_ = k.shape
    assert v.shape == (B,Hkv,L,D)

    # Repeat kv for GQA if needed
    if Hq != Hkv:
        num_groups = Hq // Hkv
        k = k.repeat_interleave(num_groups, dim=1)
        v = v.repeat_interleave(num_groups, dim=1)

    scale = 1.0 / math.sqrt(D)
    scores = torch.einsum("bhld,bhmd->bhlm", q, k) * scale

    if is_causal:
        idx = torch.arange(L, device=q.device)
        row, col = idx[:,None], idx[None,:]
        # sliding window mask
        if window_size is None:
            sliding = col <= row
        else:
            sliding = (col <= row) & (col >= row-(window_size-1))
        # sink tokens mask
        if sink_size is not None and sink_size > 0:
            sink = (col < sink_size) & (col <= row)
            mask = sliding | sink
        else:
            mask = sliding
        mask = mask.view(1,1,L,L)
        scores = scores.masked_fill(~mask, float("-inf"))

    # Softmax with explicit float32 for stability, then cast back
    weights = torch.softmax(scores.to(torch.float32), dim=-1).to(q.dtype)
    out = torch.einsum("bhlm,bhmd->bhld", weights, v)
    return out

if __name__ == "__main__":
    torch.manual_seed(0)
    B,Hq,Hkv,L,D = 1,8,2,128,16
    q = torch.randn(B,Hq,L,D, device="cuda" if torch.cuda.is_available() else "cpu")
    k = torch.randn(B,Hkv,L,D, device=q.device)
    v = torch.randn(B,Hkv,L,D, device=q.device)
    out = flash_attention_forward(q,k,v,is_causal=True, window_size=32, sink_size=4)
    print("smoke test out:", out.shape, out.dtype)


In [None]:
# Autograder Cell for Problem 7 (Second Cell)
import torch, math, time
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

def repeat_kv(x, num_groups):
    if num_groups == 1: return x
    B, H_kv, N, D = x.shape
    x = x.unsqueeze(2).expand(B, H_kv, num_groups, N, D)
    return x.reshape(B, H_kv * num_groups, N, D)

def create_mask_bool(seq_len, window_size, sink_size, device=None):
    idx = torch.arange(seq_len, device=device)
    row = idx.unsqueeze(1); col = idx.unsqueeze(0)
    sliding = (col <= row) & (col >= row - (window_size - 1))
    sink = (col < sink_size) & (col <= row)
    return sliding | sink

def naive_attention(Q, K, V, is_causal=False, window_size=None, sink_size=None):
    B, Hq, L, D = Q.shape
    _, Hkv, _, _ = K.shape
    if Hq != Hkv:
        num_groups = Hq // Hkv
        K = repeat_kv(K, num_groups)
        V = repeat_kv(V, num_groups)
    scale = 1.0 / math.sqrt(D)
    S = (Q @ K.transpose(-1, -2)) * scale
    if is_causal:
        if window_size is None:
            mask = create_mask_bool(L, L, 0, Q.device)
        else:
            if sink_size is None:
                mask = create_mask_bool(L, window_size, 0, Q.device)
            else:
                mask = create_mask_bool(L, window_size, sink_size, Q.device)
        S.masked_fill_(~mask, -float('inf'))
    P = torch.softmax(S, dim=-1, dtype=torch.float32).to(Q.dtype)
    return P @ V, torch.logsumexp(S.to(torch.float32), dim=-1)

def run_correctness_test(test_params, student_func):
    B,Hq,Hkv,L,D,W,SNK = test_params
    q = torch.randn(B,Hq,L,D, device='cuda', dtype=DTYPE)
    k = torch.randn(B,Hkv,L,D, device='cuda', dtype=DTYPE)
    v = torch.randn(B,Hkv,L,D, device='cuda', dtype=DTYPE)
    torch_out,_ = naive_attention(q,k,v,is_causal=True, window_size=W, sink_size=SNK)
    tri_out = student_func(q,k,v,is_causal=True, window_size=W, sink_size=SNK)
    ok = torch.allclose(torch_out, tri_out, rtol=5e-2, atol=5e-2)
    return ok, (torch_out - tri_out).abs().max().item() if not ok else 0.0

if torch.cuda.is_available():
    print('🚀 Running Problem 7 Tests (GQA + SWA + Sink)')
    window_size, sink_size = 128, 8
    cases = [
        (1,8,2,512,32,window_size,sink_size),
        (1,8,2,1024,32,window_size,sink_size),
        (1,16,2,2048,16,window_size,sink_size),
        (1,16,2,4096,16,window_size,sink_size),
    ]
    all_ok = True
    for params in cases:
        ok, diff = run_correctness_test(params, flash_attention_forward)
        desc = f"B={params[0]},Hq={params[1]},Hkv={params[2]},L={params[3]},D={params[4]},W={params[5]},S={params[6]}"
        if ok:
            print(f'✅ Passed {desc}')
        else:
            print(f'❌ Failed {desc} max diff {diff:.5f}')
            all_ok = False
    if all_ok: print('\nAll Problem 7 tests passed!')
else:
    print('CUDA not available; cannot run tests.')