In [1]:
import torch
import triton 
import triton.language as tl

In [7]:
#

In [57]:
@triton.jit
def atte_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    scale,
    B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, D: tl.constexpr,
    Block_M: tl.constexpr, Block_N: tl.constexpr
):
    pid0 = tl.program_id(0)  # T block_id
    pid1 = tl.program_id(1)  # B*H id
    
    Q_bp0 = tl.make_block_ptr(
        base = Q_ptr + pid1 * T * D,
        shape = (T, D),
        strides = (D, 1),
        block_shape = (Block_M, D),
        offsets = (pid0*Block_M, 0),
        order = (1,0)
    )
    
    K_bp0 = tl.make_block_ptr(
        base = K_ptr + pid1 * T * D,
        shape = (T, D),
        strides = (D, 1),
        block_shape = (Block_N, D),
        offsets = (0, 0),
        order = (1,0)
    )
    
    V_bp0 = tl.make_block_ptr(
        base = V_ptr + pid1 * T * D,
        shape = (T, D),
        strides = (D, 1),
        block_shape = (Block_N, D),
        offsets = (0, 0),
        order = (1,0)
    )
    
    O_bp = tl.make_block_ptr(
        base = O_ptr + pid1 * T * D,
        shape = (T, D),
        strides = (D, 1),
        block_shape = (Block_M, D),
        offsets = (pid0*Block_M, 0),
        order = (1,0)
    )
    
    Q = tl.load(Q_bp0).to(tl.float32)
    iter_max = tl.cdiv(T, Block_N)
    
    O_online = tl.zeros((Block_M, D), dtype=tl.float32)
    l_online = tl.zeros((Block_M,), dtype=tl.float32)  
    max_online = tl.full((Block_M,), float('-inf'), dtype=tl.float32)
    K_bp = K_bp0
    V_bp = V_bp0
    for i in range(0, iter_max):
        Ki = tl.load(K_bp).to(tl.float32) # Bn x D
        Vi = tl.load(V_bp).to(tl.float32) # Bn x D
        max_old = max_online
        l_old = l_online
        O_old = O_online
        scorei = tl.dot(Q, tl.trans(Ki), allow_tf32=False) # Bm x Bn
        scorei = scorei * scale.to(tl.float32)
        
        max_online = tl.maximum(max_online, tl.max(scorei, axis = 1))
        pi = tl.exp(scorei - max_online[:,None])
        l_online = tl.exp(max_old - max_online) * l_old + tl.sum(pi, axis = 1)
        O_online = tl.exp(max_old - max_online)[:,None] * O_old + tl.dot(pi,Vi, allow_tf32=False)
        
        K_bp = tl.advance(K_bp,(Block_N,0))
        V_bp = tl.advance(V_bp, (Block_N,0))
    out = O_online / l_online[:, None]
    tl.store(O_bp, out)
    return 

In [58]:
def my_atte(Q, K, V):
    B, H, T, D = Q.shape
    scale = 1.0 / (D ** 0.5)    
    O = torch.empty_like(Q)
    Bm = 16
    Bn = 16
    grid = (triton.cdiv(T, Bm), B*H)
    atte_kernel[grid](
        Q, K, V, O,
        scale,
        B, H, T, D,
        Block_M=Bm,
        Block_N=Bn,
    )
    
    return O

In [59]:
from torch.backends.cuda import sdp_kernel
Q = torch.randn(2, 2, 32, 16, device='cuda:0')
K = torch.randn(2, 2, 32, 16, device='cuda:0')
V = torch.randn(2, 2, 32, 16, device='cuda:0')
O = my_atte(Q, K, V)
O_ref = torch.nn.functional.scaled_dot_product_attention(Q, K, V, dropout_p=0.0, is_causal=False)
scale = 1.0 / (Q.size(-1) ** 0.5)
scores = (Q @ K.transpose(-2, -1)) * scale     # (B,H,T,T)
m = scores.max(dim=-1, keepdim=True).values
P = torch.exp(scores - m)                               # 不做减max
O_ref2 = (P @ V) / P.sum(dim=-1, keepdim=True)



In [60]:
print(torch.allclose(O, O_ref, atol=1e-6))
print(torch.allclose(O, O_ref2, atol=1e-6))
print(torch.allclose(O_ref, O_ref2, atol=1e-6))

True
True
True
