In [None]:
import torch
import time

def linear_attention_flasmla_inspired(q, k, v):
    """
    FlashMLA-inspired Linear Attention (without flash-attn library)
    """
    # Feature Map (elu + 1)
    q = torch.nn.functional.elu(q) + 1
    k = torch.nn.functional.elu(k) + 1

    # Compute KV (K^T V first)
    kv = torch.einsum("bhnd,bhmd->bhnm", k, v)

    # Compute normalization factors
    q_sums = torch.einsum("bhnd->bhn", q)
    k_sums = torch.einsum("bhnd->bhn", k)
    kv_sums = torch.einsum("bhn,bhnm->bhm", k_sums, kv)

    # Compute Q(KV)
    out = torch.einsum("bhnd,bhnm->bhmd", q, kv)

    # Normalize
    normalizer = torch.einsum("bhn,bhm->bhm", q_sums, kv_sums).unsqueeze(-1)
    out = out / normalizer

    return out

def standard_attention(q, k, v, mask=None, softmax_scale=None):
    """
    Standard Attention (Scaled Dot-Product Attention)
    """
    d_k = q.size(-1)
    if softmax_scale is None:
        softmax_scale = 1.0 / (d_k ** 0.5)

    scores = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    attention_weights = torch.softmax(scores, dim=-1)
    out = torch.matmul(attention_weights, v)

    return out, attention_weights

def compare_attention(batch_size, num_heads, seq_len_q, seq_len_k, head_dim, device="cuda", num_iters=100, causal=False):
    """
    Compare Standard Attention and FlashMLA-inspired Linear Attention
    """
    q = torch.randn(batch_size, num_heads, seq_len_q, head_dim, device=device)
    k = torch.randn(batch_size, num_heads, seq_len_k, head_dim, device=device)
    v = torch.randn(batch_size, num_heads, seq_len_k, head_dim, device=device)

    mask = None
    if causal:
        mask = torch.tril(torch.ones(seq_len_q, seq_len_q, device=device, dtype=torch.bool))
        mask = mask.view(1, 1, seq_len_q, seq_len_q)

    # Warmup
    for _ in range(5):
        with torch.no_grad():
            _ = standard_attention(q, k, v, mask)
            _ = linear_attention_flasmla_inspired(q, k, v)

    torch.cuda.synchronize() if device == "cuda" else None

    # Timing for Standard Attention
    if device == "cuda":
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        for _ in range(num_iters):
            with torch.no_grad():
                _ = standard_attention(q, k, v, mask)
        end.record()
        torch.cuda.synchronize()
        std_time = start.elapsed_time(end) / num_iters
    else:
        start_time = time.perf_counter()
        for _ in range(num_iters):
            with torch.no_grad():
                _ = standard_attention(q, k, v, mask)
        std_time = (time.perf_counter() - start_time) / num_iters * 1000

    # Timing for Linear Attention
    if device == "cuda":
        start.record()
        for _ in range(num_iters):
            with torch.no_grad():
                _ = linear_attention_flasmla_inspired(q, k, v)
        end.record()
        torch.cuda.synchronize()
        linear_time = start.elapsed_time(end) / num_iters
    else:
        start_time = time.perf_counter()
        for _ in range(num_iters):
            with torch.no_grad():
                _ = linear_attention_flasmla_inspired(q, k, v)
        linear_time = (time.perf_counter() - start_time) / num_iters * 1000

    return std_time, linear_time

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")

    batch_size = 32
    num_heads = 8
    head_dim = 64

    results = {}
    sequence_lengths = [(64, 64), (128, 128), (256, 256), (512, 512), 
                       (1024, 1024), (2048, 2048), (4096, 4096)]

    for seq_len_q, seq_len_k in sequence_lengths:
        if seq_len_q == seq_len_k:
            std_time, linear_time = compare_attention(
                batch_size, num_heads, seq_len_q, seq_len_k, 
                head_dim, device, causal=True
            )
            results[(seq_len_q, seq_len_k, True)] = (std_time, linear_time)

        std_time, linear_time

Device: cuda
