In [None]:
from tests.test_attention import _attention_and_lse, _make_attn_inputs
from cs336_systems.flashattention_autograd_function_pytorch import FlashAttentionAutogradFunctionPytorch
impl = FlashAttentionAutogradFunctionPytorch.apply


In [None]:
device="cuda"
is_causal = False
q, k, v, _do = _make_attn_inputs(device)
o = impl(q, k, v, is_causal)
# Q, K, V, _do = _make_attn_inputs(device)


In [None]:
import triton
import triton.language as tl

@triton.jit
def flash_fwd_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr, L_ptr,
    stride_qb, stride_qq, stride_qd,
    stride_kb, stride_kk, stride_kd,
    stride_vb, stride_vk, stride_vd,
    stride_ob, stride_oq, stride_od,
    stride_lb, stide_lq,
    N_QUERIES, N_KEYS,
    scale,
    D: tl.constexpr,
    Q_TILE_SIZE: tl.constexpr,
    K_TILE_SIZE: tl.constexpr,
):
    # program indices
    query_tile_index = tl.program_id(0)
    batch_index = tl.program_id(1)

    Q_block_ptr = tl.make_block_ptr(
        Q_ptr + batch_index * stride_qb,
        shape=(N_QUERIES, D),
        strides=(stride_qq, stride_qd),
        offsets=(query_tile_index * Q_TILE_SIZE, 0),
        block_shape=(Q_TILE_SIZE, D),
        order=(1, 0),
    )

    K_block_ptr = tl.make_block_ptr(
        K_ptr + batch_index * stride_kb,
        shape=(N_KEYS, D),
        strides=(stride_kk, stride_kd),
        offsets=(0, 0),
        block_shape=(K_TILE_SIZE, D),
        order=(1, 0),
    )

    V_block_ptr = tl.make_block_ptr(
        V_ptr + batch_index * stride_vb,
        shape=(N_KEYS, D),
        strides=(stride_vk, stride_vd),
        offsets=(0, 0),
        block_shape=(K_TILE_SIZE, D),
        order=(1, 0),
    )

    O_block_ptr = tl.make_block_ptr(
        O_ptr + batch_index * stride_ob,
        shape=(N_QUERIES, D),
        strides=(stride_oq, stride_od),
        offsets=(query_tile_index * Q_TILE_SIZE, 0),
        block_shape=(Q_TILE_SIZE, D),
        order=(1, 0),
    )

    L_block_ptr = tl.make_block_ptr(
        L_ptr + batch_index * stride_lb,
        shape=(N_QUERIES, ),
        strides=(stride_oq, stride_od),
        offsets=(query_tile_index * Q_TILE_SIZE, ),
        block_shape=(Q_TILE_SIZE,),
        order=(0,),
    )

    O_block = tl.zeros(Q_TILE_SIZE, D, dtype=tl.float32)
    L_block = tl.zeros(Q_TILE_SIZE, dtype=tl.float32)
    m = tl.full((Q_TILE_SIZE,), float('-inf'), dtype=tl.float32)
    Q_block = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero')

    for i in range(tl.cdiv(N_KEYS, K_TILE_SIZE)):
        K_block = tl.load(K_block_ptr, boundary_check=(0,), padding_option='zero')
        V_block = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero')
        S = scale * tl.dot(Q_block, tl.trans(K_block))
        m_curr = tl.maximum(m, tl.max(S, axis=-1))
        P = tl.exp(S - m_curr.expand_dims(axis=0))
        L_block = tl.exp(m - m_curr) * L_block  + tl.sum(P, axis=-1)
        # according to Claude, tl does not have `diag` so need to use broadcasting
        O_block = tl.exp(m-m_curr)[:, None] * O_block + tl.dot(P, tl.trans(V_block))
        m = m_curr

        # Move the pointer to next tile
        K_block_ptr = K_block_ptr.advance((K_TILE_SIZE, 0))
        V_block_ptr = V_block_ptr.advance((K_TILE_SIZE, 0))

    O_block = (1 / L_block)[:, None] * O_block
    L_block = m + tl.log(L_block)

    tl.store(O_block_ptr, O_block, boundary_check=(0,))

class FlashAttentionAutogradFunctionTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, is_causal=False):
        device = Q.device
        D, output_dims = Q.shape[-1], Q.shape[:-1]
        _, N_k, _ = K.shape
        B_q = 16
        B_k = 16
        O = torch.zeros_like(Q, device=device)
        L = torch.zeros(batch_size, N_q, device=device)

In [None]:
import triton
@triton.jit
def test_kernel():
    x = tl.full([10], float('-inf'))
    return x

# But even then, you can't call it directly - you'd need to launch it as a kernel

In [None]:
import triton.language as tl 

tl.full([10], float('-inf'))