Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Results with TF32 on Main #1937

Open
IzzyPutterman opened this issue Jul 12, 2023 · 8 comments
Open

Incorrect Results with TF32 on Main #1937

IzzyPutterman opened this issue Jul 12, 2023 · 8 comments
Labels

Comments

@IzzyPutterman
Copy link
Contributor

Running on Ampere A6000, Triton commit fd89aa1d2bca4652f383b70f81d993f258e4440f
Taken from this issue: #1840

import torch
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config(
            {
                "BLOCK_SIZE_M": 32,
                "BLOCK_SIZE_N": 64,
                "BLOCK_SIZE_K": 32,
                "GROUP_SIZE_M": 8,
            },
            num_stages=3,
            num_warps=8,
        ),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def linear_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    # Matrix dimensions
    M,
    N,
    K,
    # Strides
    stride_am,
    stride_ak,
    stride_bn,
    stride_bk,
    stride_cm,
    stride_cn,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    """Kernel for computing the linear C = A x B^T.
    A has shape (M, K), B has shape (N, K) and C has shape (M, N).
    Here M=batch_size, N=input_features, K=output_features.
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse
    # See above `L2 Cache Optimizations` section for details
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # b_ptrs is a block of [BLOCK_SIZE_N, BLOCK_SIZE_K] pointers
    # see above `Pointer Arithmetics` section for details
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Note that for simplicity, we don't apply a mask here.
        # This means that if K is not a multiple of BLOCK_SIZE_K,
        # this will access out-of-bounds memory and produce an
        # error or (worse!) incorrect results.
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        # We accumulate along the K dimension
        accumulator += tl.dot(a, tl.trans(b), allow_tf32=True)
        # Advance the ptrs to the next K block
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    # while the accumulator is still in FP32!
    c = accumulator.to(tl.float32)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def linear(input, weight):
    # checks constraints
    assert input.shape[1] == weight.shape[1], "incompatible dimensions"
    assert input.is_contiguous(), "matrix A must be contiguous"
    assert weight.is_contiguous(), "matrix B must be contiguous"
    M, K = input.shape
    N, K = weight.shape
    assert (
        K % 32 == 0
    ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
    # allocates output
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (
        triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
    )
    bin = linear_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
    )
    return c


# %%
# Unit Test
# -----------
#
# We can test our custom Linear against torch Linear

torch.manual_seed(0)
a = torch.randn((256, 128), device="cuda", dtype=torch.float32)
b = torch.randn((512, 128), device="cuda", dtype=torch.float32)
triton_output = linear(a, b)
torch_output = torch.nn.functional.linear(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
triton.testing.assert_close(triton_output, torch_output)

Output on my end

AssertionError: 
Not equal to tolerance rtol=0, atol=0.01

Mismatched elements: 106533 / 131072 (81.3%)
Max absolute difference: 73.00191
Max relative difference: 107711.47
 x: array([[ 11.251299, -15.285846,  18.187695, ..., -14.179108,   6.608107,
          2.885912],
       [ 21.856983,  -4.152937,  19.599163, ..., -14.177496, -11.181219,...
 y: array([[ 11.254191,   0.467534,  -8.10821 , ...,   4.932623,  -4.681879,
         -2.327722],
       [ 21.879164, -15.190123,  13.634354, ..., -11.283187,  -3.57296 ,...

Setting allow_tf32=False makes everything work.

@Jokeren
Copy link
Contributor

Jokeren commented Jul 13, 2023

There might be a problem with trans. Can you try to rewrite the kernel without trans? I think that should work

@Jokeren
Copy link
Contributor

Jokeren commented Jul 13, 2023

fp16 works but fp32 doesn't

@Jokeren Jokeren added the bug label Jul 13, 2023
@3gx
Copy link

3gx commented Jul 25, 2023

if I use "BLOCK_SIZE_N": 32, and change the accumulator statement to accumulator += tl.dot(a, tl.trans(b)*(0*b+1), allow_tf32=True) the test appears to pass (with rtol=atol=0.01)

Mismatched elements: 1 / 131072 (0.000763%)
Max absolute difference: 0.04163361
Max relative difference: 37.487473
 x: array([[ 11.251299,   0.469602,  -8.10081 , ...,   4.933978,  -4.67709 ,
         -2.331005],
       [ 21.856983, -15.175253,  13.622528, ..., -11.272243,  -3.573601,...
 y: array([[ 11.254191,   0.467534,  -8.10821 , ...,   4.932623,  -4.681879,
         -2.327722],
       [ 21.879164, -15.190123,  13.634354, ..., -11.283187,  -3.57296 ,...

Same result when changing b_ptrs and accumulator statements to
b_ptrs = b_ptr + (offs_bn[None,:] * stride_bn + offs_k[:, None] * stride_bk) and accumulator += tl.dot(a, b, allow_tf32=True) respectively.

@ymwangg
Copy link

ymwangg commented Aug 17, 2023

@3gx I'm also looking into similar issues, is it possible to express the transposed block ptrs b_ptrs = b_ptr + (offs_bn[None,:] * stride_bn + offs_k[:, None] * stride_bk) using tl.make_block_ptr API?

@ymwangg
Copy link

ymwangg commented Aug 22, 2023

Same result when using block ptrs by making the following changes:

    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
                                    offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
                                    order=(1, 0))
    b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
                                    offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
                                    order=(0, 1))

and

accumulator += tl.dot(a, b, allow_tf32=False)
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))

@BinFan
Copy link
Contributor

BinFan commented Aug 25, 2023

I created #2180 to address the issue

@BinFan
Copy link
Contributor

BinFan commented Aug 25, 2023

btw, I think the reason why fp16 works is by luck, because in this case the mma shape m=n=k=8, so even we swap the n and k the result is the same.

BinFan added a commit to BinFan/triton that referenced this issue Aug 25, 2023
… swizzling code

Also some addtional fix in Pipeline pass

This fix bug triton-lang#1937
ptillet added a commit that referenced this issue Sep 13, 2023
… correct swizzling code (#2180)

fix bug #1937

Co-authored-by: Philippe Tillet <phil@openai.com>
@IzzyPutterman
Copy link
Contributor Author

Results look closer after the merge, but how close are we expecting compared to torch. Running the above script gives

Mismatched elements: 32488 / 131072 (24.8%)                                                                                                                                                               
Max absolute difference: 0.04163361                                                                                                                                                                       
Max relative difference: 37.487473                                                                                                                                                                        
 x: array([[ 11.251299,   0.469602,  -8.10081 , ...,   4.933978,  -4.67709 ,                                                                                                                              
         -2.331005],                                                                                                                                                                                      
       [ 21.856983, -15.175253,  13.622528, ..., -11.272243,  -3.573601,...                                                                                                                               
 y: array([[ 11.254191,   0.467534,  -8.10821 , ...,   4.932623,  -4.681879,                                                                                                                              
         -2.327722],                                                                                                                                                                                      
       [ 21.879164, -15.190123,  13.634354, ..., -11.283187,  -3.57296 ,... 

alexander-zinoviev pushed a commit to alexander-zinoviev/triton that referenced this issue Sep 21, 2023
… correct swizzling code (triton-lang#2180)

fix bug triton-lang#1937

Co-authored-by: Philippe Tillet <phil@openai.com>
pingzhuu pushed a commit to siliconflow/triton that referenced this issue Apr 2, 2024
… correct swizzling code (triton-lang#2180)

fix bug triton-lang#1937

Co-authored-by: Philippe Tillet <phil@openai.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants