In [2]:
import torch
import triton
import triton.language as tl

In [3]:
DEVICE = 'cuda:0'
torch_device = torch.device(DEVICE)

## Kernel to load every other element from a tensor

In [4]:
@triton.jit
def stride_copy_kernel(in_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    # Mask for input (every other element)
    input_mask = (offsets * 2) < N
    
    # Mask for output (contiguous elements)
    output_mask = offsets < (N // 2)

    inp_data = tl.load(in_ptr + (2 * offsets), mask=input_mask)
    alt_inp_data = tl.load(in_ptr + (2 * offsets + 1), mask=input_mask)
    tl.store(out_ptr + offsets, alt_inp_data, mask=output_mask)

In [6]:
# Wrapper function to launch the kernel
def stride_copy_wrapper(input_tensor, output_tensor):
    # Assume last dimension is the one to stride over
    n_elements = input_tensor.shape[-1]
    assert n_elements % 2 == 0, "Input tensor must have an even number of elements"
    
    BLOCK_SIZE = 64
    grid = (triton.cdiv(n_elements // 2, BLOCK_SIZE),)
    
    stride_copy_kernel[grid](
        input_tensor, 
        output_tensor, 
        n_elements, 
        BLOCK_SIZE
    )

In [7]:
# Create input tensor
input_tensor = torch.arange(300, dtype=torch.float32, device='cuda')
# Create output tensor to hold the result
output_tensor = torch.empty(150, dtype=torch.float32, device='cuda')

# Call the wrapper function
stride_copy_wrapper(input_tensor, output_tensor)

# Verify the result
print(output_tensor)  # Should contain [1, 3, 5, ..., 31]

tensor([  1.,   3.,   5.,   7.,   9.,  11.,  13.,  15.,  17.,  19.,  21.,  23.,
         25.,  27.,  29.,  31.,  33.,  35.,  37.,  39.,  41.,  43.,  45.,  47.,
         49.,  51.,  53.,  55.,  57.,  59.,  61.,  63.,  65.,  67.,  69.,  71.,
         73.,  75.,  77.,  79.,  81.,  83.,  85.,  87.,  89.,  91.,  93.,  95.,
         97.,  99., 101., 103., 105., 107., 109., 111., 113., 115., 117., 119.,
        121., 123., 125., 127., 129., 131., 133., 135., 137., 139., 141., 143.,
        145., 147., 149., 151., 153., 155., 157., 159., 161., 163., 165., 167.,
        169., 171., 173., 175., 177., 179., 181., 183., 185., 187., 189., 191.,
        193., 195., 197., 199., 201., 203., 205., 207., 209., 211., 213., 215.,
        217., 219., 221., 223., 225., 227., 229., 231., 233., 235., 237., 239.,
        241., 243., 245., 247., 249., 251., 253., 255., 257., 259., 261., 263.,
        265., 267., 269., 271., 273., 275., 277., 279., 281., 283., 285., 287.,
        289., 291., 293., 295., 297., 29

## RoPE Forward Pass Implementation

A work in progress implementation of RoPE forward pass in Triton.
Trying to understand the micro kernels required to make this efficient. 

In [8]:
torch.manual_seed(0)

# The d-head dimension
n_elements = 256

In [9]:
# Write the initial implementation for the RoPE kernel in Triton
import triton
import triton.language as tl
import torch

@triton.jit
def rope_kernel(
    q_ptr,
    k_ptr,
    cos_ptr,
    sin_ptr,
    out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr  # Block size for parallel processing
):
    # Define the program's index in the grid
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    # Mask for input (every other element)
    input_mask = (offsets * 2) < n_elements
    # Mask for loading the sinusoid values, up to dmodel // 2
    sinusoid_mask = offsets < (n_elements // 2)
    
    q_real = tl.load(q_ptr + (2 * offsets), mask=input_mask)
    q_imag = tl.load(q_ptr + (2 * offsets + 1), mask=input_mask)
    k_real = tl.load(k_ptr + (2 * offsets), mask=input_mask)
    k_imag = tl.load(k_ptr + (2 * offsets + 1), mask=input_mask)

    cos = tl.load(cos_ptr + offsets, mask=sinusoid_mask)
    sin = tl.load(sin_ptr + offsets, mask=sinusoid_mask)

    q_rotated_real = q_real * cos - q_imag * sin
    q_rotated_imag = q_real * sin + q_imag * cos
    k_rotated_real = k_real * cos - k_imag * sin
    k_rotated_imag = k_real * sin + k_imag * cos

    # Store rotated vectors back to memory
    output_mask = offsets < n_elements
    tl.store(out_ptr + offsets * 2, q_rotated_real, mask=output_mask)
    tl.store(out_ptr + offsets * 2 + 1, q_rotated_imag, mask=output_mask)
    # tl.store(out_ptr + n_elements * 2 + offsets * 2, k_rotated_real, mask=output_mask)
    # tl.store(out_ptr + n_elements * 2 + offsets * 2 + 1, k_rotated_imag, mask=output_mask)


In [10]:
import torch.nn as nn
class RoPEEmbeddings(nn.Module):
    def __init__(self, dim, max_seq_len=4096, base=10000):
        super(RoPEEmbeddings, self).__init__()
        assert dim % 2 == 0, "dim must be even for RoPE."
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        self.register_buffer("inv_freq",
            1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)), persistent=False)

        self.build_rope_cache()

    # Add the function to rotate half
    def rotate_half(self, x): 
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    def build_rope_cache(self):
        # Use this to precompute the RoPE cache
        pos_array = torch.arange(self.max_seq_len)
        theta = pos_array.unsqueeze(-1) * self.inv_freq
        cos = torch.cos(theta)
        sin = torch.sin(theta)
        #
        # Disable any interleaving or duplicating. CUDA doesn't need it.
        # cos = torch.cat([cos, cos], dim=-1)
        # sin = torch.cat([sin, sin], dim=-1)
        # instead do y[2i], y[2i+1] = x[i], x[i]
        # cos = torch.repeat_interleave(cos, repeats=2, dim=-1)
        # sin = torch.repeat_interleave(sin, repeats=2, dim=-1)

        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)

    def forward(self, x, positions=None):
        # Postitions is unused for now.
        out_real = x[..., ::2] * self.cos - x[..., 1::2] * self.sin
        out_imag = x[..., ::2] * self.sin + x[..., 1::2] * self.cos
        # Is there a better way to do this?
        out_interleaved = torch.cat((out_real.unsqueeze(-1),
                                    out_imag.unsqueeze(-1)), dim=-1)
        return out_interleaved.flatten(-2)


rope_ref = RoPEEmbeddings(n_elements, max_seq_len=1).to(DEVICE)

In [11]:
DEVICE = 'cuda:0' # There should be a better way to capture this
q_ptr = torch.randn((1, n_elements), device=DEVICE)
k_ptr = torch.randn((1, n_elements), device=DEVICE)
out_ptr = torch.zeros((1, n_elements), device=DEVICE)

torch_device = torch.device(DEVICE)
assert q_ptr.device == torch_device and k_ptr.device == torch_device and \
    out_ptr.device == torch_device

In [54]:
# Use a 1D grid.
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
rope_kernel[grid](
    q_ptr,
    k_ptr,
    rope_ref.cos,
    rope_ref.sin,
    out_ptr,
    n_elements,
    BLOCK_SIZE=32
)

<triton.compiler.compiler.CompiledKernel at 0x7df74ba00150>

In [92]:
import torch
from torchtune.modules import RotaryPositionalEmbeddings

# Initialize the RoPE module
rope = RotaryPositionalEmbeddings(dim=n_elements, max_seq_len=1).to(q_ptr.device)

# Apply RoPE to the input tensor
x_transformed = rope(q_ptr.view(1, 1, 1, n_elements))


In [96]:
# Comparing `torchtune` RopeEmbeddings with our implementation
# and Triton implementation.
print(torch.allclose(x_transformed, rope_ref(q_ptr)))
print(torch.allclose(x_transformed, out_ptr))

True
True


## Flash Attention

Following the tutorial from https://www.youtube.com/watch?v=zy8ChVd_oTM&t=7049s

In [None]:
@triton.jit
def _attn_fwd(
    Q,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    K,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    V,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    O,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    M,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN
    causal,
    softmax_scale,
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
):
    tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)
    
    block_index_q = tl.program_id(0)
    index_batch_head = tl.program_id(1)
    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS

    # Index into Q to figure out where does the head start
    qkv_offset = (
        index_batch.to(tl.int64) * stride_Q_batch
        + index_head.to(tl.int64) * stride_Q_head
    )
    
    q_block_ptr = tl.make_block_ptr(  # Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q, :]
        base=Q + qkv_offset,  # a 2D tensor
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_Q_seq, stride_Q_dim),
        block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
        # In the tensor view, the start offsets of the queries this block will work on.
        offsets=(block_index_q * BLOCK_SIZE_Q, 0),
        # TODO: what is this order?
        order=(1, 0)
    )
    
    # NOTE: Not skipping KVs into a block of KVs. TODO later.
    v_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_V_seq, stride_V_dim),
        block_shape=(BLOCK_SIZE_KV, HEAD_DIM),
        offsets=(0, 0),
        # TODO: what is this order?
        order=(1, 0)
    )

    # K should be indexed in the transposed manner.
    k_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(HEAD_DIM, SEQ_LEN),
        strides=(stride_K_dim, stride_K_seq),
        block_shape=(HEAD_DIM, BLOCK_SIZE_KV),
        # offsets are (0, 0) because we are not skipping anything. We are at the beginning of the cache block
        offsets=(0, 0),
        # TODO: what is this order?
        order=(0, 1)
    )

    # How many outputs do we generate?
    O_block_ptr = tl.make_block_ptr(  # O[index_batch, index_head, block_index_q * BLOCK_SIZE_Q, :]
        base=O + qkv_offset,
        shape=(SEQ_LEN, HEAD_DIM),
        strides=(stride_O_seq, stride_O_dim),
        block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
        offsets=(block_index_q * BLOCK_SIZE_Q, 0),
        order=(1, 0)
    )

    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)

    # offs_kv: the offsets for the token in the K and V sequence to process
    offs_kv = tl.arange(0, BLOCK_SIZE_KV)

    # m_i: the running maximum of the softmax block.
    m_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) - float('inf')

    # l_i: the running sum of the softmax block. We have one for each query
    l_i = tl.zeros((BLOCK_SIZE_Q,), dtype=tl.float32) + 1.0  # added in the algorithm

    # What are stages? Just following the tutorial for now.
    if STAGE == 3:
        # This step runs for the blocks to the right of the diagonal in causal attention


In [5]:
# If we want to define a function that we can backpropagate through, the class needs to be
# derived from torch.autograd.Function.

class TritonAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, causal, softmax_scale):
        # ctx allows us to save intermediate results for backward pass.

        HEAD_DIM_Q, HEAD_DIM_K, HEAD_DIM_V = Q.shape[-1], K.shape[-1], V.shape[-1]

        BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape

        O = torch.empty_like(Q)

        # What is stage?
        stage = 3 if causal else 1
        
        # Parallelize over the batch dimension and the number of heads
        grid = lambda args: (
            # ceil(SEQ_LEN / BLOCK_SIZE_Q) = How many blocks of Q we have.
            triton.cdiv(SEQ_LEN, args['BLOCK_SIZE_Q']),
            BATCH_SIZE * NUM_HEADS,
            1, # z in the CUDA launch grid
        )

        # M is the logsumexp for the backward pass, one for each query.
        M = torch.empty((BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32)

        _attn_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            O=O,
            M=M,
            causal=causal,
            softmax_scale=softmax_scale,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_d=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_d=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_d=V.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_d=O.stride(3),
            BATCH_SIZE=Q.shape[0],
            NUM_HEADS=Q.shape[1],
            SEQ_LEN=Q.shape[2],
            HEAD_DIM=HEAD_DIM_K,
            STAGE=stage
        )

        # save the intermediate results for backward pass.
        ctx.save_for_backward(Q, K, V, M)
        return O

In [94]:
import math

def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, causal, dtype=torch.float16, DEVICE='cuda'):
    Q = (torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device=DEVICE, dtype=dtype
        )
        .normal_(mean=0.0, std=0.5)
    )
    K = (torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device=DEVICE, dtype=dtype
        )
        .normal_(mean=0.0, std=0.5)
    )
    V = (torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device=DEVICE, dtype=dtype
        )
        .normal_(mean=0.0, std=0.5)
    )

    softmax_scale = 1.0 / math.sqrt(HEAD_DIM)
    # Backpropagate through the output.
    dO = torch.randn_like(Q)

    mask = torch.tril(torch.ones(SEQ_LEN, SEQ_LEN), device=DEVICE)
    P = torch.matmul(Q, K.transpose(2, 3)) * softmax_scale
    
    if causal:
        P[..., mask == 0] = float('-inf')
    P = torch.softmax(P, dim=-1).half()
    ref_O = torch.matmul(P, V)
    ref_O.backward(dO)

    ref_dV, V.grad = V.grad.clone(), None
    ref_dP, P.grad = P.grad.clone(), None
    ref_dK, K.grad = K.grad.clone(), None
    ref_dQ, Q.grad = Q.grad.clone(), None

    # Compare with Triton implementation
    
    

True