<a href="https://colab.research.google.com/github/dljones555/llm_block_exclusion/blob/main/exclusion_routing_triton_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import time
import math

@triton.jit
def exclude_kernel(
    headers_ptr,
    mask_ptr,
    seq_len,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < seq_len

    # Load headers (4 rows packed sequentially)
    h_type   = tl.load(headers_ptr + offsets, mask=mask, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store 1 = keep, 0 = excluded
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask)

def test_baseline(seq_len=64, device="cuda"):

    print(f"Testing baseline with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64

    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)

    torch.cuda.synchronize()
    start = time.time()

    # pure matmul timing - no softmax
    dummy_scores = torch.matmul(q, k.transpose(-2, -1))

    # dummy_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_head)

    # dummy_attn = F.softmax(dummy_scores, dim=-1)
    dummy_attn = torch.nn.functional.scaled_dot_product_attention(
        q, k, k, is_causal=False
    )
    torch.cuda.synchronize()
    end = time.time()

    print(f"Baseline full matmul time: {end - start:.6f}s")

# Test function (run on GPU)
def test_kernel(seq_len=64):

    print(f"Test_kernel with seq_len={seq_len}")
    print(f"-----")

    device = "cuda"

    headers = torch.randint(
        0, 10, (4 * seq_len,),
        device=device,
        dtype=torch.int32
    )

    mask = torch.zeros(seq_len, device=device, dtype=torch.int32)

    BLOCK_SIZE = 32
    grid = (triton.cdiv(seq_len, BLOCK_SIZE),)

    torch.cuda.synchronize()
    start = time.time()

    exclude_kernel[grid](
        headers,
        mask,
        seq_len,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    torch.cuda.synchronize()
    end = time.time()

    print(f"Kernel time: {end - start:.6f}s")
    sparsity = (mask == 0).float().mean().item() * 100
    print(f"Excluded tokens: {sparsity:.2f}%")

seq_len=2048
test_baseline(seq_len)
test_kernel(seq_len)

seq_len=4096
test_baseline(seq_len)
test_kernel(seq_len)

#seq_len=8192
#test_baseline(seq_len)
#test_kernel(seq_len)

sql_len=37268
test_baseline(sql_len)
test_kernel(sql_len)


Testing baseline with seq_len=2048
-----
Baseline full matmul time: 0.001669s
Test_kernel with seq_len=2048
-----
Kernel time: 0.003026s
Excluded tokens: 79.05%
Testing baseline with seq_len=4096
-----
Baseline full matmul time: 0.004678s
Test_kernel with seq_len=4096
-----
Kernel time: 0.000064s
Excluded tokens: 79.83%
Testing baseline with seq_len=37268
-----
Baseline full matmul time: 0.268003s
Test_kernel with seq_len=37268
-----
Kernel time: 0.001621s
Excluded tokens: 79.42%
