<a href="https://colab.research.google.com/github/dljones555/llm_block_exclusion/blob/main/block_level_exclusion_mve.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import time
import math

@triton.jit
def exclude_kernel(
    headers_ptr,
    mask_ptr,
    block_mast_ptr,
    seq_len,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < seq_len

    # Load headers (4 rows packed sequentially)
    h_type   = tl.load(headers_ptr + offsets, mask=mask, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store 1 = keep, 0 = excluded
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask)

    # Compute block-level mask: if any token in the block is compatible, keep the block
    block_keep = tl.sum(compat.to(tl.int32), axis=0) > 0

    # Store the block-level mask
    tl.store(block_mast_ptr + pid, block_keep.to(tl.int32))

@triton.jit
def exclude_attention_kernel(
    q_ptr, k_ptr, v_ptr, headers_ptr, mask_ptr, block_mask_ptr, output_ptr,
    seq_len, d_head: tl.constexpr, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)  # Program ID (block)

    # Compute offsets for the current block along the sequence length
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_seq = offsets < seq_len  # Mask out-of-bounds offsets for sequence length

    # Load headers (4 rows packed sequentially)
    h_type   = tl.load(headers_ptr + offsets, mask=mask_seq, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask_seq, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask_seq, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask_seq, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store the exclusion mask (token level)
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask_seq)

    # Load the block-level mask (whether the block is kept or not)
    block_keep = tl.load(block_mask_ptr + pid)

    # If the block is excluded, skip the entire block (early exit)
    if block_keep == 0:
        return

    # Load Q, K, V tensors (flattened batch and head dimensions)
    q_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    q = tl.load(q_ptr + q_block_offsets, mask=mask_seq[:, None], other=0.0)

    k_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    k = tl.load(k_ptr + k_block_offsets, mask=mask_seq[:, None], other=0.0)

    v_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    v = tl.load(v_ptr + v_block_offsets, mask=mask_seq[:, None], other=0.0)

    # Compute attention scores (scaled dot-product attention)
    scores = tl.dot(q, tl.trans(k)) / (d_head**0.5)

    # Apply exclusion mask (0 for excluded, no contribution to scores from these positions)
    compat_expanded = compat[:, None].to(tl.float32)
    scores = scores * compat_expanded

    # Softmax (implementing F.softmax(scores, dim=-1) in Triton)
    scores_max = tl.max(scores, axis=1)[:, None]
    exp_scores = tl.math.exp(scores - scores_max)

    # Apply mask again after exponentiation to ensure excluded tokens are zero
    exp_scores = exp_scores * compat_expanded

    sum_exp_scores = tl.sum(exp_scores, axis=1)[:, None]
    attn_weights = exp_scores / sum_exp_scores

    # Apply attention to V
    result = tl.dot(attn_weights, v)

    # Store final attention result
    result_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    tl.store(output_ptr + result_offsets, result, mask=mask_seq[:, None])


def test_baseline(seq_len=64, device="cuda"):

    print(f"Testing baseline with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64

    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)

    torch.cuda.synchronize()
    start = time.time()

    # pure matmul timing - no softmax
    dummy_scores = torch.matmul(q, k.transpose(-2, -1))

    # dummy_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_head)

    # dummy_attn = F.softmax(dummy_scores, dim=-1)
    dummy_attn = torch.nn.functional.scaled_dot_product_attention(
        q, k, k, is_causal=False
    )
    torch.cuda.synchronize()
    end = time.time()

    print(f"Baseline full matmul time: {end - start:.6f}s")

def test_kernel_attention(seq_len=64, device="cuda"):
    print(f"Test_kernel_attention with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64
    BLOCK_SIZE = 32

    # Initialize input tensors (Q, K, V)
    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)
    v = torch.randn(batch, heads, seq_len, d_head, device=device)

    # Headers (to simulate your POS-related metadata)
    headers = torch.randint(0, 10, (4 * seq_len,), device=device, dtype=torch.int32)

    # Mask for storing the exclusion mask result (token-level)
    mask = torch.zeros(seq_len, device=device, dtype=torch.int32)

    # Output tensor for storing the final attention result
    output = torch.zeros_like(q)

    # Block-level mask for excluding blocks
    block_mask = torch.zeros(seq_len // BLOCK_SIZE, device=device, dtype=torch.int32)  # Assuming BLOCK_SIZE=32

    grid = (triton.cdiv(seq_len, BLOCK_SIZE),)

    torch.cuda.synchronize()
    start = time.time()

    # First, run the exclusion kernel to generate the token-level mask AND the block-level mask
    exclude_kernel[grid](
        headers,
        mask,
        block_mask, # Now correctly populated by exclude_kernel
        seq_len,
        BLOCK_SIZE=BLOCK_SIZE
    )

    torch.cuda.synchronize()
    end = time.time()
    print(f"Exclusion kernel time: {end - start:.6f}s")
    print(f"Block mask (first 100 values): {block_mask[:100].tolist()}")
    print(f"Number of blocks to process: {block_mask.sum().item()}")
    print(f"Total blocks: {len(block_mask)}")

    torch.cuda.synchronize()
    start = time.time()

    # Now, run the exclusion + attention kernel
    exclude_attention_kernel[grid](
        q, k, v, headers, mask, block_mask, output, seq_len, d_head, BLOCK_SIZE=BLOCK_SIZE
    )

    torch.cuda.synchronize()
    end = time.time()
    print(f"Kernel time (exclusion + attention): {end - start:.6f}s")


for seq_len in [64, 2048, 4096, 32768]:
  test_baseline(seq_len)
  test_kernel_attention(seq_len)

Testing baseline with seq_len=64
-----
Baseline full matmul time: 0.000442s
Test_kernel_attention with seq_len=64
-----
Exclusion kernel time: 0.003163s
Block mask (first 100 values): [1, 1]
Number of blocks to process: 2
Total blocks: 2
Kernel time (exclusion + attention): 0.004679s
Testing baseline with seq_len=2048
-----
Baseline full matmul time: 0.001466s
Test_kernel_attention with seq_len=2048
-----
Exclusion kernel time: 0.000094s
Block mask (first 100 values): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Number of blocks to process: 64
Total blocks: 64
Kernel time (exclusion + attention): 0.000103s
Testing baseline with seq_len=4096
-----
Baseline full matmul time: 0.004552s
Test_kernel_attention with seq_len=4096
-----
Exclusion kernel time: 0.000059s
Block mask (first 100 values): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1