diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index 0460277..df56247 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -18,10 +18,6 @@ get_world_size ) -from ring_attention_pytorch.triton_flash_attn import ( - _flash_attn_backward -) - from beartype import beartype from einops import rearrange, repeat @@ -107,369 +103,10 @@ def inverse_fn(y): import triton import triton.language as tl -# taking the flash attention forwards from Tri's flash_attn repository -# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py -# and modifying to return unnormalized accumulation, row maxes, row lse - reduced over passed rings - -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } +from ring_attention_pytorch.triton_flash_attn import ( + flash_attn_backward, + flash_attn_forward ) -@triton.jit -def _fwd_kernel( - Q, - K, - V, - Bias, - Out, - M, - Lse, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_bb, - stride_bh, - stride_bm, - stride_ob, - stride_oh, - stride_om, - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, - HAS_BIAS: tl.constexpr, - IS_CAUSAL: tl.constexpr, - CAUSAL_MASK_DIAGONAL: tl.constexpr, - LOAD_ACCUMULATED: tl.constexpr, - RETURN_NORMALIZED_OUTPUT: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - - q_ptrs = ( - Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - ) - k_ptrs = ( - K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) - ) - v_ptrs = ( - V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) - ) - - if HAS_BIAS: - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n - - # maximum - - m_ptrs = M + off_hb * seqlen_q_rounded + offs_m - - if LOAD_ACCUMULATED: - m_i = tl.load(m_ptrs) - else: - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - - # load lse - - lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - - if LOAD_ACCUMULATED: - lse_i = tl.load(lse_ptrs) - else: - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - - # load accumualted output - - offs_d = tl.arange(0, BLOCK_HEADDIM) - - out_ptrs = ( - Out - + off_b * stride_ob - + off_h * stride_oh - + (offs_m[:, None] * stride_om + offs_d[None, :]) - ) - - if LOAD_ACCUMULATED: - if EVEN_M: - if EVEN_HEADDIM: - acc_o = tl.load(out_ptrs) - else: - acc_o = tl.load(out_ptrs, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - acc_o = tl.load(out_ptrs, mask=offs_m[:, None] < seqlen_q) - else: - acc_o = tl.load( - out_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) - ) - - acc_o = acc_o.to(tl.float32) - else: - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - - # load queries, keys, values - - if EVEN_M & EVEN_N: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 - ) - - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) - - if not EVEN_N: - qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - - if IS_CAUSAL: - if CAUSAL_MASK_DIAGONAL: - # needed for stripe attention - qk += tl.where(offs_m[:, None] > (start_n + offs_n)[None, :], 0, float("-inf")) - else: - qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - - if HAS_BIAS: - if EVEN_N: - bias = tl.load(b_ptrs + start_n) - else: - bias = tl.load( - b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 - ) - bias = bias[None, :] - - bias = bias.to(tl.float32) - qk = qk * softmax_scale + bias - m_ij = tl.maximum(tl.max(qk, 1), lse_i) - p = tl.exp(qk - m_ij[:, None]) - else: - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) - - l_ij = tl.sum(p, 1) - - acc_o_scale = tl.exp(m_i - m_ij) - acc_o = acc_o * acc_o_scale[:, None] - - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - - p = p.to(v.dtype) - acc_o += tl.dot(p, v) - - # -- update statistics - - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - - if RETURN_NORMALIZED_OUTPUT: - acc_o_scale = tl.exp(m_i - lse_i) - acc_o = acc_o * acc_o_scale[:, None] - - # offsets for m and lse - - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # write back lse and m - - tl.store(lse_ptrs, lse_i) - - if not RETURN_NORMALIZED_OUTPUT: - tl.store(m_ptrs, m_i) - - # write to output - - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store( - out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) - ) - -def flash_attn_forward( - q, - k, - v, - bias = None, - causal = False, - o = None, - m = None, - lse = None, - softmax_scale = None, - causal_mask_diagonal = False, - return_normalized_output = False, - load_accumulated = True -): - q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)] - - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, "FlashAttention only support head dimensions up to 128" - assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert q.is_cuda and k.is_cuda and v.is_cuda - - softmax_scale = default(softmax_scale, d ** -0.5) - - has_bias = exists(bias) - - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - - if bias.ndim == 2: - bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q) - - if not is_contiguous(bias): - bias = bias.contiguous() - - assert bias.shape[-2:] == (seqlen_q, seqlen_k) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - - if not exists(lse): - max_neg_value = -torch.finfo(torch.float32).max - init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty - lse = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - - if not exists(m): - max_neg_value = -torch.finfo(torch.float32).max - init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty - m = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - - if not exists(o): - init_fn = torch.zeros_like if load_accumulated else torch.empty_like - o = init_fn(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK = 128 - num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - - _fwd_kernel[grid]( - q, - k, - v, - bias, - o, - m, - lse, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - *bias_strides, - o.stride(0), - o.stride(2), - o.stride(1), - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, - has_bias, - causal, - causal_mask_diagonal, - load_accumulated, - return_normalized_output, - BLOCK_HEADDIM, - BLOCK_M = BLOCK, - BLOCK_N = BLOCK, - num_warps = num_warps, - num_stages = 1, - ) - - return o, m, lse # ring + (flash) attention forwards and backwards @@ -717,20 +354,18 @@ def backward(ctx, do): if causal or not exists(mask): - block_causal = False - causal_mask_diagonal = False - - need_accum = True - - if causal: - if striped_ring_attn: - block_causal = True - causal_mask_diagonal = get_rank() < ring_rank - else: - block_causal = get_rank() == ring_rank - - if get_rank() < ring_rank: - need_accum = False + if causal and striped_ring_attn: + need_accum = True + block_causal = True + causal_mask_diagonal = get_rank() < ring_rank + elif causal: + need_accum = get_rank() >= ring_rank + block_causal = get_rank() == ring_rank + causal_mask_diagonal = False + else: + need_accum = True + block_causal = False + causal_mask_diagonal = False # use flash attention backwards kernel to calculate dq, dk, dv and accumulate diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 6611876..1921879 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -1,5 +1,6 @@ -# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L618 +# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py # with fixes for triton 2.3 and preparing for modifications to backwards +# forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings import math @@ -7,6 +8,366 @@ import triton import triton.language as tl +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + M, + Lse, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + HAS_BIAS: tl.constexpr, + IS_CAUSAL: tl.constexpr, + CAUSAL_MASK_DIAGONAL: tl.constexpr, + LOAD_ACCUMULATED: tl.constexpr, + RETURN_NORMALIZED_OUTPUT: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + + if HAS_BIAS: + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + + # maximum + + m_ptrs = M + off_hb * seqlen_q_rounded + offs_m + + if LOAD_ACCUMULATED: + m_i = tl.load(m_ptrs) + else: + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + # load lse + + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + + if LOAD_ACCUMULATED: + lse_i = tl.load(lse_ptrs) + else: + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + # load accumualted output + + offs_d = tl.arange(0, BLOCK_HEADDIM) + + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + + if LOAD_ACCUMULATED: + if EVEN_M: + if EVEN_HEADDIM: + acc_o = tl.load(out_ptrs) + else: + acc_o = tl.load(out_ptrs, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + acc_o = tl.load(out_ptrs, mask=offs_m[:, None] < seqlen_q) + else: + acc_o = tl.load( + out_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + acc_o = acc_o.to(tl.float32) + else: + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # load queries, keys, values + + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + if not EVEN_N: + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + + if IS_CAUSAL: + if CAUSAL_MASK_DIAGONAL: + # needed for stripe attention + qk += tl.where(offs_m[:, None] > (start_n + offs_n)[None, :], 0, float("-inf")) + else: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + + if HAS_BIAS: + if EVEN_N: + bias = tl.load(b_ptrs + start_n) + else: + bias = tl.load( + b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 + ) + bias = bias[None, :] + + bias = bias.to(tl.float32) + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + + l_ij = tl.sum(p, 1) + + acc_o_scale = tl.exp(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + if RETURN_NORMALIZED_OUTPUT: + acc_o_scale = tl.exp(m_i - lse_i) + acc_o = acc_o * acc_o_scale[:, None] + + # offsets for m and lse + + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # write back lse and m + + tl.store(lse_ptrs, lse_i) + + if not RETURN_NORMALIZED_OUTPUT: + tl.store(m_ptrs, m_i) + + # write to output + + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + +def flash_attn_forward( + q, + k, + v, + bias = None, + causal = False, + o = None, + m = None, + lse = None, + softmax_scale = None, + causal_mask_diagonal = False, + return_normalized_output = False, + load_accumulated = True +): + q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)] + + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + softmax_scale = default(softmax_scale, d ** -0.5) + + has_bias = exists(bias) + + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + + if bias.ndim == 2: + bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q) + + if not is_contiguous(bias): + bias = bias.contiguous() + + assert bias.shape[-2:] == (seqlen_q, seqlen_k) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + + if not exists(lse): + max_neg_value = -torch.finfo(torch.float32).max + init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty + lse = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + + if not exists(m): + max_neg_value = -torch.finfo(torch.float32).max + init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty + m = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + + if not exists(o): + init_fn = torch.zeros_like if load_accumulated else torch.empty_like + o = init_fn(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + m, + lse, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, + has_bias, + causal, + causal_mask_diagonal, + load_accumulated, + return_normalized_output, + BLOCK_HEADDIM, + BLOCK_M = BLOCK, + BLOCK_N = BLOCK, + num_warps = num_warps, + num_stages = 1, + ) + + return o, m, lse + @triton.jit def _bwd_preprocess_do_o_dot( Out, diff --git a/setup.py b/setup.py index 46b27cb..5d98217 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.12', + version = '0.3.14', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',