In [8]:
import torch
import triton
import triton.language as tl

In [None]:
DEVICE = 'cuda:0'
torch_device = torch.device(DEVICE)

## Kernel to load every other element from a tensor

In [9]:
@triton.jit
def stride_copy_kernel(in_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    # Mask for input (every other element)
    input_mask = (offsets * 2) < N
    
    # Mask for output (contiguous elements)
    output_mask = offsets < (N // 2)
    
    inp_data = tl.load(in_ptr + (2 * offsets), mask=input_mask)
    tl.store(out_ptr + offsets, inp_data, mask=output_mask)

In [5]:
# Wrapper function to launch the kernel
def stride_copy_wrapper(input_tensor, output_tensor):
    # Assume last dimension is the one to stride over
    n_elements = input_tensor.shape[-1]
    assert n_elements % 2 == 0, "Input tensor must have an even number of elements"
    
    BLOCK_SIZE = 64
    grid = (triton.cdiv(n_elements // 2, BLOCK_SIZE),)
    
    stride_copy_kernel[grid](
        input_tensor, 
        output_tensor, 
        n_elements, 
        BLOCK_SIZE
    )

In [7]:
# Create input tensor
input_tensor = torch.arange(300, dtype=torch.float32, device='cuda')
# Create output tensor to hold the result
output_tensor = torch.empty(150, dtype=torch.float32, device='cuda')

# Call the wrapper function
stride_copy_wrapper(input_tensor, output_tensor)

# Verify the result
print(output_tensor)  # Should contain [1, 3, 5, ..., 31]

tensor([  0.,   2.,   4.,   6.,   8.,  10.,  12.,  14.,  16.,  18.,  20.,  22.,
         24.,  26.,  28.,  30.,  32.,  34.,  36.,  38.,  40.,  42.,  44.,  46.,
         48.,  50.,  52.,  54.,  56.,  58.,  60.,  62.,  64.,  66.,  68.,  70.,
         72.,  74.,  76.,  78.,  80.,  82.,  84.,  86.,  88.,  90.,  92.,  94.,
         96.,  98., 100., 102., 104., 106., 108., 110., 112., 114., 116., 118.,
        120., 122., 124., 126., 128., 130., 132., 134., 136., 138., 140., 142.,
        144., 146., 148., 150., 152., 154., 156., 158., 160., 162., 164., 166.,
        168., 170., 172., 174., 176., 178., 180., 182., 184., 186., 188., 190.,
        192., 194., 196., 198., 200., 202., 204., 206., 208., 210., 212., 214.,
        216., 218., 220., 222., 224., 226., 228., 230., 232., 234., 236., 238.,
        240., 242., 244., 246., 248., 250., 252., 254., 256., 258., 260., 262.,
        264., 266., 268., 270., 272., 274., 276., 278., 280., 282., 284., 286.,
        288., 290., 292., 294., 296., 29

## RoPE Forward Pass Implementation

A work in progress implementation of RoPE forward pass in Triton.
Trying to understand the micro kernels required to make this efficient.

In [1]:
# Write the initial implementation for the RoPE kernel in Triton
import triton
import triton.language as tl
import torch

@triton.jit
def rope_kernel(
    q_ptr,
    k_ptr,
    cos_ptr,
    sin_ptr,
    out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr  # Block size for parallel processing
):
    # Define the program's index in the grid
    pid = tl.program_id(0)

    # Calculate start index for this program
    start = pid * BLOCK_SIZE

    # Create offsets within the block
    offsets = start + tl.arange(0, BLOCK_SIZE)

    # Check if offsets are within bounds
    mask = offsets < n_elements

    # Load query, key, cos and sin values from memory with masking
    q = tl.load(q_ptr + offsets, mask=mask)
    k = tl.load(k_ptr + offsets, mask=mask)
    cos = tl.load(cos_ptr + offsets, mask=mask)
    sin = tl.load(sin_ptr + offsets, mask=mask)

    q_real = q[:, :, ::2]  # Even indices
    q_imag = q[:, :, 1::2]  # Odd indices
    q_rotated_real = q_real * cos - q_imag * sin
    q_rotated_imag = q_real * sin + q_imag * cos

    # For key: k_rotated = [k_real * cos - k_imag * sin,
    #                       k_real * sin + k_imag * cos]
    k_real = k[..., ::2]  # Even indices
    k_imag = k[..., 1::2]  # Odd indices
    k_rotated_real = k_real * cos - k_imag * sin
    k_rotated_imag = k_real * sin + k_imag * cos

    # Store rotated vectors back to memory
    tl.store(out_ptr + offsets * 2, q_rotated_real, mask=mask)
    tl.store(out_ptr + offsets * 2 + 1, q_rotated_imag, mask=mask)
    tl.store(out_ptr + n_elements * 2 + offsets * 2, k_rotated_real, mask=mask)
    tl.store(out_ptr + n_elements * 2 + offsets * 2 + 1, k_rotated_imag, mask=mask)


In [None]:
torch.manual_seed(0)
n_elements = 1024
DEVICE = 'cuda:0' # There should be a better way to capture this
q_ptr = torch.randn((1, n_elements), device=DEVICE)
k_ptr = torch.randn((1, n_elements), device=DEVICE)

In [None]:
out_ptr = torch.zeros((1, n_elements), device=DEVICE)
# DEVICE = triton.runtime.driver.active.get_active_torch_device()

torch_device = torch.device(DEVICE)
assert q_ptr.device == torch_device and k_ptr.device == torch_device and \
    out_ptr.device == torch_device

# Use a 1D grid.
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
rope_kernel[grid](
    q_ptr,
    k_ptr,
    rope_ref.cos,
    rope_ref.sin,
    out_ptr,
    n_elements,
    BLOCK_SIZE=32
)