In [8]:
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import time
import math

# Triton kernel for initial exclusion logic and block-level masking
# This kernel processes header data to determine which tokens are 'compatible'
# and computes a block-level mask for optimization.
@triton.jit
def exclude_kernel(
    headers_ptr,       # Pointer to the input header data (e.g., POS tags, affordances, energy, age)
    mask_ptr,          # Output pointer for the token-level exclusion mask (1=keep, 0=exclude)
    block_mast_ptr,    # Output pointer for the block-level exclusion mask (1=keep block, 0=exclude block)
    seq_len,           # Total sequence length
    BLOCK_SIZE: tl.constexpr, # Size of the processing block for each program instance
):
    # Get the program ID (block ID) for the current Triton program instance
    pid = tl.program_id(0)

    # Calculate global offsets for tokens within this block
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    # Create a mask to ensure we don't access out-of-bounds tokens at the end of the sequence
    mask = offsets < seq_len

    # Load headers for the current block. We assume 4 rows of header data packed sequentially.
    # headers_ptr points to the start of 'h_type', then 'h_afford' starts at seq_len offset, etc.
    h_type   = tl.load(headers_ptr + offsets, mask=mask, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask, other=0)

    # Apply a simple exclusion logic based on header values.
    # For demonstration: compatible if (type AND affordance are non-zero) AND (energy >= 3) AND (age <= 5)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store the token-level compatibility mask (1 if compatible, 0 if excluded).
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask)

    # Compute the block-level mask: if any token in the block is compatible, the block should be kept.
    # This allows for early exiting of entire blocks in subsequent kernels.
    block_keep = tl.sum(compat.to(tl.int32), axis=0) > 0

    # Store the block-level mask. Each block has one block_keep value.
    tl.store(block_mast_ptr + pid, block_keep.to(tl.int32))

# Triton kernel for attention computation with integrated exclusion logic and block pruning.
# This kernel performs scaled dot-product attention but only for compatible tokens and blocks.
@triton.jit
def exclude_attention_kernel(
    q_ptr,             # Pointer to the Query tensor
    k_ptr,             # Pointer to the Key tensor
    v_ptr,             # Pointer to the Value tensor
    headers_ptr,       # Pointer to the header data (same as exclude_kernel)
    mask_ptr,          # Output pointer for the token-level exclusion mask (recomputed or loaded)
    block_mask_ptr,    # Input pointer for the block-level exclusion mask (computed by exclude_kernel)
    output_ptr,        # Output pointer for the final attention result
    seq_len,           # Total sequence length
    d_head: tl.constexpr,   # Dimension of each attention head
    BLOCK_SIZE: tl.constexpr # Size of the processing block
):
    pid = tl.program_id(0)  # Program ID (block index)

    # Compute offsets for the current block along the sequence length
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_seq = offsets < seq_len  # Mask out-of-bounds offsets for sequence length

    # Load headers for the current block to re-evaluate or confirm token compatibility
    h_type   = tl.load(headers_ptr + offsets, mask=mask_seq, other=0)
    h_afford = tl.load(headers_ptr + seq_len + offsets, mask=mask_seq, other=0)
    h_energy = tl.load(headers_ptr + 2 * seq_len + offsets, mask=mask_seq, other=0)
    h_age    = tl.load(headers_ptr + 3 * seq_len + offsets, mask=mask_seq, other=0)

    # Simple exclusion logic (vectorized)
    compat = ((h_type & h_afford) != 0) & (h_energy >= 3) & (h_age <= 5)

    # Store the exclusion mask (token level) for this block
    tl.store(mask_ptr + offsets, compat.to(tl.int32), mask=mask_seq)

    # Load the block-level mask, which was pre-computed by exclude_kernel.
    # This allows skipping entire blocks if they contain no compatible tokens.
    block_keep = tl.load(block_mask_ptr + pid)

    # If the block is entirely excluded (no compatible tokens), skip all attention computation for it.
    if block_keep == 0:
        return

    # Load Q, K, V tensors for the current block.
    # The batch and head dimensions are flattened, so we load directly from the sequence dimension.
    q_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    q = tl.load(q_ptr + q_block_offsets, mask=mask_seq[:, None], other=0.0)

    k_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    k = tl.load(k_ptr + k_block_offsets, mask=mask_seq[:, None], other=0.0)

    v_block_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    v = tl.load(v_ptr + v_block_offsets, mask=mask_seq[:, None], other=0.0)

    # Compute attention scores (scaled dot-product attention)
    scores = tl.dot(q, tl.trans(k)) / (d_head**0.5)

    # Apply exclusion mask to scores: incompatible tokens should not contribute to attention.
    # Expand compat to match score dimensions for element-wise multiplication.
    compat_expanded = compat[:, None].to(tl.float32)
    scores = scores * compat_expanded

    # Softmax implementation (equivalent to F.softmax(scores, dim=-1) in Triton)
    # Subtract max for numerical stability
    scores_max = tl.max(scores, axis=1)[:, None]
    exp_scores = tl.math.exp(scores - scores_max)

    # Apply mask again after exponentiation to ensure excluded tokens remain zeroed out
    exp_scores = exp_scores * compat_expanded

    # Sum exponentials and divide to get attention weights
    sum_exp_scores = tl.sum(exp_scores, axis=1)[:, None]
    # Handle potential division by zero if all values in a row were masked out
    attn_weights = tl.where(sum_exp_scores > 0, exp_scores / sum_exp_scores, 0.0)

    # Apply attention weights to the Value tensor
    result = tl.dot(attn_weights, v)

    # Store the final attention result for the current block
    result_offsets = (offsets[:, None] * d_head) + tl.arange(0, d_head)[None, :]
    tl.store(output_ptr + result_offsets, result, mask=mask_seq[:, None])


# Python function to test baseline (standard PyTorch) attention performance
def test_baseline(seq_len=64, device="cuda"):

    print(f"Testing baseline with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64

    # Initialize dummy Query and Key tensors
    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)
    # Value tensor is also needed for scaled_dot_product_attention, even if not explicitly passed to matmul
    v = torch.randn(batch, heads, seq_len, d_head, device=device)

    torch.cuda.synchronize()
    start = time.time()

    # Use PyTorch's native scaled_dot_product_attention for a fair comparison
    # This function handles the matmul, scaling, softmax, and multiplication with V
    dummy_attn = torch.nn.functional.scaled_dot_product_attention(
        q, k, v, is_causal=False
    )
    torch.cuda.synchronize()
    end = time.time()

    print(f"Baseline full matmul time: {end - start:.6f}s")

# Python function to test the custom Triton attention kernel with exclusion
def test_kernel_attention(seq_len=64, device="cuda"):
    print(f"Test_kernel_attention with seq_len={seq_len}")
    print(f"-----")

    batch = 1
    heads = 1
    d_head = 64
    BLOCK_SIZE = 32 # Define the block size for Triton kernels

    # Initialize input tensors (Q, K, V) with random data
    q = torch.randn(batch, heads, seq_len, d_head, device=device)
    k = torch.randn(batch, heads, seq_len, d_head, device=device)
    v = torch.randn(batch, heads, seq_len, d_head, device=device)

    # Headers: Simulate part-of-speech related metadata or other exclusion criteria.
    # Stored as a flat tensor, 4 categories (type, afford, energy, age), each of seq_len.
    # To ensure more blocks are skipped, we will make the conditions for 'compat' less likely to be met.

    # headers = torch.randint(0, 10, (4 * seq_len,), device=device, dtype=torch.int32)
    # Generate header components separately
    h_type_data = torch.randint(0, 10, (seq_len,), device=device, dtype=torch.int32)
    h_afford_data = torch.randint(0, 10, (seq_len,), device=device, dtype=torch.int32)
    # Make h_energy have a mix of values >=3 and <3
    h_energy_data = torch.randint(0, 5, (seq_len,), device=device, dtype=torch.int32) # Range [0, 4]
    # Make h_age have a mix of values <=5 and >5
    h_age_data = torch.randint(0, 8, (seq_len,), device=device, dtype=torch.int32)    # Range [0, 7]

    # Concatenate them to form the headers tensor as expected by the kernel
    headers = torch.cat([h_type_data, h_afford_data, h_energy_data, h_age_data])

    # Token-level mask: will store 1 or 0 for each token indicating compatibility.
    mask = torch.zeros(seq_len, device=device, dtype=torch.int32)

    # Output tensor for storing the final attention result from the Triton kernel.
    output = torch.zeros_like(q)

    # Block-level mask: Stores 1 or 0 for each block (seq_len / BLOCK_SIZE blocks).
    # This mask determines if an entire block can be skipped.
    num_blocks = triton.cdiv(seq_len, BLOCK_SIZE)
    block_mask = torch.zeros(num_blocks, device=device, dtype=torch.int32)

    # Define the grid for Triton kernel launch (one program per block)
    grid = (num_blocks,)

    torch.cuda.synchronize()
    start = time.time()

    # Step 1: Run the `exclude_kernel` to populate the token-level and block-level masks.
    # This kernel identifies compatible tokens and sets the 'block_mask' for optimization.
    exclude_kernel[grid](
        headers,
        mask,
        block_mask, # Populated by exclude_kernel
        seq_len,
        BLOCK_SIZE=BLOCK_SIZE
    )

    torch.cuda.synchronize()
    end = time.time()
    print(f"Exclusion kernel time: {end - start:.6f}s")
    print(f"Block mask (first 100 values): {block_mask[:100].tolist()}...")
    print(f"Number of blocks to process (after exclusion): {block_mask.sum().item()}")
    print(f"Total blocks: {len(block_mask)}")

    torch.cuda.synchronize()
    start = time.time()

    # Step 2: Run the `exclude_attention_kernel`.
    # This kernel performs attention, leveraging the 'block_mask' for early exits
    # and applying the token-level exclusion logic during score calculation and softmax.
    exclude_attention_kernel[grid](
        q, k, v, headers, mask, block_mask, output, seq_len, d_head, BLOCK_SIZE=BLOCK_SIZE
    )

    torch.cuda.synchronize()
    end = time.time()
    print(f"Kernel time (exclusion + attention): {end - start:.6f}s")


# Loop through different sequence lengths to test performance scaling
for seq_len in [64, 2048, 4096, 32768]:
  test_baseline(seq_len)
  test_kernel_attention(seq_len)


Testing baseline with seq_len=64
-----
Baseline full matmul time: 0.000220s
Test_kernel_attention with seq_len=64
-----
Exclusion kernel time: 0.005018s
Block mask (first 100 values): [1, 1]...
Number of blocks to process (after exclusion): 2
Total blocks: 2
Kernel time (exclusion + attention): 0.004723s
Testing baseline with seq_len=2048
-----
Baseline full matmul time: 0.001172s
Test_kernel_attention with seq_len=2048
-----
Exclusion kernel time: 0.000068s
Block mask (first 100 values): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]...
Number of blocks to process (after exclusion): 64
Total blocks: 64
Kernel time (exclusion + attention): 0.000102s
Testing baseline with seq_len=4096
-----
Baseline full matmul time: 0.003628s
Test_kernel_attention with seq_len=4096
-----
Exclusion kernel time: 0.000058s
Block mask (first 100 values): [1, 1, 1