Skip to content

Conversation

@zhang-hui-yulo
Copy link
Contributor

Enable mul_mat_f for RDNA4 and move the n >= 3 workload from mmvf to mmf based on the result of test-backend-ops.

Use a weird unreached branch to force rocm compiler to generate better performance code for RDNA4, a bug shall be submitted to rocm.

System: Ubuntu 24.04.3 LTS
ROCm: 7.1.0
Driver: amdgpu version: 6.16.6 ROCm version: 7.1.0
GPU: 9070XT

MUL_MAT results
Backend GGML op Op parameters TFLOPS master TFLOPS 698c9f2 Speedup
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 0.61 0.61 1.00
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.20 1.20 1.00
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.65 1.87 1.13
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.98 2.49 1.26
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.50 3.11 1.25
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 91.95 93.59 1.02
ROCm0 MUL_MAT type_a=bf16,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.92 4.95 1.26
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],k_v=32832,o=1 1.38 1.39 1.01
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1 0.34 0.34 1.00
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 0.61 0.61 1.00
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.21 1.21 1.00
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.77 1.87 1.06
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.27 2.49 1.10
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.68 3.11 1.16
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 96.01 96.28 1.00
ROCm0 MUL_MAT type_a=f16,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.04 4.95 1.23
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 0.31 0.31 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 0.63 0.63 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 0.94 0.94 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.25 1.25 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.55 1.54 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.48 3.49 1.00
ROCm0 MUL_MAT type_a=f32,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.30 2.31 1.00
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.63 3.67 1.01
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.95 5.99 1.01
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.03 7.08 1.01
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.86 7.89 1.00
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.21 8.32 1.01
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 74.78 75.01 1.00
ROCm0 MUL_MAT type_a=iq1_m,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.38 9.34 1.00
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.16 4.18 1.00
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.90 6.92 1.00
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.42 7.49 1.01
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.25 8.29 1.00
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.87 8.97 1.01
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 75.22 75.50 1.00
ROCm0 MUL_MAT type_a=iq1_s,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.05 8.92 0.99
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.61 1.62 1.01
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.86 2.88 1.01
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.97 3.99 1.00
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.81 4.83 1.00
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.11 5.12 1.00
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 74.32 74.48 1.00
ROCm0 MUL_MAT type_a=iq2_s,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.53 6.56 1.00
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.20 2.22 1.01
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.85 3.88 1.01
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.07 5.10 1.01
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.90 5.92 1.00
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.64 6.67 1.00
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 74.33 74.51 1.00
ROCm0 MUL_MAT type_a=iq2_xs,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.80 7.83 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.66 1.69 1.01
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.00 3.01 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.17 4.19 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.05 5.08 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.84 5.86 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 74.68 74.76 1.00
ROCm0 MUL_MAT type_a=iq2_xxs,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.93 6.96 1.00
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.56 1.57 1.01
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.88 2.90 1.01
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.05 4.06 1.00
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.03 5.04 1.00
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.86 5.88 1.00
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 73.48 73.44 1.00
ROCm0 MUL_MAT type_a=iq3_s,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.07 7.09 1.00
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.14 2.16 1.01
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.81 3.83 1.00
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.10 5.12 1.01
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.83 5.85 1.00
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.88 6.90 1.00
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 73.48 73.65 1.00
ROCm0 MUL_MAT type_a=iq3_xxs,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.71 7.75 1.01
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.81 3.83 1.00
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.39 5.41 1.00
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.30 7.35 1.01
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.74 8.77 1.00
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.56 8.58 1.00
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.57 72.68 1.00
ROCm0 MUL_MAT type_a=iq4_nl,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.27 9.36 1.01
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.78 3.77 1.00
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.57 6.59 1.00
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.99 9.02 1.00
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.64 9.70 1.01
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 10.28 10.32 1.00
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.64 72.71 1.00
ROCm0 MUL_MAT type_a=iq4_xs,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 10.15 10.20 1.00
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.58 3.61 1.01
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.08 5.13 1.01
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.19 7.25 1.01
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.55 8.65 1.01
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.72 8.72 1.00
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 74.08 74.29 1.00
ROCm0 MUL_MAT type_a=mxfp4,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.15 9.24 1.01
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.87 2.88 1.00
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.71 3.73 1.01
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.03 4.07 1.01
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.22 4.25 1.01
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.33 4.35 1.00
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.45 72.60 1.00
ROCm0 MUL_MAT type_a=q2_K,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.37 4.40 1.01
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.72 1.74 1.02
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.83 2.85 1.01
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.42 3.46 1.01
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.84 3.88 1.01
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.12 4.13 1.00
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 68.80 68.89 1.00
ROCm0 MUL_MAT type_a=q3_K,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.42 4.44 1.01
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.05 4.08 1.01
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.50 5.53 1.00
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.40 7.45 1.01
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.77 8.82 1.01
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.64 8.63 1.00
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.69 72.98 1.00
ROCm0 MUL_MAT type_a=q4_0,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.48 9.51 1.00
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.99 4.00 1.00
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.92 6.96 1.01
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.65 7.70 1.01
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.05 9.12 1.01
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.75 8.76 1.00
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.43 72.55 1.00
ROCm0 MUL_MAT type_a=q4_1,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.74 9.86 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.71 2.75 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.70 3.73 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.06 4.08 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.25 4.28 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.38 4.40 1.01
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 72.11 72.19 1.00
ROCm0 MUL_MAT type_a=q4_K,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.58 4.61 1.01
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.12 3.14 1.01
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.24 5.26 1.00
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.44 6.47 1.00
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.83 7.88 1.01
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.18 8.18 1.00
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 63.07 63.16 1.00
ROCm0 MUL_MAT type_a=q5_0,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.89 8.97 1.01
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.37 3.39 1.01
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.62 5.60 1.00
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.07 7.14 1.01
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.08 8.10 1.00
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.82 8.84 1.00
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 63.58 63.86 1.00
ROCm0 MUL_MAT type_a=q5_1,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 9.56 9.69 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.55 2.57 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.55 3.57 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.92 3.95 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.13 4.16 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.31 4.33 1.01
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 69.48 69.57 1.00
ROCm0 MUL_MAT type_a=q5_K,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.50 4.53 1.01
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 1.84 1.86 1.01
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.91 2.94 1.01
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 3.67 3.71 1.01
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.16 4.20 1.01
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.52 4.54 1.00
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 69.97 70.02 1.00
ROCm0 MUL_MAT type_a=q6_K,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 5.16 5.20 1.01
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 2.57 2.56 1.00
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=2,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 4.66 4.67 1.00
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=3,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 6.47 6.54 1.01
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=4,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.42 7.50 1.01
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=5,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 7.81 7.88 1.01
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 68.41 68.60 1.00
ROCm0 MUL_MAT type_a=q8_0,type_b=f32,m=4096,n=8,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1 8.04 8.15 1.01

model: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

bf16
GPU Model Microbatch size Test t/s master t/s 698c9f2 Speedup
RX 9070 XT qwen2 1.5B BF16 1 pp512 156.89 157.12 1.00
RX 9070 XT qwen2 1.5B BF16 2 pp512 250.88 250.35 1.00
RX 9070 XT qwen2 1.5B BF16 4 pp512 375.40 518.08 1.38
RX 9070 XT qwen2 1.5B BF16 8 pp512 744.81 1039.92 1.40
RX 9070 XT qwen2 1.5B BF16 16 pp512 1466.76 2014.38 1.37
RX 9070 XT qwen2 1.5B BF16 32 pp512 2803.02 2758.97 0.98
RX 9070 XT qwen2 1.5B BF16 64 pp512 4754.45 4704.70 0.99
RX 9070 XT qwen2 1.5B BF16 128 pp512 7872.05 7756.98 0.99
RX 9070 XT qwen2 1.5B BF16 256 pp512 11419.86 11277.21 0.99
RX 9070 XT qwen2 1.5B BF16 512 pp512 14612.64 14515.26 0.99
f16
GPU Model Microbatch size Test t/s master t/s 698c9f2 Speedup
RX 9070 XT qwen2 1.5B F16 1 pp512 156.74 156.50 1.00
RX 9070 XT qwen2 1.5B F16 2 pp512 245.37 245.48 1.00
RX 9070 XT qwen2 1.5B F16 4 pp512 398.99 518.60 1.30
RX 9070 XT qwen2 1.5B F16 8 pp512 796.58 1033.86 1.30
RX 9070 XT qwen2 1.5B F16 16 pp512 1592.32 2017.26 1.27
RX 9070 XT qwen2 1.5B F16 32 pp512 3022.76 2982.65 0.99
RX 9070 XT qwen2 1.5B F16 64 pp512 5170.13 5083.67 0.98
RX 9070 XT qwen2 1.5B F16 128 pp512 8504.90 8457.53 0.99
RX 9070 XT qwen2 1.5B F16 256 pp512 12656.47 12706.78 1.00
RX 9070 XT qwen2 1.5B F16 512 pp512 18188.82 18325.39 1.01
f32
GPU Model Microbatch size Test t/s master t/s 698c9f2 Speedup
RX 9070 XT qwen2 1.5B all F32 1 pp512 99.78 99.41 1.00
RX 9070 XT qwen2 1.5B all F32 2 pp512 185.40 185.05 1.00
RX 9070 XT qwen2 1.5B all F32 4 pp512 343.46 341.67 0.99
RX 9070 XT qwen2 1.5B all F32 8 pp512 494.70 493.65 1.00
RX 9070 XT qwen2 1.5B all F32 16 pp512 725.07 723.68 1.00
RX 9070 XT qwen2 1.5B all F32 32 pp512 1144.63 1145.15 1.00
RX 9070 XT qwen2 1.5B all F32 64 pp512 1409.09 1404.45 1.00
RX 9070 XT qwen2 1.5B all F32 128 pp512 846.13 842.86 1.00
RX 9070 XT qwen2 1.5B all F32 256 pp512 1263.61 1263.73 1.00
RX 9070 XT qwen2 1.5B all F32 512 pp512 1370.07 1361.65 0.99

Best Regards
Hui

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 22, 2025
@zhang-hui-yulo zhang-hui-yulo changed the title Enable mul_mat_f for RDNA4 HIP: enable mul_mat_f for RDNA4 Nov 22, 2025
@zhang-hui-yulo
Copy link
Contributor Author

Add the output of "test-backend-ops test -o MUL_MAT" to check rocm generated code, the result shall be fine.

test.txt

@JohannesGaessler
Copy link
Collaborator

As of right now the data is being loaded in chunks of 8 bytes, the maximum size that AMD/NVIDIA GPUs support and the size that I am targeting in terms of SRAM padding is 16 bytes. Before we resort to black magic, please try loading the data using ggml_cuda_memcpy_1<16>, to my understanding tile<16, 8, half2>/tile<16, 8, nv_bfloat162> have the correct data layout for it. If I had to guess the ROCm compiler is failing to correctly batch the data loads into chunks of 16 bytes on master.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Nov 25, 2025

As of right now the data is being loaded in chunks of 8 bytes, the maximum size that AMD/NVIDIA GPUs support and the size that I am targeting in terms of SRAM padding is 16 bytes. Before we resort to black magic, please try loading the data using ggml_cuda_memcpy_1<16>, to my understanding tile<16, 8, half2>/tile<16, 8, nv_bfloat162> have the correct data layout for it. If I had to guess the ROCm compiler is failing to correctly batch the data loads into chunks of 16 bytes on master.

Hello @JohannesGaessler

I'm not sure the meaning of "chunks of 8 bytes", is it smem or gmem? Based on my understanding, gmem -> smem is just 4 bytes loading on K dim (2 half2 or 1 float), smem -> rmem is 16 bytes loading (ldmatrix for NV and ggml_cuda_memcpy_1 for amd).

I just look at mma.cuh, I'm not sure why 0543f92 removed my ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic.

Honestly I've tried 16 bytes loading gmem -> rmem for A matrix (it's reasonable as tile_xy is used to load full matrix A) and 16 bytes loading gmem -> smem and smem -> rmem, the performance is same or less than doing nothing.

Best Regards
Hui

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Nov 25, 2025

I have one more suggestion, mma.cuh is developed on NVIDIA GPUs, the generic tile<I, J, T> is designed for matrix A, B and C in mma, this isn't suitable for AMD GPU as matrix A(row-major) and C(col-major) are different, this is why mmq is weird when using tile<16, 4, int> tile_A (row-major) and tile<16, 16, int> tile_C (col-major).

This piece of code isn't friendly to read, honestly all data in matrix A and B on RDNA are continues, a simple ggml_cuda_memcpy_1 is enough, all data position related code shall belongs to tile itself.

Less friendly

#elif defined(AMD_WMMA_AVAILABLE)
        if constexpr (I == 16 && J == 4) {
            int64_t * xi = (int64_t *) t.x;
            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
            xi[0] = xs[0];
        }else if constexpr (I == 16 && J == 8) {
            int64_t * xi = (int64_t *) t.x;
            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
            xi[0] = xs[0];

            const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
            xi[1] = xs1[0];
        }else{
            NO_DEVICE_CODE;
        }
#else

More friendly and can use load_128 not two load_64 (maybe compiler might do the optimization but we do it by ourself)

#elif defined(AMD_WMMA_AVAILABLE)
        if constexpr (I == 16 && J == 4) {
            int64_t * xi = (int64_t *) t.x;
            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
            xi[0] = xs[0];
        }else if constexpr (I == 16 && J == 8) {
            ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
        }else{
            NO_DEVICE_CODE;
        }
#else

So, yes, mma.cuh needs a refactor, this is what I want to do for FA, add a subclass of generic tile to handle transposed mma like matrix C on RDNA.

template <int I, int J, typename T, bool trans>
class tile : public tile<I, J, T> {
    int get_i() {
        if (trans) {
            return tile<I, J, T>::get_j():
        } else {
            return tile<I, J, T>::get_i():
        }
    }
}

@JohannesGaessler
Copy link
Collaborator

I just look at mma.cuh, I'm not sure why 0543f92 removed my ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic.

I don't know either, I didn't spot it in the code diff or else I would have asked about it during review. In any case, as of right now data loading is unfortunately handled in an inconsistent way in mma.cuh. The SRAM data layout in the kernels I use the interface in has a padding of 16 bytes between columns/rows. This is the requirement for ldmatrix. But all relevant GPUs also support loading data from SRAM in chunks of 16 bytes (I forgot what the exact instruction was called in the AMD ISA documentation). When optimizing code for RX 6800/MI50 I found for fattn-tile.cuh that this works quite well. cc @jiachengjason

I have one more suggestion, mma.cuh is developed on NVIDIA GPUs, the generic tile<I, J, T> is designed for matrix A, B and C in mma, this isn't suitable for AMD GPU as matrix A(row-major) and C(col-major) are different, this is why mmq is weird when using tile<16, 4, int> tile_A (row-major) and tile<16, 16, int> tile_C (col-major).

I will either today or tomorrow make a PR that extends the tile template:

    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
    // In those cases the data can be split in different ways across the warp.
    enum data_split {
        DATA_SPLIT_NONE     =  0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
        DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup.
    };
    // Implemented mma combinations are:
    //   - (NONE,     NONE)     -> NONE
    //   - (NONE,     MIRRORED) -> NONE

    template <int I_, int J_, typename T, data_split ds_=DATA_SPLIT_NONE, bool transposed=false>
    struct tile {};

Originally I also had support for other data layouts but I found this to not perform well and cut it:

    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
    // In those cases the data can be split in different ways across the warp.
    enum data_split {
        DATA_SPLIT_NONE     =  0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
        DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup.
        DATA_SPLIT_I        = 20, // Each data value is held exactly once per warp with striping in the I dimension.
        DATA_SPLIT_J        = 30, // Each data value is held exactly once per warp with striping in the J dimension.
        DATA_SPLIT_PARTIAL  = 40, // Each subgroup holds a partial sum for each data value.
    };
    // Implemented mma combinations are:
    //   - (NONE,     NONE)     -> NONE
    //   - (NONE,     MIRRORED) -> NONE
    //   - (MIRRORED, I)        -> J
    //   - (J,        J)        -> PARTIAL (Due to transposition of B the combination of (J, I) -> PARTIAL is actually implemented.)

@zhang-hui-yulo
Copy link
Contributor Author

I just look at mma.cuh, I'm not sure why 0543f92 removed my ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic.

I don't know either, I didn't spot it in the code diff or else I would have asked about it during review. In any case, as of right now data loading is unfortunately handled in an inconsistent way in mma.cuh. The SRAM data layout in the kernels I use the interface in has a padding of 16 bytes between columns/rows. This is the requirement for ldmatrix. But all relevant GPUs also support loading data from SRAM in chunks of 16 bytes (I forgot what the exact instruction was called in the AMD ISA documentation). When optimizing code for RX 6800/MI50 I found for fattn-tile.cuh that this works quite well. cc @jiachengjason

I have one more suggestion, mma.cuh is developed on NVIDIA GPUs, the generic tile<I, J, T> is designed for matrix A, B and C in mma, this isn't suitable for AMD GPU as matrix A(row-major) and C(col-major) are different, this is why mmq is weird when using tile<16, 4, int> tile_A (row-major) and tile<16, 16, int> tile_C (col-major).

I will either today or tomorrow make a PR that extends the tile template:

    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
    // In those cases the data can be split in different ways across the warp.
    enum data_split {
        DATA_SPLIT_NONE     =  0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
        DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup.
    };
    // Implemented mma combinations are:
    //   - (NONE,     NONE)     -> NONE
    //   - (NONE,     MIRRORED) -> NONE

    template <int I_, int J_, typename T, data_split ds_=DATA_SPLIT_NONE, bool transposed=false>
    struct tile {};

Originally I also had support for other data layouts but I found this to not perform well and cut it:

    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
    // In those cases the data can be split in different ways across the warp.
    enum data_split {
        DATA_SPLIT_NONE     =  0, // Each data value is held exactly once per warp (always applies to Turing, Ampere, Ada Lovelace, consumer Blackwell).
        DATA_SPLIT_MIRRORED = 10, // Each data value is held exactly once per subgroup.
        DATA_SPLIT_I        = 20, // Each data value is held exactly once per warp with striping in the I dimension.
        DATA_SPLIT_J        = 30, // Each data value is held exactly once per warp with striping in the J dimension.
        DATA_SPLIT_PARTIAL  = 40, // Each subgroup holds a partial sum for each data value.
    };
    // Implemented mma combinations are:
    //   - (NONE,     NONE)     -> NONE
    //   - (NONE,     MIRRORED) -> NONE
    //   - (MIRRORED, I)        -> J
    //   - (J,        J)        -> PARTIAL (Due to transposition of B the combination of (J, I) -> PARTIAL is actually implemented.)

Thank you for the info, I assume that 0543f92 needs to handle tile<16, 4, int> for int8 and tile<16, 8, half2> for fp16, I shall revert the code back for <16, 8>.

Anyway, I just use the two int64 loading for mul_mat_f and the performance is same as before with or without the black magic. :(

Based on my knowledge, padding 16 bytes for fp16 mma on RDNA4 is a reasonable number, as the data layout is same as ldmatrix, I just use swizzle<3,3,3> for RDNA4. For RDNA3, you need padding 32 bytes as each thread needs 16 half member for 16x16x16 mma, swizzle<3,3,,3> makes the performance very terrible, swizzle<2,4,2> makes more sense.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Nov 25, 2025

@JohannesGaessler may I have more suggestion except removing the black magic? As I really don't have other way to make the performance normal.

Anyway I still need to submit a bug to ROCm compiler to ask it to generate higher performance code, putting this PR into the main branch will give ROCm more motivation to fix it, or based on my experience ROCm compiler will put RDNA into very low priority.

Also I will add memcpy back for tile<16, 8, half2> in mma.cuh to use lds.128 instruction, although the performance isn't much different.

@JohannesGaessler
Copy link
Collaborator

On Volta I am already permuting the data layout by default:

// On Volta each warp is doing 4 8x8 mma operations in parallel.
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
// However, the i indices in this file are by default permuted to simplify the index calculations.
// #define GGML_CUDA_MMA_NO_VOLTA_PERM

This has comparatively little impact but I think it would be possible to do something similar for RDNA3 by permuting the data layout in the J dimension (will possibly need larger tiles).

@zhang-hui-yulo
Copy link
Contributor Author

On Volta I am already permuting the data layout by default:

// On Volta each warp is doing 4 8x8 mma operations in parallel.
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
// However, the i indices in this file are by default permuted to simplify the index calculations.
// #define GGML_CUDA_MMA_NO_VOLTA_PERM

This has comparatively little impact but I think it would be possible to do something similar for RDNA3 by permuting the data layout in the J dimension (will possibly need larger tiles).

Sorry I don't understand, I think volta 's layout is quite different than RDNA3, you can just assume that RDNA3 is RDNA4 with duplicated data in matrix A and B, and empty output in matrix C for fp16. Unless you use tile<16, 16, half2> tile_A and do the index calculation in load_generic, but this will get the code more complicated.

For RDNA3, I would suggest to read https://gpuopen.com/learn/wmma_on_rdna3/

@JohannesGaessler
Copy link
Collaborator

What I mean for RDNA3 is to load the data for A and B as if it were 16x16x32 tiles (in logical units) with the RDNA4 layout repeated twice. The permutation should exactly cancel out.

Bur actually, now that I think about it, RDNA3 shouldn't be an issue anyways. What matters is that the stride between threads has a padding of 16 bytes, if you load the data using 2 consecutive 16 byte copies that should still work to avoid shared memory bank conflicts.

@zhang-hui-yulo
Copy link
Contributor Author

What I mean for RDNA3 is to load the data for A and B as if it were 16x16x32 tiles (in logical units) with the RDNA4 layout repeated twice. The permutation should exactly cancel out.

Bur actually, now that I think about it, RDNA3 shouldn't be an issue anyways. What matters is that the stride between threads has a padding of 16 bytes, if you load the data using 2 consecutive 16 byte copies that should still work to avoid shared memory bank conflicts.

Honestly, you just need to double the "ne" in tile is enough for matrix A and B on RDNA3, fp16 type matrix C on RDNA3 is a trouble but can be fixed as I have deal with it in my personal CUTE library.

I think you've got something from RDNA3, but I will suggest to pad 32 bytes, I didn't pad it on RDNA3 but using the same swizzle value swizzle<3,3,3> on RDNA3 as Ampere will make the performance extremely terrible, only 10% of the right padding value.

The root cause is that two load 128 instructions will be executed at same time, anyway, it's just parameter, it's easy to do the perf test then adjust it.

@zhang-hui-yulo
Copy link
Contributor Author

@JohannesGaessler may I have more suggestion except removing the black magic? As I really don't have other way to make the performance normal.

Anyway I still need to submit a bug to ROCm compiler to ask it to generate higher performance code, putting this PR into the main branch will give ROCm more motivation to fix it, or based on my experience ROCm compiler will put RDNA into very low priority.

Also I will add memcpy back for tile<16, 8, half2> in mma.cuh to use lds.128 instruction, although the performance isn't much different.

Hello @JohannesGaessler , I've reverted the memcpy in mma, is there any new suggestion from your side? As smem -> rmem shall be load 128 now.

Sorry about the black magic, as I really don't have other way to get the performance reasonable, if you approve it, I assume ROCm would have more motivation to fix it as it's in the official repo, or I'm not sure if ROCm will deal with my personal repo, thank you for the support.

@jiachengjason
Copy link
Contributor

As of right now the data is being loaded in chunks of 8 bytes, the maximum size that AMD/NVIDIA GPUs support and the size that I am targeting in terms of SRAM padding is 16 bytes. Before we resort to black magic, please try loading the data using ggml_cuda_memcpy_1<16>, to my understanding tile<16, 8, half2>/tile<16, 8, nv_bfloat162> have the correct data layout for it. If I had to guess the ROCm compiler is failing to correctly batch the data loads into chunks of 16 bytes on master.

Hello @JohannesGaessler

I'm not sure the meaning of "chunks of 8 bytes", is it smem or gmem? Based on my understanding, gmem -> smem is just 4 bytes loading on K dim (2 half2 or 1 float), smem -> rmem is 16 bytes loading (ldmatrix for NV and ggml_cuda_memcpy_1 for amd).

I just look at mma.cuh, I'm not sure why 0543f92 removed my ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic.

Honestly I've tried 16 bytes loading gmem -> rmem for A matrix (it's reasonable as tile_xy is used to load full matrix A) and 16 bytes loading gmem -> smem and smem -> rmem, the performance is same or less than doing nothing.

Best Regards Hui

Hi @zhang-hui-yulo, I removed ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic because it was causing issues when running ./build/bin/test-backend-ops test -o MUL_MAT, because there was no 16x8 and 16x4 for get_i and get_j for int. But I am going to add it back in for Half2 and Float162 and I agree that using memcpy is more friendly, will eventually change to that for Int as well.

also when running ./build/bin/test-backend-ops test -o MUL_MAT on 028f93e changes, it does not go through your float162 and half2 mmf changes because the tile input tiles are in int in mmq.cuh. I am wondering what test cases you used for the float162 and half2 mmf changes?

@zhang-hui-yulo
Copy link
Contributor Author

As of right now the data is being loaded in chunks of 8 bytes, the maximum size that AMD/NVIDIA GPUs support and the size that I am targeting in terms of SRAM padding is 16 bytes. Before we resort to black magic, please try loading the data using ggml_cuda_memcpy_1<16>, to my understanding tile<16, 8, half2>/tile<16, 8, nv_bfloat162> have the correct data layout for it. If I had to guess the ROCm compiler is failing to correctly batch the data loads into chunks of 16 bytes on master.

Hello @JohannesGaessler
I'm not sure the meaning of "chunks of 8 bytes", is it smem or gmem? Based on my understanding, gmem -> smem is just 4 bytes loading on K dim (2 half2 or 1 float), smem -> rmem is 16 bytes loading (ldmatrix for NV and ggml_cuda_memcpy_1 for amd).
I just look at mma.cuh, I'm not sure why 0543f92 removed my ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic.
Honestly I've tried 16 bytes loading gmem -> rmem for A matrix (it's reasonable as tile_xy is used to load full matrix A) and 16 bytes loading gmem -> smem and smem -> rmem, the performance is same or less than doing nothing.
Best Regards Hui

Hi @zhang-hui-yulo, I removed ggml_cuda_memcpy_1<sizeof(t.x)> in load_generic because it was causing issues when running ./build/bin/test-backend-ops test -o MUL_MAT, because there was no 16x8 and 16x4 for get_i and get_j for int. But I am going to add it back in for Half2 and Float162 and I agree that using memcpy is more friendly, will eventually change to that for Int as well.

also when running ./build/bin/test-backend-ops test -o MUL_MAT on 028f93e changes, it does not go through your float162 and half2 mmf changes because the tile input tiles are in int in mmq.cuh. I am wondering what test cases you used for the float162 and half2 mmf changes?

Hello @jiachengjason , although 028f93e enabled mul_mat_f but I disabled it in cpu side as the performance is not good, so it still uses hipblas path.

This is why I raise this PR to enable mul_mat_f on RDNA4 with some black magic code, the performance makes more sense but I'm not sure if @JohannesGaessler would accept it, anyway a ROCm compiler bug shall be raised to fix the low perf generated code issue.

@JohannesGaessler
Copy link
Collaborator

Please rebase your branches instead of merging master into them, it makes it easier to work with. I cherry-picked the relevant commits to a branch enable_mmf_for_rdna4-2 on my fork of llama.cpp that I use for development: https://github.com/JohannesGaessler/llama.cpp/tree/enable_mmf_for_rdna4-2 . With the weird unreached branch I get:

ggml_cuda_compute_forward: MUL_MAT failed
ROCm error: the operation cannot be performed in the present state
  current device: 0, in function ggml_cuda_compute_forward at /home/johannesg/Projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2727

I don't know why this is happening but we cannot merge broken code onto master. When I remove the unreached branch the code works correctly and the performance looks good as well:

GPU Model Microbatch size Test t/s master t/s d4e50a4 Speedup
RX 9060 XT llama 8B BF16 1 pp512 20.74 20.77 1.00
RX 9060 XT llama 8B BF16 2 pp512 40.77 40.81 1.00
RX 9060 XT llama 8B BF16 4 pp512 6.58 78.82 11.99
RX 9060 XT llama 8B BF16 8 pp512 12.91 137.39 10.65
RX 9060 XT llama 8B BF16 16 pp512 25.91 213.08 8.22
RX 9060 XT llama 8B BF16 32 pp512 51.66 51.39 0.99
RX 9060 XT llama 8B BF16 64 pp512 103.31 102.77 0.99
RX 9060 XT llama 8B BF16 128 pp512 333.85 333.37 1.00
RX 9060 XT llama 8B BF16 256 pp512 339.94 339.25 1.00
RX 9060 XT llama 8B BF16 512 pp512 343.15 342.42 1.00
RX 9060 XT llama 8B F16 1 pp512 20.72 20.71 1.00
RX 9060 XT llama 8B F16 2 pp512 40.41 40.37 1.00
RX 9060 XT llama 8B F16 4 pp512 71.99 65.17 0.91
RX 9060 XT llama 8B F16 8 pp512 43.85 120.91 2.76
RX 9060 XT llama 8B F16 16 pp512 85.50 205.97 2.41
RX 9060 XT llama 8B F16 32 pp512 168.99 168.66 1.00
RX 9060 XT llama 8B F16 64 pp512 206.55 205.82 1.00
RX 9060 XT llama 8B F16 128 pp512 360.94 360.26 1.00
RX 9060 XT llama 8B F16 256 pp512 367.29 366.66 1.00
RX 9060 XT llama 8B F16 512 pp512 369.98 369.52 1.00

The performance declines for batch sizes 16 when rocBLAS is used instead. But as it is ggml doesn't have a floating-point GEMM kernel to handle that case. The original mul_mat_f kernel was only ever designed for very thin matrices. And while the copied mul_mat_f_ids kernel can handle larger batch sizes, it is suboptimally written for that because it took over the memory access patterns from mul_mat_f.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Nov 26, 2025

Please rebase your branches instead of merging master into them, it makes it easier to work with. I cherry-picked the relevant commits to a branch enable_mmf_for_rdna4-2 on my fork of llama.cpp that I use for development: https://github.com/JohannesGaessler/llama.cpp/tree/enable_mmf_for_rdna4-2 . With the weird unreached branch I get:

ggml_cuda_compute_forward: MUL_MAT failed
ROCm error: the operation cannot be performed in the present state
  current device: 0, in function ggml_cuda_compute_forward at /home/johannesg/Projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2727

I don't know why this is happening but we cannot merge broken code onto master. When I remove the unreached branch the code works correctly and the performance looks good as well:

GPU Model Microbatch size Test t/s master t/s d4e50a4 Speedup
RX 9060 XT llama 8B BF16 1 pp512 20.74 20.77 1.00
RX 9060 XT llama 8B BF16 2 pp512 40.77 40.81 1.00
RX 9060 XT llama 8B BF16 4 pp512 6.58 78.82 11.99
RX 9060 XT llama 8B BF16 8 pp512 12.91 137.39 10.65
RX 9060 XT llama 8B BF16 16 pp512 25.91 213.08 8.22
RX 9060 XT llama 8B BF16 32 pp512 51.66 51.39 0.99
RX 9060 XT llama 8B BF16 64 pp512 103.31 102.77 0.99
RX 9060 XT llama 8B BF16 128 pp512 333.85 333.37 1.00
RX 9060 XT llama 8B BF16 256 pp512 339.94 339.25 1.00
RX 9060 XT llama 8B BF16 512 pp512 343.15 342.42 1.00
RX 9060 XT llama 8B F16 1 pp512 20.72 20.71 1.00
RX 9060 XT llama 8B F16 2 pp512 40.41 40.37 1.00
RX 9060 XT llama 8B F16 4 pp512 71.99 65.17 0.91
RX 9060 XT llama 8B F16 8 pp512 43.85 120.91 2.76
RX 9060 XT llama 8B F16 16 pp512 85.50 205.97 2.41
RX 9060 XT llama 8B F16 32 pp512 168.99 168.66 1.00
RX 9060 XT llama 8B F16 64 pp512 206.55 205.82 1.00
RX 9060 XT llama 8B F16 128 pp512 360.94 360.26 1.00
RX 9060 XT llama 8B F16 256 pp512 367.29 366.66 1.00
RX 9060 XT llama 8B F16 512 pp512 369.98 369.52 1.00
The performance declines for batch sizes 16 when rocBLAS is used instead. But as it is ggml doesn't have a floating-point GEMM kernel to handle that case. The original mul_mat_f kernel was only ever designed for very thin matrices. And while the copied mul_mat_f_ids kernel can handle larger batch sizes, it is suboptimally written for that because it took over the memory access patterns from mul_mat_f.

Sorry about the branch issue, I'm still not very familiar with public github repo.

The reason why test case cannot pass is that rocm compiler doesn't generate the correct code, I've seen it many times before on RDNA like broken code and wrong result.

Since 9060 gets much performance improvement, the only two solution I can have are

  • enable the weird code on gfx1201 only, but the risk is high.
  • enable mul_mat_f on RDNA4 without the weird code, then I submit a bug to ROCm about the performance drop, that's more reasonable but not sure if you can accept perf drop.

@JohannesGaessler
Copy link
Collaborator

My preferred solution would be to merge the code without the weird fix. As it is I do not have the means to test it so I would only be fine with keeping it if you commit to long-term llama.cpp/ggml maintenance.

@zhang-hui-yulo
Copy link
Contributor Author

zhang-hui-yulo commented Nov 26, 2025

My preferred solution would be to merge the code without the weird fix. As it is I do not have the means to test it so I would only be fine with keeping it if you commit to long-term llama.cpp/ggml maintenance.

I also prefer to remove the weird code and accept the performance drop on 9070XT then submit the bug to rocm compiler as it's rocm bug, also I'm not sure if this piece code will crash on other version of rocm, I don't want to take the risk.

If you agree, I will just enable mul_mat_f for RDNA4 and remove other parts including the weird code and mmvf.

For future maintenance plan, at least I will spend the full 2026 to optimize the performance of llama.cpp on RDNA and CDNA3 (if I can get one, I'm trying now) if there is no situation out of my control.

@zhang-hui-yulo
Copy link
Contributor Author

Please rebase your branches instead of merging master into them, it makes it easier to work with. I cherry-picked the relevant commits to a branch enable_mmf_for_rdna4-2 on my fork of llama.cpp that I use for development: https://github.com/JohannesGaessler/llama.cpp/tree/enable_mmf_for_rdna4-2 . With the weird unreached branch I get:

ggml_cuda_compute_forward: MUL_MAT failed
ROCm error: the operation cannot be performed in the present state
  current device: 0, in function ggml_cuda_compute_forward at /home/johannesg/Projects/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2727

I don't know why this is happening but we cannot merge broken code onto master. When I remove the unreached branch the code works correctly and the performance looks good as well:

GPU Model Microbatch size Test t/s master t/s d4e50a4 Speedup
RX 9060 XT llama 8B BF16 1 pp512 20.74 20.77 1.00
RX 9060 XT llama 8B BF16 2 pp512 40.77 40.81 1.00
RX 9060 XT llama 8B BF16 4 pp512 6.58 78.82 11.99
RX 9060 XT llama 8B BF16 8 pp512 12.91 137.39 10.65
RX 9060 XT llama 8B BF16 16 pp512 25.91 213.08 8.22
RX 9060 XT llama 8B BF16 32 pp512 51.66 51.39 0.99
RX 9060 XT llama 8B BF16 64 pp512 103.31 102.77 0.99
RX 9060 XT llama 8B BF16 128 pp512 333.85 333.37 1.00
RX 9060 XT llama 8B BF16 256 pp512 339.94 339.25 1.00
RX 9060 XT llama 8B BF16 512 pp512 343.15 342.42 1.00
RX 9060 XT llama 8B F16 1 pp512 20.72 20.71 1.00
RX 9060 XT llama 8B F16 2 pp512 40.41 40.37 1.00
RX 9060 XT llama 8B F16 4 pp512 71.99 65.17 0.91
RX 9060 XT llama 8B F16 8 pp512 43.85 120.91 2.76
RX 9060 XT llama 8B F16 16 pp512 85.50 205.97 2.41
RX 9060 XT llama 8B F16 32 pp512 168.99 168.66 1.00
RX 9060 XT llama 8B F16 64 pp512 206.55 205.82 1.00
RX 9060 XT llama 8B F16 128 pp512 360.94 360.26 1.00
RX 9060 XT llama 8B F16 256 pp512 367.29 366.66 1.00
RX 9060 XT llama 8B F16 512 pp512 369.98 369.52 1.00
The performance declines for batch sizes 16 when rocBLAS is used instead. But as it is ggml doesn't have a floating-point GEMM kernel to handle that case. The original mul_mat_f kernel was only ever designed for very thin matrices. And while the copied mul_mat_f_ids kernel can handle larger batch sizes, it is suboptimally written for that because it took over the memory access patterns from mul_mat_f.

Just enable mul_mat_f for RDNA4 based on the data of 9060, will raise a bug to ROCm for 9070XT once this PR is merged.

@zhang-hui-yulo
Copy link
Contributor Author

Attach the test result from test-backend-ops
Uploading mul_mat.txt…

@zhang-hui-yulo
Copy link
Contributor Author

Hello @JohannesGaessler , could you approve the PR if the data is good enough on you 9060? Thank you.

@JohannesGaessler JohannesGaessler merged commit 6bca76f into ggml-org:master Nov 28, 2025
70 of 74 checks passed
@zhang-hui-yulo zhang-hui-yulo deleted the enable_mmf_for_rdna4 branch November 30, 2025 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants