Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward of Flash Attention produces incorrect results in fp32 with tl.dot(allow_tf32=False) #1821

Closed
szmigacz opened this issue Jun 22, 2023 · 5 comments · Fixed by #1913
Closed

Comments

@szmigacz
Copy link

szmigacz commented Jun 22, 2023

Repro with modified code from tutorials:

import pytest
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel(
    Q, K, V, sm_scale,
    L, M,
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
    off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    q = tl.load(q_ptrs)
    # loop over k, v and update accumulator
    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
        # -- compute qk ----
        k = tl.load(k_ptrs)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, allow_tf32=False)
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        # compute new m
        m_curr = tl.maximum(tl.max(qk, 1), m_prev)
        # correct old l
        l_prev *= tl.exp(m_prev - m_curr)
        # attention weights
        p = tl.exp(qk - m_curr[:, None])
        l_curr = tl.sum(p, 1) + l_prev
        # rescale operands of matmuls
        l_rcp = 1. / l_curr
        p *= l_rcp[:, None]
        acc *= (l_prev * l_rcp)[:, None]
        # update acc
        p = p.to(Q.dtype.element_ty)
        v = tl.load(v_ptrs)
        acc += tl.dot(p, v, allow_tf32=False)
        # update m_i and l_i
        l_prev = l_curr
        m_prev = m_curr
        # update pointers
        k_ptrs += BLOCK_N * stride_kn
        v_ptrs += BLOCK_N * stride_vk
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    l_ptrs = L + off_hz * N_CTX + offs_m
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(l_ptrs, l_prev)
    tl.store(m_ptrs, m_prev)
    # initialize pointers to output
    offs_n = tl.arange(0, BLOCK_DMODEL)
    off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc)


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, sm_scale):
        BLOCK = 32
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        num_warps = 4 if Lk <= 64 else 8

        _fwd_kernel[grid](
            q, k, v, sm_scale,
            L, m,
            o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            q.shape[0], q.shape[1], q.shape[2],
            BLOCK_M=BLOCK, BLOCK_N=BLOCK,
            BLOCK_DMODEL=Lk, num_warps=num_warps,
            num_stages=2,
        )

        ctx.save_for_backward(q, k, v, o, L, m)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        return o



attention = _attention.apply


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float32):
    torch.manual_seed(20)
    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
    sm_scale = 0.2
    dout = torch.randn_like(q)
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    for z in range(Z):
        for h in range(H):
            p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1)
    # p = torch.exp(p)
    ref_out = torch.matmul(p, v)
    tri_out = attention(q, k, v, sm_scale)
    assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)

Launch:

python3 -m pytest 06-fused-attention.py -s -k test_op

Output:

(...)
>       assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7f80d7baa540>(tensor([[[[ 0.4037,  0.3309,  0.3346,  ...,  0.0815,  0.3073,  0.3725],\n          [ 0.3167,  0.5633,  0.2115,  ...,  0...[ 0.2973,  0.3009,  0.3038,  ...,  0.3017,  0.2963,  0.3065]]]],\n       device='cuda:0', grad_fn=<UnsafeViewBackward0>), tensor([[[[ 0.4037,  0.3309,  0.3346,  ...,  0.0815,  0.3073,  0.3725],\n          [ 0.3167,  0.5633,  0.2115,  ...,  0... [ 0.1591,  0.1614,  0.1616,  ...,  0.1601,  0.1591,  0.1632]]]],\n       device='cuda:0', grad_fn=<_attentionBackward>), atol=0.01, rtol=0)
E        +    where <built-in method allclose of type object at 0x7f80d7baa540> = torch.allclose

Modifications applied to original code from tutorials:

  • tl.dot(...) -> tl.dot(..., allow_tf32=False)
  • changed BLOCK from 128 to 32
  • changed input dtype from torch.float16 to torch.float32

Triton: latest main (5686c51)
GPU: NVIDIA RTX A6000 (Ampere, sm_86)
Container: pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel

@Jokeren
Copy link
Contributor

Jokeren commented Jun 22, 2023

Thanks, we are working on it. #1671

@ThomasRaoux
Copy link
Contributor

Thanks, we are working on it. #1671

Is it the same problem? The other issue happens only with allow_tf_32=True, also I'm not able to reproduce it on ToT

@Jokeren
Copy link
Contributor

Jokeren commented Jun 22, 2023

Thanks, we are working on it. #1671

Is it the same problem? The other issue happens only with allow_tf_32=True, also I'm not able to reproduce it on ToT

Oh I see, hmm, let's take a look at this after the other issue has been solved.

@szmigacz
Copy link
Author

I reproduced this problem on RTX A6000 and A100 in both pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel and nvcr.io/nvidia/pytorch:23.03-py3 containers (with latest triton: 7b30e24)

@ThomasRaoux
Copy link
Contributor

Will be fixed with #1913

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants