In [1]:
import torch
import torch.nn as nn
import time

class StandardAttention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim)
        attn = torch.einsum('bthd,bshd->bhts', q, k) * self.scale
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('bhts,bshd->bthd', attn, v).contiguous().view(B, T, C)
        return self.out_proj(out)

class GQAAttention(nn.Module):
    def __init__(self, dim, n_heads, n_kv_heads):
        super().__init__()
        assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, self.head_dim * n_kv_heads)
        self.v_proj = nn.Linear(dim, self.head_dim * n_kv_heads)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim)

        # Repeat k and v to match the number of query heads
        repeat_factor = self.n_heads // self.n_kv_heads
        k = k.repeat_interleave(repeat_factor, dim=2)
        v = v.repeat_interleave(repeat_factor, dim=2)

        attn = torch.einsum('bthd,bshd->bhts', q, k) * self.scale
        attn = torch.softmax(attn, dim=-1)
        out = torch.einsum('bhts,bshd->bthd', attn, v).contiguous().view(B, T, C)
        return self.out_proj(out)

def benchmark_attention(attn_module, input_tensor, label):
    torch.cuda.empty_cache()
    start = time.time()
    output = attn_module(input_tensor)
    end = time.time()
    mem = torch.cuda.memory_allocated() / 1e6 if torch.cuda.is_available() else 0
    print(f"{label}: Output shape: {output.shape}, Time: {end - start:.4f}s, CUDA Mem: {mem:.2f} MB")

In [3]:
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(0)
    B, T, C = 2, 128, 512   # batch size, sequence length, embedding dim
    n_heads = 16
    n_kv_heads = 4

    x = torch.randn(B, T, C).to(device)

    print("=== Standard Multi-Head Attention ===")
    std_attn = StandardAttention(C, n_heads).to(device)
    benchmark_attention(std_attn, x, "Standard Attention")

    print("\n=== Grouped Query Attention (GQA) ===")
    gqa_attn = GQAAttention(C, n_heads, n_kv_heads).to(device)
    benchmark_attention(gqa_attn, x, "GQA Attention")


=== Standard Multi-Head Attention ===
Standard Attention: Output shape: torch.Size([2, 128, 512]), Time: 0.0038s, CUDA Mem: 0.00 MB

=== Grouped Query Attention (GQA) ===
GQA Attention: Output shape: torch.Size([2, 128, 512]), Time: 0.0029s, CUDA Mem: 0.00 MB
