<a href="https://colab.research.google.com/github/jpli02/learn-bpf/blob/master/causal_attn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install triton==2.0.0.dev20220709
!pip install pytest
import pytest
!pip install flash-attn

[31mERROR: Could not find a version that satisfies the requirement triton==2.0.0.dev20220709 (from versions: 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.3.1, 3.0.0, 3.1.0, 3.2.0)[0m[31m
[0m[31mERROR: No matching distribution found for triton==2.0.0.dev20220709[0m[31m
Collecting flash-attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m70.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->flash-attn)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->flash-attn)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->flash-attn)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl

In [2]:
import torch

import triton
import triton.language as tl


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    K_block_ptr, V_block_ptr, #
                    start_m, qk_scale,  #
                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
                    N_CTX: tl.constexpr, fp8_v: tl.constexpr):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX

    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(K_block_ptr)
        qk = tl.dot(q, k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]
        else:
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, None]

        p = tl.math.exp2(qk)
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        # -- update output accumulator --
        acc = acc * alpha[:, None]
        # update acc
        v = tl.load(V_block_ptr)
        if fp8_v:
            p = p.to(tl.float8e5)
        else:
            p = p.to(tl.float16)
        acc = tl.dot(p, v, acc)
        # update m_i and l_i
        m_i = m_ij
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))

    return acc, l_i, m_i

@triton.jit
def _acc_attention_score(acc_score, k,  #
                        Q_block_ptr, M_block_ptr, #
                        start_m, qk_scale,  #
                        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,  #
                        STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
                        N_CTX: tl.constexpr):

    # range of values handled by this stage
    lo, hi = 0, N_CTX
    for start_n in range(lo, hi, BLOCK_M):
        q = tl.load(Q_block_ptr)
        m = tl.load(M_block_ptr)
        qk = tl.dot(q, k)
        if STAGE == 1 or STAGE == 2:
          mask = (offs_m[:, None] + start_n) >= (offs_n[None, :])
          qk = qk * qk_scale + tl.where(mask, 0, -1.0e8) - m[:, None]
        # causal = False
        else:
          qk = qk * qk_scale  - m[:, None]

        p = tl.math.exp2(qk)

        acc_score += tl.sum(p, 0)
        Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
        M_block_ptr = tl.advance(M_block_ptr, (BLOCK_M,))

    return acc_score

# We don't run auto-tuning every time to keep the tutorial fast. Keeping`
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
configs = [
    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
    # for BM in [64, 128]\
    # for BN in [32, 64]\
    for BM in [16]\
    for BN in [16]\
    for s in ([1] if is_hip() else [3, 4, 7])\
    for w in [4, 8]\
]


def keep(conf):
    BLOCK_M = conf.kwargs["BLOCK_M"]
    BLOCK_N = conf.kwargs["BLOCK_N"]
    if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
        return False
    return True


@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out, C, # C = (Z, H, N_CTX)
              stride_qz, stride_qh, stride_qm, stride_qk,  #
              stride_kz, stride_kh, stride_kn, stride_kk,  #
              stride_vz, stride_vh, stride_vk, stride_vn,  #
              stride_oz, stride_oh, stride_om, stride_on,  #
              stride_cz, stride_ch, stride_cn,  #
              stride_mz, stride_mh, stride_mn, #
              Z, H, N_CTX,  #
              HEAD_DIM: tl.constexpr,  #
              BLOCK_M: tl.constexpr,  #
              BLOCK_N: tl.constexpr,  #
              STAGE: tl.constexpr  #
              ):
    tl.static_assert(BLOCK_N <= HEAD_DIM)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H

    # corresponds to a q, k and v for a particular head and batch
    qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, HEAD_DIM),
        order=v_order,
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )

    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)

    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))

    # second-pass accumulated score calculation
    # required condition: BLOCK_M == BLOCK_N
    m_offset = off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
    c_offset = off_z.to(tl.int64) * stride_cz + off_h.to(tl.int64) * stride_ch

    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(0, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, BLOCK_N * start_m),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    C_block_ptr = tl.make_block_ptr(
        base=C + c_offset,
        shape=(N_CTX,),
        strides=(stride_cn,),
        offsets=(start_m * BLOCK_N,),
        block_shape=(BLOCK_N,),
        order=(0,),
    )

    M_block_ptr = tl.make_block_ptr(
        base=M + m_offset,
        shape=(N_CTX,),
        strides=(stride_mn,),
        offsets=(0,),
        block_shape=(BLOCK_N,),
        order=(0,)
    )

    offs_m = tl.arange(0, BLOCK_M)
    offs_n = start_m * BLOCK_N + tl.arange(0, BLOCK_N)

    acc_score = tl.zeros([BLOCK_N,], dtype=tl.float32)
    k = tl.load(K_block_ptr)
    acc_score = _acc_attention_score(acc_score, k,
                    Q_block_ptr, M_block_ptr, #
                    start_m, qk_scale,  #
                    BLOCK_M, BLOCK_N,  #
                    4 - STAGE, offs_m, offs_n, #
                    N_CTX)


    tl.store(C_block_ptr, acc_score.to(C.type.element_ty))


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale):
        # shape constraints
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        # when v is in float8_e5m2 it is transposed.
        HEAD_DIM_V = v.shape[-1]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        o = torch.empty_like(q)
        c = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), dtype=torch.float32, device=q.device)
        stage = 3 if causal else 1
        extra_kern_args = {}
        # Tuning for AMD target
        if is_hip():
            waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float16)

        _attn_fwd[grid](
            q, k, v, sm_scale, M, o, c, #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
            c.stride(0), c.stride(1), c.stride(2),  #
            M.stride(0), M.stride(1), M.stride(2),
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            STAGE=stage,  #
            **extra_kern_args)

        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return o, c, M


selection_attention = _attention.apply

In [4]:
import torch
import torch.nn as nn
import argparse
import math
import time
import gc
import pandas as pd
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

def gpu_cleanup():
    """
    Function to clean up GPU memory.
    """
    gc.collect()
    torch.cuda.empty_cache()

def create_tensors(Z, H, N_CTX, HEAD_DIM, dtype=torch.float16):
    """
    Create tensors for attention computation.
    """
    torch.manual_seed(int(time.time()))
    q = torch.rand((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda")
    k = torch.rand((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda")
    v = torch.rand((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda")
    return q, k, v

def _make_causal_mask(
    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
    """
    Make causal mask used for bi-directional self-attention.
    """
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def ref_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
    q, k, v = create_tensors(Z, H, N_CTX, HEAD_DIM, dtype)
    attn_weights = torch.matmul(q, k.transpose(2,3)) / math.sqrt(HEAD_DIM)

    if causal:
        attention_mask = _make_causal_mask(
            bsz=Z,
            tgt_len=N_CTX,
            past_key_values_length=0,
            dtype=q.dtype,
            device=q.device,
        )

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
            attn_weights = torch.max(
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
            )

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(q.dtype)
    attn_output = torch.matmul(attn_weights, v)
    cumulative_attn_map = attn_weights.sum(2)
    return attn_output, cumulative_attn_map

def flash_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
    q, k, v = create_tensors(Z, H, N_CTX, HEAD_DIM, dtype)
    attn_output = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=causal,
                window_size=(-1, -1), alibi_slopes=None, deterministic=False)

    return attn_output

def triton_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
    """
    Perform Triton-based attention computation on the GPU.
    """
    q, k, v = create_tensors(Z, H, N_CTX, HEAD_DIM, dtype)
    sm_scale = 1.0 / math.sqrt(HEAD_DIM)
    tri_out, tri_c, tri_m = selection_attention(q, k, v, causal, sm_scale)
    return tri_out, tri_c, tri_m

def test_attention(Z, H, N_CTX, HEAD_DIM, causal=False, dtype=torch.float16):
    """
    Test to compare correctness of triton cmul attention kernel
    """
    gpu_cleanup()
    ref_out_gpu1, ref_c_gpu1 = ref_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype)
    # Convert reference tensors to match dtype of Triton results
    ref_c_gpu1 = ref_c_gpu1.half()
    ref_out_gpu1 = ref_out_gpu1.half()
    tri_out_gpu, tri_c_gpu, tri_m_gpu = triton_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype)

    flash_out_gpu = flash_attention(Z, H, N_CTX, HEAD_DIM, causal, dtype)

    # Compare results
    print(f"Attention max diff: {(tri_out_gpu.half() - ref_out_gpu1).abs().max().item()}")
    assert torch.allclose(ref_out_gpu1, tri_out_gpu.half(), atol=0.8, rtol=0), "Attention output mismatch"
    print("Attention check passed")

    print(f"accum score max diff: {(tri_c_gpu.half() - ref_c_gpu1).abs().max().item()}")
    # print("---------------------------------------------")
    # print("ref attention")
    # print(ref_c_gpu1)
    # print("---------------------------------------------")
    # print("triton attention")
    # print(tri_c_gpu.half())

    assert torch.allclose(ref_c_gpu1, tri_c_gpu.half(), atol=0.05, rtol=0), "col-wise sum score acc mismatch"
    print("Attention score acc check passed")

    # save results
    # pd.DataFrame(ref_c_gpu.cpu().numpy().flatten()).to_csv("/u/ndani/selection_kernel/reference_scores.csv", index=False, header=False, float_format="%.5f")
    # pd.DataFrame(tri_c_gpu.cpu().numpy().flatten()).to_csv("/u/ndani/selection_kernel/ours_scores.csv", index=False, header=False, float_format="%.5f")

if __name__ == "__main__":

    # Execute the test
    # test_attention(16, 32, 4096, 16, False)
    print("causal false passed")

    test_attention(16, 32, 1024, 16, True)

causal false passed


RuntimeError: FlashAttention only supports Ampere GPUs or newer.