Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial example 03 performance issue #1122

Open
narendrachaudhary51 opened this issue May 14, 2024 · 16 comments · Fixed by #1185
Open

Tutorial example 03 performance issue #1122

narendrachaudhary51 opened this issue May 14, 2024 · 16 comments · Fixed by #1185

Comments

@narendrachaudhary51
Copy link

Hi,

I am trying to run the benchmark in python tutorial 03-matrix-multiplication.py.
I do not see the expected performance with triton. Even for higher square matrix sizes the performance of triton is not improving.

I am using the following XPU hardware.
Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100 1.3 [1.3.26918]

image

matmul-performance-fp16:
M N K rocBLAS Triton
0 256.0 256.0 256.0 0.108552 0.065673
1 384.0 384.0 384.0 0.363976 0.154921
2 512.0 512.0 512.0 0.857502 0.162093
3 640.0 640.0 640.0 1.664666 0.174643
4 768.0 768.0 768.0 2.829421 0.307113
5 896.0 896.0 896.0 4.453225 0.395680
6 1024.0 1024.0 1024.0 6.550690 0.316575
7 1152.0 1152.0 1152.0 9.213149 0.378349
8 1280.0 1280.0 1280.0 12.485583 0.444944
9 1408.0 1408.0 1408.0 15.589347 0.518402
10 1536.0 1536.0 1536.0 19.637789 0.576684
11 1664.0 1664.0 1664.0 23.102231 0.644403
12 1792.0 1792.0 1792.0 29.274080 0.726035
13 1920.0 1920.0 1920.0 32.775947 0.721385
14 2048.0 2048.0 2048.0 36.019792 0.722632
15 2176.0 2176.0 2176.0 43.476061 0.755484
16 2304.0 2304.0 2304.0 50.967526 0.797482
17 2432.0 2432.0 2432.0 59.675966 0.516829
18 2560.0 2560.0 2560.0 65.795927 0.524481
19 2688.0 2688.0 2688.0 72.057158 0.525320
20 2816.0 2816.0 2816.0 75.685494 0.397117
21 2944.0 2944.0 2944.0 78.605996 0.397091
22 3072.0 3072.0 3072.0 77.401139 0.392022
23 3200.0 3200.0 3200.0 93.879067 0.395438
24 3328.0 3328.0 3328.0 100.308276 0.366734
25 3456.0 3456.0 3456.0 107.904954 0.323059
26 3584.0 3584.0 3584.0 116.285356 0.305686
27 3712.0 3712.0 3712.0 97.801647 0.296061
28 3840.0 3840.0 3840.0 100.229800 0.293752
29 3968.0 3968.0 3968.0 103.656807 0.278301
30 4096.0 4096.0 4096.0 105.869743 0.265378

@vlad-penkin
Copy link
Contributor

@narendrachaudhary51 could you please provide the information on the environment

@narendrachaudhary51
Copy link
Author

  • I build the triton, ipex and pytorch from source using "scripts/compile-triton.sh" and "scripts/test-triton.sh" scripts.
  • My oneAPI version is 2024.1.0
  • My OS version is Rocky Linux 9.2 on a cluster, therefore commands mentioned in the link did not work for me.
  • I tried the same with "yum info" but that doesn't seem to work.

I was able to run the following command
xpu-smi discovery | grep "Device Name" | sed -n 's/.Device Name: (.)\s|/\1/p' >gpu.txt*
and gpu.txt contains this -
Intel(R) Data Center GPU Max 1100
Intel(R) Data Center GPU Max 1100

@fcharras
Copy link

fcharras commented May 16, 2024

@narendrachaudhary51 I'm playing with this exemple too and noticed the performance issue too.

I think the grid search parameters have not been tuned. The grid search parameters currently used are the same than for cuda, but I noticed that using higher num_warp offers up to a 3x speedup for the max series gpu. The best parameters I have found so far are:

BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 16, num_ctas: 1, num_stages: 5

with those parameters, the 512 * 512 matmul performance is very close to torch.matmul performance. (however I unfortunately also noticed that the performance becomes worse with higher dimensions)

@fcharras
Copy link

fcharras commented May 16, 2024

Here are the changes to the example that I find gives better performance: #1139

It also improves performance for higher dimensions with a 3 to 4 times speedup, but not to the point of reaching torch.matmul performance, e.g I find torch.matmul 100x faster on 4096 * 4096 example, and still 20x faster after this change:

In [6]: %time matmul(a, b)[0,0].cpu()
CPU times: user 17.3 ms, sys: 8.88 ms, total: 26.2 ms
Wall time: 25.6 ms
Out[6]: tensor(-64.9375, dtype=torch.float16)

In [7]: %time torch.matmul(a, b)[0,0].cpu()
CPU times: user 1.15 ms, sys: 788 µs, total: 1.94 ms
Wall time: 1.82 ms
Out[7]: tensor(-64.9375, dtype=torch.float16)

Maybe more work on the grid search could help having a good speedup on higher dimensions too.

@narendrachaudhary51
Copy link
Author

@fcharras Thank you for your reply. I suspected that grid search parameters could be the cause.
I will play with grid search parameters and check performance. Do you have a guess on why more warps help with XPU performance?

@fcharras
Copy link

I have only some intuition and a limited understanding of the xpu concepts to which warps (which is a cuda concept only) could be mapped to (execution units ? subslices ?), but I suspected that "num_warps" value in the grid is too low in this regard. But to be honest I'm surprised that it's good for cuda device to set it so low to begin with. So it was more some luck...

@fcharras
Copy link

fcharras commented May 16, 2024

From (again) entry-level understanding of it, matmul performance is a balance that is hard to achieve, between global and local memory bandwidth, cache bandwidth and hit rate, and actual compute, and I thought that if something has to be a bottleneck there, it might be the compute because of a too low number of threads being leveraged, which is (I think) increased when increasing num_warp.

@etiotto
Copy link
Contributor

etiotto commented May 16, 2024

By default the Triton kernel in the tutorial compiles the matrix multiplication operation to a sequence of floating point multiply-add scalar instructions. We can force the tutorial to use 16 threads per warp which allows our compiler to generate specialized HW instructions (DPAS) rather than scalar FMAs. Performance then improves ~7X for large problem sizes.

The code to change in the tutorial is:

    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation,  #
        threads_per_warp=16   <<<<< add this parameter
    )

Performance is still lacking behind the torch.matmul implementation (which offload that computation to a specialized oneDNN library). Future improvement are WIP.

I will post a PR to change the warp size used by the tutorial.

@vlad-penkin vlad-penkin linked a pull request May 16, 2024 that will close this issue
@fcharras
Copy link

Thank @etiotto , which this tip, and also suggestion in #1139 (comment) , I'm seeing almost equivalent walltime on the 512 * 512 example, and it has increased a lot on larger (4096 * 4096) example, albeit still below torch.matmul:

In [7]: %time torch.matmul(a, b)[0,0].cpu()
CPU times: user 0 ns, sys: 1.77 ms, total: 1.77 ms
Wall time: 1.67 ms
Out[7]: tensor(-64.9375, dtype=torch.float16)

In [8]: %time matmul(a, b)[0,0].cpu()
CPU times: user 4.79 ms, sys: 4.3 ms, total: 9.09 ms
Wall time: 8.77 ms
Out[8]: tensor(-64.9375, dtype=torch.float16)

maybe the matmul with the experimental block pointer approach (tutorial 09) will give better results ?

@ogrisel
Copy link

ogrisel commented May 17, 2024

maybe the matmul with the experimental block pointer approach (tutorial 09) will give better results?

That's an interesting question. Note however that the experimental example, including tutorial 09 on the use of block pointers for matrix matrix multiplication have been removed from the upstream repo (triton-lang/triton#3371), but I am not sure why.

Still I would be interested in the performance results of a Max Series GPU with a tuned grid and optimal threads_per_warp with block pointers.

@narendrachaudhary51
Copy link
Author

narendrachaudhary51 commented May 17, 2024

I tried the change suggested by @etiotto and increased the number of warps. This gave me speedup across the board.
I use the following autotune parameters and obtained the matmul performance changes. Currently, it is 100x faster than the default parameters. However, we are still 7-8x slower compared to the torch version.
def get_xpu_autotune_config():
return [
# FIXME: Once tl.dot uses DPAS put back the workload commented out.
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
# num_warps=64),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=16),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=16),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=32)
]

image

I tried several other configurations, but I was not able to go beyond this. Is there a way to check and analyze the triton generated code? We can then identify the inefficiencies in the code by comparing it to torch implementation.

@ogrisel
Copy link

ogrisel commented May 17, 2024

The rocBLAS legend actually refers to Intel oneDNN's kernel wrapped as pytorch XPU matmul, right?

@narendrachaudhary51
Copy link
Author

@ogrisel Yes. It is the pytorch XPU matmul. It must be using the oneDNN kernel underneath.

@narendrachaudhary51
Copy link
Author

Is the triton generated code using 2D loads when doing the following operations?

a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

How can I dump the generated code?

@whitneywhtsang
Copy link
Contributor

How can I dump the generated code?

Output of different stages can be found in TRITON_CACHE_DIR, or you can use MLIR_ENABLE_DUMP=1 to dump the IR before every MLIR pass Triton runs.

@narendrachaudhary51
Copy link
Author

narendrachaudhary51 commented May 27, 2024

@ogrisel Block pointer-based matrix multiplication is faster than the previous implementation which only reached peak performance of 25 TFlops. But current performance is still below oneDNN matrix multiplication.

matmul-performance-block-pointerfp16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment