<a href="https://colab.research.google.com/github/dljones555/llm_memory_concepts/blob/main/exclusion_routing_triton_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import time
import math

@triton.jit
def exclude_kernel(
    headers_ptr,
    mask_ptr,
    seq_len,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < seq_len

    # Load headers (4 rows packed sequentially)
    h_type   = tl.load(headers_ptr + offsets, mask=mask, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store 1 = keep, 0 = excluded
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask)

@triton.jit
def exclude_attention_kernel(
    q_ptr, k_ptr, v_ptr, headers_ptr, mask_ptr, output_ptr, seq_len, d_head: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)

    # Compute offsets for the current block along the sequence length
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_seq = offsets < seq_len  # Mask out-of-bounds offsets for sequence length

    # Load headers (4 rows packed sequentially)
    h_type   = tl.load(headers_ptr + offsets, mask=mask_seq, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask_seq, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask_seq, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask_seq, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store the exclusion mask (assuming mask_ptr is for this purpose)
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask_seq)

    # Load Q, K, V tensors (flattened batch and head dimensions)
    # Each pid loads a BLOCK_SIZE x d_head block of Q, K, V
    q_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    q = tl.load(q_ptr + q_block_offsets, mask=mask_seq[:, None], other=0.0)

    k_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    k = tl.load(k_ptr + k_block_offsets, mask=mask_seq[:, None], other=0.0)

    v_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    v = tl.load(v_ptr + v_block_offsets, mask=mask_seq[:, None], other=0.0)

    # Compute attention scores (scaled dot-product attention)
    # This performs a (BLOCK_SIZE, d_head) @ (d_head, BLOCK_SIZE) -> (BLOCK_SIZE, BLOCK_SIZE) matmul
    # Note: This is an attention computation *within* the current block of query/key tokens.
    scores = tl.dot(q, tl.trans(k)) / (d_head**0.5)

    # Apply exclusion mask (0 for excluded, no contribution to scores from these positions)
    # Assuming compat is (BLOCK_SIZE,) and scores is (BLOCK_SIZE, BLOCK_SIZE)
    # Expand compat to (BLOCK_SIZE, BLOCK_SIZE) to apply row-wise
    compat_expanded = compat[:, None].to(tl.float32)
    scores = scores * compat_expanded

    # Softmax (implementing F.softmax(scores, dim=-1) in Triton)
    # Subtract max for numerical stability
    scores_max = tl.max(scores, axis=1)[:, None]
    exp_scores = tl.math.exp(scores - scores_max)

    # Apply mask again after exponentiation to ensure excluded tokens are zero
    exp_scores = exp_scores * compat_expanded

    sum_exp_scores = tl.sum(exp_scores, axis=1)[:, None]
    attn_weights = exp_scores / sum_exp_scores

    # Apply attention to V
    # This performs (BLOCK_SIZE, BLOCK_SIZE) @ (BLOCK_SIZE, d_head) -> (BLOCK_SIZE, d_head) matmul
    result = tl.dot(attn_weights, v)

    # Store final attention result in output tensor
    # This assumes output_ptr stores the result of the attention computation
    result_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    tl.store(output_ptr + result_offsets, result, mask=mask_seq[:, None])

def test_baseline(seq_len=64, device="cuda"):

    print(f"Testing baseline with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64

    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)

    torch.cuda.synchronize()
    start = time.time()

    # pure matmul timing - no softmax
    dummy_scores = torch.matmul(q, k.transpose(-2, -1))

    # dummy_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_head)

    # dummy_attn = F.softmax(dummy_scores, dim=-1)
    dummy_attn = torch.nn.functional.scaled_dot_product_attention(
        q, k, k, is_causal=False
    )
    torch.cuda.synchronize()
    end = time.time()

    print(f"Baseline full matmul time: {end - start:.6f}s")

# Test function (run on GPU)
def test_kernel(seq_len=64):

    print(f"Test_kernel with seq_len={seq_len}")
    print(f"-----")

    device = "cuda"

    headers = torch.randint(
        0, 10, (4 * seq_len,),
        device=device,
        dtype=torch.int32
    )

    mask = torch.zeros(seq_len, device=device, dtype=torch.int32)

    BLOCK_SIZE = 32
    grid = (triton.cdiv(seq_len, BLOCK_SIZE),)

    torch.cuda.synchronize()
    start = time.time()

    exclude_attention_kernel[grid](
        headers,
        mask,
        None,  # We don't need the output tensor here for the exclusion kernel
        seq_len,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    torch.cuda.synchronize()
    end = time.time()

    print(f"Kernel time: {end - start:.6f}s")
    sparsity = (mask == 0).float().mean().item() * 100
    print(f"Excluded tokens: {sparsity:.2f}%")

def test_kernel_attention(seq_len=64, device="cuda"):
    print(f"Test_kernel_attention with seq_len={seq_len}")
    print(f"-----")

    device = "cuda"

    batch = 1
    heads = 1
    d_head = 64

    # Initialize input tensors (Q, K, V)
    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)
    v = torch.randn(batch, heads, seq_len, d_head, device=device)

    # Headers (to simulate your POS-related metadata)
    headers = torch.randint(0, 10, (4 * seq_len,), device=device, dtype=torch.int32)

    # Mask for storing the exclusion mask result
    # Use a separate tensor for storing the attention results

    mask = torch.zeros(seq_len, device=device, dtype=torch.int32)  # For exclusion
    output = torch.zeros_like(q, device=device)  # For storing the final attention output

    BLOCK_SIZE = 32
    grid = (triton.cdiv(seq_len, BLOCK_SIZE),)

    torch.cuda.synchronize()
    start = time.time()

    # Call the exclusion + attention kernel
    exclude_attention_kernel[grid](
        q, k, v, headers, mask, output, seq_len, d_head, BLOCK_SIZE=BLOCK_SIZE
    )

    torch.cuda.synchronize()
    end = time.time()

    print(f"Kernel time (exclusion + attention): {end - start:.6f}s")
    # Sparsity calculation might not be meaningful if mask now holds attention results
    # sparsity = (mask == 0).float().mean().item() * 100
    # print(f"Excluded tokens: {sparsity:.2f}%")

#
for seq_len in [64, 2048, 4096, 32768]:
  test_baseline(seq_len)
  test_kernel_attention(seq_len)


Testing baseline with seq_len=64
-----
Baseline full matmul time: 0.000264s
Test_kernel_attention with seq_len=64
-----
Kernel time (exclusion + attention): 0.004596s
Testing baseline with seq_len=2048
-----
Baseline full matmul time: 0.001473s
Test_kernel_attention with seq_len=2048
-----
Kernel time (exclusion + attention): 0.000108s
Testing baseline with seq_len=4096
-----
Baseline full matmul time: 0.004643s
Test_kernel_attention with seq_len=4096
-----
Kernel time (exclusion + attention): 0.000138s
Testing baseline with seq_len=32768
-----
Baseline full matmul time: 0.223682s
Test_kernel_attention with seq_len=32768
-----
Kernel time (exclusion + attention): 0.000397s
