In [1]:
pip install triton==3.3.1



In [2]:
pip install torch==2.7.1



In [3]:
import torch
import triton
import triton.language as tl

from typing import Tuple

In [4]:
print(torch.__version__)
print(triton.__version__)

2.7.1+cu126
3.3.1


In [5]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_D': 16}, num_warps=1),
        triton.Config({'BLOCK_D': 16}, num_warps=2),

        triton.Config({'BLOCK_D': 32}, num_warps=2),
        triton.Config({'BLOCK_D': 32}, num_warps=4),

        triton.Config({'BLOCK_D': 64}, num_warps=1),
        triton.Config({'BLOCK_D': 64}, num_warps=2),
        triton.Config({'BLOCK_D': 64}, num_warps=4),

        triton.Config({'BLOCK_D': 128}, num_warps=2),
        triton.Config({'BLOCK_D': 128}, num_warps=4),
        triton.Config({'BLOCK_D': 128}, num_warps=8),

        triton.Config({'BLOCK_D': 256}, num_warps=4),
        triton.Config({'BLOCK_D': 256}, num_warps=8),

        triton.Config({'BLOCK_D': 512}, num_warps=4),
        triton.Config({'BLOCK_D': 512}, num_warps=8),

        triton.Config({'BLOCK_D': 1024}, num_warps=8),
        triton.Config({'BLOCK_D': 1024}, num_warps=16),
    ],
    key=['HEAD_DIM', 'ROPE_DIM', 'NUM_ELEMENTS'],
)

@triton.jit
def _rope_fused_kernel(
    x_ptr, cos_ptr, sin_ptr, out_ptr,
    HEAD_DIM: tl.constexpr,
    ROPE_DIM: tl.constexpr,
    NUM_ELEMENTS: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    offs_d = tl.arange(0, BLOCK_D)
    offs = tl.program_id(axis=0) * BLOCK_D + offs_d

    is_in_bounds = offs < NUM_ELEMENTS
    x = tl.load(x_ptr + offs, mask=is_in_bounds, other=0.0)

    ROPE_OFFSET: tl.constexpr = HEAD_DIM - ROPE_DIM
    offs_rope = (offs_d % HEAD_DIM) - ROPE_OFFSET

    HALF_ROPE_DIM: tl.constexpr = ROPE_DIM // 2
    is_first_half = offs_rope < HALF_ROPE_DIM
    rope_partner = tl.gather(x, offs_d + tl.where(is_first_half, HALF_ROPE_DIM, -HALF_ROPE_DIM), axis=0) # can also just use tl.load but tl.gather is more suited for this

    use_rope = offs_rope >= 0
    cos = tl.load(cos_ptr + offs_rope, mask=use_rope, other=0.0)
    sin = tl.load(sin_ptr + offs_rope, mask=use_rope, other=0.0)

    out = tl.where(use_rope, x * cos + tl.where(is_first_half, -1.0,  1.0) * rope_partner * sin, x)
    tl.store(out_ptr + offs, out, mask=is_in_bounds)

In [None]:
# Kernel generated by torch.compile

# import triton
# import triton.language as tl
# from triton.compiler.compiler import AttrsDescriptor

# from torch._inductor.runtime import triton_helpers, triton_heuristics
# from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
# from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
# triton_helpers.set_driver_to_gpu()

# @triton_heuristics.pointwise(
#     size_hints={'x': 16777216},
#     filename=__file__,
#     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=40, cc=75, major=7, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1024, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
#     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_cat_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': '9182018CCD6A4F758231D68D0B1E1E23CEBB32E5D78CB36B65791C4EB96774A2', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
#     min_elem_per_thread=0
# )
# @triton.jit
# def triton_poi_fused_cat_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
#     xnumel = 16777216
#     xoffset = tl.program_id(0) * XBLOCK
#     xindex = xoffset + tl.arange(0, XBLOCK)[:]
#     xmask = tl.full([XBLOCK], True, tl.int1)
#     x0 = (xindex % 64)
#     x1 = xindex // 64
#     x2 = xindex
#     tmp0 = x0
#     tmp1 = tl.full([1], 0, tl.int64)
#     tmp2 = tmp0 >= tmp1
#     tmp3 = tl.full([1], 16, tl.int64)
#     tmp4 = tmp0 < tmp3
#     tmp5 = tl.load(in_ptr0 + (64*x1 + (x0)), tmp4, eviction_policy='evict_last', other=0.0)
#     tmp6 = tmp0 >= tmp3
#     tmp7 = tl.full([1], 64, tl.int64)
#     tmp8 = tmp0 < tmp7
#     tmp9 = tl.load(in_ptr0 + (16 + 64*x1 + ((-16) + x0)), tmp6, eviction_policy='evict_last', other=0.0)
#     tmp10 = tl.load(in_ptr1 + ((-16) + x0), tmp6, eviction_policy='evict_last', other=0.0)
#     tmp11 = tmp9 * tmp10
#     tmp12 = (-16) + x0
#     tmp13 = tl.full([1], 0, tl.int64)
#     tmp14 = tmp12 >= tmp13
#     tmp15 = tl.full([1], 24, tl.int64)
#     tmp16 = tmp12 < tmp15
#     tmp17 = tmp16 & tmp6
#     tmp18 = tl.load(in_ptr0 + (40 + 64*x1 + ((-16) + x0)), tmp17, eviction_policy='evict_last', other=0.0)
#     tmp19 = -tmp18
#     tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
#     tmp21 = tl.where(tmp17, tmp19, tmp20)
#     tmp22 = tmp12 >= tmp15
#     tmp23 = tl.full([1], 48, tl.int64)
#     tmp24 = tmp12 < tmp23
#     tmp25 = tmp22 & tmp6
#     tmp26 = tl.load(in_ptr0 + (16 + 64*x1 + ((-24) + ((-16) + x0))), tmp25, eviction_policy='evict_last', other=0.0)
#     tmp27 = tl.where(tmp16, tmp21, tmp26)
#     tmp28 = tl.load(in_ptr2 + ((-16) + x0), tmp6, eviction_policy='evict_last', other=0.0)
#     tmp29 = tmp27 * tmp28
#     tmp30 = tmp11 + tmp29
#     tmp31 = tl.full(tmp30.shape, 0.0, tmp30.dtype)
#     tmp32 = tl.where(tmp6, tmp30, tmp31)
#     tmp33 = tl.where(tmp4, tmp5, tmp32)
#     tl.store(out_ptr0 + (x2), tmp33, None)


In [7]:
def apply_rotary_pos_emb_triton(
    x: torch.Tensor,
    cos_sin: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
    if x.device.type != "cuda":
        raise RuntimeError("Triton kernel requires CUDA tensor")
    if x.dtype not in (torch.float16, torch.bfloat16, torch.float32):
        raise TypeError("x must be fp16, bf16, or fp32")

    cos, sin = cos_sin

    head_dim = x.size(-1)
    rope_dim = cos.size(-1)

    if rope_dim % 2:
        raise ValueError("rope_dim must be even")
    if rope_dim > head_dim:
        raise ValueError("rope_dim should be less than or equal to head_dim")

    out = torch.empty_like(x)
    x_numel = x.numel()

    grid = lambda META: (triton.cdiv(x_numel, META['BLOCK_D']),)

    _rope_fused_kernel[grid](
        x, cos, sin, out,
        head_dim,
        rope_dim,
        x_numel,
    )

    return out.view_as(x)

In [8]:
def apply_rotary_pos_emb(
    x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
    cos, sin = cos_sin

    head_dim = x.size(-1)
    rope_dim = cos.size(-1)

    if head_dim == rope_dim:
        x = (x * cos) + (_rotate_half(x) * sin)
    elif rope_dim < head_dim:
        x_nope, x_rope = x.split((head_dim - rope_dim, rope_dim), dim=-1)
        x_rope = (x_rope * cos) + (_rotate_half(x_rope) * sin)
        x = torch.cat([x_nope, x_rope], dim=-1)
    else:
        raise ValueError("rope_dim should be less than head_dim")

    return x

def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1, x2 = torch.chunk(x, 2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

In [9]:
if __name__ == "__main__":
    torch.manual_seed(0)
    batch_size = 32
    seq_len = 1024
    num_heads = 8
    head_dim = 64
    rope_dim = 48

    x = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32)
    cos = torch.randn(rope_dim, device="cuda", dtype=torch.float32)
    sin = torch.randn_like(cos)

    ref = apply_rotary_pos_emb(x, (cos, sin))
    tri = apply_rotary_pos_emb_triton(x, (cos, sin))

    diff = (ref - tri).abs().max()
    print("max|diff| =", diff.item())

    ref = torch.compile(apply_rotary_pos_emb)(x, (cos, sin))
    tri = torch.compile(apply_rotary_pos_emb_triton)(x, (cos, sin))

    diff = (ref - tri).abs().max()
    print("max|diff| =", diff.item())

max|diff| = 9.5367431640625e-07
max|diff| = 0.0


In [10]:
torch.manual_seed(42)
apply_rotary_pos_emb = torch.compile(apply_rotary_pos_emb)
apply_rotary_pos_emb_triton = torch.compile(apply_rotary_pos_emb_triton)

batch_size = 32
seq_len = 1024
num_heads = 8
head_dim = 64
rope_dim = 64

x = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32)
cos = torch.randn(rope_dim, device="cuda", dtype=torch.float32)
sin = torch.randn_like(cos)

# compile for the first time
for _ in range(1):
    _ = apply_rotary_pos_emb(x, (cos, sin))
    _ = apply_rotary_pos_emb_triton(x, (cos, sin))

total_ref_time = 0.0
total_triton_time = 0.0
num_runs = 100

for i in range(num_runs):
    torch.manual_seed(i)

    x = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32)
    cos = torch.randn(rope_dim, device="cuda", dtype=torch.float32)
    sin = torch.randn_like(cos)

    torch.cuda.synchronize()

    start_ref = torch.cuda.Event(enable_timing=True)
    end_ref = torch.cuda.Event(enable_timing=True)
    start_ref.record()
    ref = apply_rotary_pos_emb(x, (cos, sin))
    end_ref.record()
    torch.cuda.synchronize()
    ref_time = start_ref.elapsed_time(end_ref)
    total_ref_time += ref_time

    start_tri = torch.cuda.Event(enable_timing=True)
    end_tri = torch.cuda.Event(enable_timing=True)
    start_tri.record()
    tri = apply_rotary_pos_emb_triton(x, (cos, sin))
    end_tri.record()
    torch.cuda.synchronize()
    triton_time = start_tri.elapsed_time(end_tri)
    total_triton_time += triton_time

    print(ref_time)
    print(triton_time)
    print()

    diff = (ref - tri).abs().max()

average_ref_time = total_ref_time / num_runs
average_triton_time = total_triton_time / num_runs

print("-" * 30)
print(f"Average ref time over {num_runs} runs: {average_ref_time:.4f} ms")
print(f"Average triton time over {num_runs} runs: {average_triton_time:.4f} ms")

0.7671679854393005
0.98716801404953

0.6861119866371155
0.7167999744415283

0.6747519969940186
0.6750079989433289

0.6701440215110779
0.6723840236663818

0.6635519862174988
0.6817600131034851

0.6696959733963013
0.6689280271530151

0.6635839939117432
0.6801279783248901

0.6614720225334167
0.6581760048866272

0.6609280109405518
0.6921600103378296

0.6594560146331787
0.6758400201797485

0.6635519862174988
0.6625919938087463

0.6914880275726318
0.8195199966430664

0.7024959921836853
0.6758400201797485

0.6696959733963013
0.6649600267410278

0.6718720197677612
0.66975998878479

0.6532160043716431
0.6584640145301819

0.667136013507843
0.6657919883728027

0.6755520105361938
0.7344319820404053

0.7309759855270386
0.6792320013046265

0.6799359917640686
0.6750079989433289

0.6689599752426147
0.6650239825248718

0.6717439889907837
0.6778879761695862

0.6594560146331787
0.6625919938087463

0.6956160068511963
0.7403839826583862

0.6778879761695862
0.713919997215271

0.6533120274543762
0.6812480092