In [1]:
import torch
import triton
import triton.language as tl

from typing import Tuple

In [2]:
def _pick_block_d(head_dim: int) -> int:
    for block_d in (128, 64, 32):
        if head_dim >= block_d:
            return block_d
    return 16

In [3]:
@triton.jit
def _rope_fused_kernel(
    x_ptr, cos_ptr, sin_ptr, out_ptr,
    STRIDE_ROW: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    ROPE_DIM: tl.constexpr,
    ROPE_OFFSET: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_d = tl.program_id(axis=1)

    offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
    
    offs_rope = offs_d - ROPE_OFFSET
    use_rope = offs_rope >= 0

    cos = tl.load(cos_ptr + tl.where(use_rope, offs_rope, 0),mask=use_rope, other=1.0)
    sin = tl.load(sin_ptr + tl.where(use_rope, offs_rope, 0),mask=use_rope, other=0.0)

    offs = pid_m * STRIDE_ROW + offs_d
    is_in_bounds = offs_d < HEAD_DIM

    x = tl.load(x_ptr + offs, mask=is_in_bounds)

    half_rope_dim = tl.constexpr(ROPE_DIM // 2)
    is_first_half = offs_rope < half_rope_dim
    
    rope_partner = tl.load(x_ptr + offs + tl.where(is_first_half, half_rope_dim, -half_rope_dim), mask=use_rope, other=0.0)
    out = tl.where(use_rope, x * cos + tl.where(is_first_half, -1.0,  1.0) * rope_partner * sin, x)

    tl.store(out_ptr + offs, out, mask=is_in_bounds)

In [4]:
def apply_rotary_pos_emb_triton(
    x: torch.Tensor,
    cos_sin: Tuple[torch.Tensor, torch.Tensor],
    *,
    block_d: int | None = None,
) -> torch.Tensor:
    if x.device.type != "cuda":
        raise RuntimeError("Triton kernel requires CUDA tensor")
    if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
        raise TypeError("x must be fp16, bf16, or fp32")

    cos, sin = cos_sin

    head_dim = x.size(-1)
    rope_dim = cos.size(-1)

    if rope_dim % 2:
        raise ValueError("rope_dim must be even")
    if rope_dim > head_dim:
        raise ValueError("rope_dim should be less than or equal to head_dim")

    x_flat = x.contiguous().view(-1, head_dim)
    out = torch.empty_like(x_flat)

    block_d = block_d or _pick_block_d(head_dim)
    if block_d & (block_d - 1):
        raise ValueError("block_d should be a power of two")
    if block_d > 128:
        raise ValueError("block_d should be less than or equal to 128")

    num_rows = x_flat.shape[0]
    num_chunks = (head_dim + block_d - 1) // block_d
    
    _rope_fused_kernel[(num_rows, num_chunks)](
        x_flat, cos, sin, out,
        head_dim,
        head_dim,
        rope_dim,
        head_dim - rope_dim,
        BLOCK_D=block_d,
        num_warps=1,
    )

    return out.view_as(x)

In [5]:
def apply_rotary_pos_emb(
    x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
    cos, sin = cos_sin

    head_dim = x.size(-1)
    rope_dim = cos.size(-1)

    if head_dim == rope_dim:
        x = (x * cos) + (_rotate_half(x) * sin)
    elif rope_dim < head_dim:
        x_nope, x_rope = x.split((head_dim - rope_dim, rope_dim), dim=-1)
        x_rope = (x_rope * cos) + (_rotate_half(x_rope) * sin)
        x = torch.cat([x_nope, x_rope], dim=-1)
    else:
        raise ValueError("rope_dim should be less than head_dim")

    return x

def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1, x2 = torch.chunk(x, 2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [9]:
if __name__ == "__main__":
    torch.manual_seed(42)
    batch_size = 32
    seq_len = 1024
    num_heads = 8
    head_dim = 64
    rope_dim = 48

    x = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32)
    cos = torch.randn(rope_dim, device="cuda", dtype=torch.float32)
    sin = torch.randn_like(cos)

    for _ in range(10):
        _ = apply_rotary_pos_emb(x, (cos, sin))
        _ = apply_rotary_pos_emb_triton(x, (cos, sin))
    torch.cuda.synchronize()

    start_ref = torch.cuda.Event(enable_timing=True)
    end_ref = torch.cuda.Event(enable_timing=True)
    start_ref.record()
    ref = apply_rotary_pos_emb(x, (cos, sin))
    end_ref.record()
    torch.cuda.synchronize()
    print("ref time:", start_ref.elapsed_time(end_ref), "ms")

    start_tri = torch.cuda.Event(enable_timing=True)
    end_tri = torch.cuda.Event(enable_timing=True)
    start_tri.record()
    tri = apply_rotary_pos_emb_triton(x, (cos, sin))
    end_tri.record()
    torch.cuda.synchronize()
    print("triton time:", start_tri.elapsed_time(end_tri), "ms")

    diff = (ref - tri).abs().max()
    print("max|diff| =", diff.item())

ref time: 3.528991937637329 ms
triton time: 0.9875839948654175 ms
max|diff| = 9.5367431640625e-07


In [7]:
if __name__ == "__main__":
    torch.manual_seed(42)
    batch_size = 32
    seq_len = 1024
    num_heads = 8
    head_dim = 64
    rope_dim = 48

    x = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32)
    cos = torch.randn(rope_dim, device="cuda", dtype=torch.float32)
    sin = torch.randn_like(cos)

    for _ in range(10):
        _ = apply_rotary_pos_emb(x, (cos, sin))
        _ = apply_rotary_pos_emb_triton(x, (cos, sin))
    torch.cuda.synchronize()

    start_ref = torch.cuda.Event(enable_timing=True)
    end_ref = torch.cuda.Event(enable_timing=True)
    start_ref.record()
    ref = apply_rotary_pos_emb(x, (cos, sin))
    end_ref.record()
    torch.cuda.synchronize()
    print("ref time:", start_ref.elapsed_time(end_ref), "ms")

    start_tri = torch.cuda.Event(enable_timing=True)
    end_tri = torch.cuda.Event(enable_timing=True)
    start_tri.record()
    tri = apply_rotary_pos_emb_triton(x, (cos, sin))
    end_tri.record()
    torch.cuda.synchronize()
    print("triton time:", start_tri.elapsed_time(end_tri), "ms")

    diff = (ref - tri).abs().max()
    print("max|diff| =", diff.item())

ref time: 3.5327999591827393 ms
triton time: 0.9911999702453613 ms
max|diff| = 9.5367431640625e-07
