Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Jul 10, 2024

This PR optimizes MMQ performance and refactors the code in preparation for i-quants. Brief summary of changes:

  • The k dimension of the shared memory tiles for matrix multiplication is set to 256 logical values. Previously the k dimension was such that each warp would load a single 32 bit value and then unpack it. While this is beneficial for memory bandwidth it causes problems with shared memory and register use being very different between quantization formats which makes optimization much harder. The affected formats are q2_K, q3_K, and q8_0.
  • The loop structure for tensor core kernels is optimized in such a way that reduces register pressure which in turn allows for more aggressive loop unrolling.
  • q8_0 and q5_0 now use the same code except for loading the data.
  • The q8_1 activations for q2_K now use a single FP16 scale per 64 values instead of per 32 values in order to fit more pre-quantization partial sums per 16 values. For the vector dot product 6/8 blocks of size 16 can be sped up (instead of 4/8 with 1 scale per 32 values).
  • The auxiliary kernels for quantizing the activations and the stream-k fixup have received optimizations.

Precision loss for q2_K:

Model imatrix Code PPL KL Divergence vs. FP16 Mean Δp RMS Δp
LLaMA 3 q2_K_M WT 10m master 8.730622 0.342380 -6.776% 19.187%
LLaMA 3 q2_K_M WT 10m PR 8.734961 0.342814 -6.774% 19.190%
LLaMA 3 q2_K_S WT 10m master 9.371474 0.408614 -7.289% 20.149%
LLaMA 3 q2_K_S WT 10m PR 9.378482 0.409435 -7.319% 20.175%

I would argue that this is negligible.

Performance changes
Model GPU Microbatch size Test t/s 6b2a849 t/s cuda-mmq-256k-5 Speedup
llama 8B IQ4_NL - 4.5 bpw RX 6800 16 pp2048 236.52 237.22 1.00
llama 8B IQ4_NL - 4.5 bpw RX 6800 32 pp2048 328.45 332.38 1.01
llama 8B IQ4_NL - 4.5 bpw RX 6800 64 pp2048 404.76 411.18 1.02
llama 8B IQ4_NL - 4.5 bpw RX 6800 128 pp2048 509.28 516.32 1.01
llama 8B IQ4_NL - 4.5 bpw RX 6800 256 pp2048 603.86 617.01 1.02
llama 8B IQ4_NL - 4.5 bpw RX 6800 512 pp2048 612.29 621.83 1.02
llama 8B IQ4_NL - 4.5 bpw RX 6800 1024 pp2048 686.38 700.95 1.02
llama 8B IQ4_NL - 4.5 bpw RX 6800 2048 pp2048 627.16 640.30 1.02
llama 8B IQ4_NL - 4.5 bpw RTX 3090 16 pp2048 1084.71 1088.77 1.00
llama 8B IQ4_NL - 4.5 bpw RTX 3090 32 pp2048 1760.23 1787.89 1.02
llama 8B IQ4_NL - 4.5 bpw RTX 3090 64 pp2048 2678.81 2776.67 1.04
llama 8B IQ4_NL - 4.5 bpw RTX 3090 128 pp2048 3306.59 3585.50 1.08
llama 8B IQ4_NL - 4.5 bpw RTX 3090 256 pp2048 3720.80 4095.23 1.10
llama 8B IQ4_NL - 4.5 bpw RTX 3090 512 pp2048 3852.71 4298.42 1.12
llama 8B IQ4_NL - 4.5 bpw RTX 3090 1024 pp2048 3886.39 4341.15 1.12
llama 8B IQ4_NL - 4.5 bpw RTX 3090 2048 pp2048 3775.64 4201.85 1.11
llama 8B IQ4_NL - 4.5 bpw RTX 4090 16 pp2048 1897.16 1950.39 1.03
llama 8B IQ4_NL - 4.5 bpw RTX 4090 32 pp2048 3343.59 3428.25 1.03
llama 8B IQ4_NL - 4.5 bpw RTX 4090 64 pp2048 5449.91 5671.69 1.04
llama 8B IQ4_NL - 4.5 bpw RTX 4090 128 pp2048 7380.43 7964.68 1.08
llama 8B IQ4_NL - 4.5 bpw RTX 4090 256 pp2048 9588.41 10556.45 1.10
llama 8B IQ4_NL - 4.5 bpw RTX 4090 512 pp2048 10303.19 11383.62 1.10
llama 8B IQ4_NL - 4.5 bpw RTX 4090 1024 pp2048 10008.66 11077.53 1.11
llama 8B IQ4_NL - 4.5 bpw RTX 4090 2048 pp2048 9077.18 9842.67 1.08
llama 8B IQ4_NL - 4.5 bpw P40 16 pp2048 249.55 253.31 1.02
llama 8B IQ4_NL - 4.5 bpw P40 32 pp2048 417.65 416.95 1.00
llama 8B IQ4_NL - 4.5 bpw P40 64 pp2048 572.82 582.69 1.02
llama 8B IQ4_NL - 4.5 bpw P40 128 pp2048 680.85 690.43 1.01
llama 8B IQ4_NL - 4.5 bpw P40 256 pp2048 768.59 780.00 1.01
llama 8B IQ4_NL - 4.5 bpw P40 512 pp2048 792.24 805.79 1.02
llama 8B IQ4_NL - 4.5 bpw P40 1024 pp2048 760.64 777.41 1.02
llama 8B IQ4_NL - 4.5 bpw P40 2048 pp2048 723.52 724.94 1.00
llama 8B IQ4_XS - 4.25 bpw RX 6800 16 pp2048 235.71 235.49 1.00
llama 8B IQ4_XS - 4.25 bpw RX 6800 32 pp2048 331.57 333.53 1.01
llama 8B IQ4_XS - 4.25 bpw RX 6800 64 pp2048 405.95 410.99 1.01
llama 8B IQ4_XS - 4.25 bpw RX 6800 128 pp2048 510.76 520.78 1.02
llama 8B IQ4_XS - 4.25 bpw RX 6800 256 pp2048 608.01 619.50 1.02
llama 8B IQ4_XS - 4.25 bpw RX 6800 512 pp2048 614.68 624.71 1.02
llama 8B IQ4_XS - 4.25 bpw RX 6800 1024 pp2048 690.04 702.35 1.02
llama 8B IQ4_XS - 4.25 bpw RX 6800 2048 pp2048 627.83 639.34 1.02
llama 8B IQ4_XS - 4.25 bpw RTX 3090 16 pp2048 1092.11 1112.37 1.02
llama 8B IQ4_XS - 4.25 bpw RTX 3090 32 pp2048 1756.50 1839.03 1.05
llama 8B IQ4_XS - 4.25 bpw RTX 3090 64 pp2048 2665.05 2791.84 1.05
llama 8B IQ4_XS - 4.25 bpw RTX 3090 128 pp2048 3253.58 3600.24 1.11
llama 8B IQ4_XS - 4.25 bpw RTX 3090 256 pp2048 3644.70 4113.13 1.13
llama 8B IQ4_XS - 4.25 bpw RTX 3090 512 pp2048 3779.13 4314.15 1.14
llama 8B IQ4_XS - 4.25 bpw RTX 3090 1024 pp2048 3793.53 4348.22 1.15
llama 8B IQ4_XS - 4.25 bpw RTX 3090 2048 pp2048 3686.52 4217.29 1.14
llama 8B IQ4_XS - 4.25 bpw RTX 4090 16 pp2048 1963.94 2014.78 1.03
llama 8B IQ4_XS - 4.25 bpw RTX 4090 32 pp2048 3408.65 3550.23 1.04
llama 8B IQ4_XS - 4.25 bpw RTX 4090 64 pp2048 5569.97 5814.91 1.04
llama 8B IQ4_XS - 4.25 bpw RTX 4090 128 pp2048 7355.95 8064.64 1.10
llama 8B IQ4_XS - 4.25 bpw RTX 4090 256 pp2048 9445.91 10645.13 1.13
llama 8B IQ4_XS - 4.25 bpw RTX 4090 512 pp2048 10131.50 11475.40 1.13
llama 8B IQ4_XS - 4.25 bpw RTX 4090 1024 pp2048 9808.51 11156.90 1.14
llama 8B IQ4_XS - 4.25 bpw RTX 4090 2048 pp2048 8951.82 10025.18 1.12
llama 8B IQ4_XS - 4.25 bpw P40 16 pp2048 258.44 256.69 0.99
llama 8B IQ4_XS - 4.25 bpw P40 32 pp2048 418.49 421.41 1.01
llama 8B IQ4_XS - 4.25 bpw P40 64 pp2048 581.16 596.30 1.03
llama 8B IQ4_XS - 4.25 bpw P40 128 pp2048 686.45 703.01 1.02
llama 8B IQ4_XS - 4.25 bpw P40 256 pp2048 773.25 794.03 1.03
llama 8B IQ4_XS - 4.25 bpw P40 512 pp2048 806.87 828.14 1.03
llama 8B IQ4_XS - 4.25 bpw P40 1024 pp2048 763.50 795.61 1.04
llama 8B IQ4_XS - 4.25 bpw P40 2048 pp2048 732.66 744.06 1.02
llama 8B Q2_K_M RX 6800 16 pp2048 126.96 184.08 1.45
llama 8B Q2_K_M RX 6800 32 pp2048 188.46 261.40 1.39
llama 8B Q2_K_M RX 6800 64 pp2048 241.87 286.11 1.18
llama 8B Q2_K_M RX 6800 128 pp2048 299.51 354.08 1.18
llama 8B Q2_K_M RX 6800 256 pp2048 364.77 425.51 1.17
llama 8B Q2_K_M RX 6800 512 pp2048 383.76 439.03 1.14
llama 8B Q2_K_M RX 6800 1024 pp2048 416.47 481.05 1.16
llama 8B Q2_K_M RX 6800 2048 pp2048 398.74 455.77 1.14
llama 8B Q2_K_M RTX 3090 16 pp2048 1081.38 1228.22 1.14
llama 8B Q2_K_M RTX 3090 32 pp2048 1561.28 1868.28 1.20
llama 8B Q2_K_M RTX 3090 64 pp2048 2146.91 2516.09 1.17
llama 8B Q2_K_M RTX 3090 128 pp2048 2651.83 2577.82 0.97
llama 8B Q2_K_M RTX 3090 256 pp2048 2971.92 3081.47 1.04
llama 8B Q2_K_M RTX 3090 512 pp2048 3112.39 3353.02 1.08
llama 8B Q2_K_M RTX 3090 1024 pp2048 3134.75 3505.83 1.12
llama 8B Q2_K_M RTX 3090 2048 pp2048 3093.43 3451.92 1.12
llama 8B Q2_K_M RTX 4090 16 pp2048 2009.40 2132.92 1.06
llama 8B Q2_K_M RTX 4090 32 pp2048 3077.56 3637.09 1.18
llama 8B Q2_K_M RTX 4090 64 pp2048 4608.38 5385.63 1.17
llama 8B Q2_K_M RTX 4090 128 pp2048 6048.25 5467.48 0.90
llama 8B Q2_K_M RTX 4090 256 pp2048 7741.86 7518.69 0.97
llama 8B Q2_K_M RTX 4090 512 pp2048 8270.00 9055.11 1.09
llama 8B Q2_K_M RTX 4090 1024 pp2048 8143.32 9159.31 1.12
llama 8B Q2_K_M RTX 4090 2048 pp2048 7539.51 8445.80 1.12
llama 8B Q2_K_M P40 16 pp2048 281.13 332.60 1.18
llama 8B Q2_K_M P40 32 pp2048 404.55 504.05 1.25
llama 8B Q2_K_M P40 64 pp2048 497.27 573.08 1.15
llama 8B Q2_K_M P40 128 pp2048 604.08 664.99 1.10
llama 8B Q2_K_M P40 256 pp2048 674.40 749.42 1.11
llama 8B Q2_K_M P40 512 pp2048 685.04 788.49 1.15
llama 8B Q2_K_M P40 1024 pp2048 666.09 774.29 1.16
llama 8B Q2_K_M P40 2048 pp2048 629.81 724.67 1.15
llama 8B Q3_K_S RX 6800 16 pp2048 97.78 218.59 2.24
llama 8B Q3_K_S RX 6800 32 pp2048 151.35 294.39 1.95
llama 8B Q3_K_S RX 6800 64 pp2048 202.30 318.56 1.57
llama 8B Q3_K_S RX 6800 128 pp2048 250.39 397.31 1.59
llama 8B Q3_K_S RX 6800 256 pp2048 302.50 471.41 1.56
llama 8B Q3_K_S RX 6800 512 pp2048 319.38 483.63 1.51
llama 8B Q3_K_S RX 6800 1024 pp2048 343.96 531.08 1.54
llama 8B Q3_K_S RX 6800 2048 pp2048 332.81 495.86 1.49
llama 8B Q3_K_S RTX 3090 16 pp2048 1114.12 1169.83 1.05
llama 8B Q3_K_S RTX 3090 32 pp2048 1601.18 1890.46 1.18
llama 8B Q3_K_S RTX 3090 64 pp2048 2163.78 2734.38 1.26
llama 8B Q3_K_S RTX 3090 128 pp2048 2890.74 3226.13 1.12
llama 8B Q3_K_S RTX 3090 256 pp2048 3191.09 3680.64 1.15
llama 8B Q3_K_S RTX 3090 512 pp2048 3316.23 3894.79 1.17
llama 8B Q3_K_S RTX 3090 1024 pp2048 3334.05 3977.03 1.19
llama 8B Q3_K_S RTX 3090 2048 pp2048 3269.66 3879.93 1.19
llama 8B Q3_K_S RTX 4090 16 pp2048 1883.09 1908.42 1.01
llama 8B Q3_K_S RTX 4090 32 pp2048 3029.86 3615.06 1.19
llama 8B Q3_K_S RTX 4090 64 pp2048 4480.88 5639.45 1.26
llama 8B Q3_K_S RTX 4090 128 pp2048 6515.56 7656.04 1.18
llama 8B Q3_K_S RTX 4090 256 pp2048 8307.18 9696.19 1.17
llama 8B Q3_K_S RTX 4090 512 pp2048 8762.30 10233.61 1.17
llama 8B Q3_K_S RTX 4090 1024 pp2048 8568.31 10104.30 1.18
llama 8B Q3_K_S RTX 4090 2048 pp2048 7907.97 9156.33 1.16
llama 8B Q3_K_S P40 16 pp2048 220.20 348.10 1.58
llama 8B Q3_K_S P40 32 pp2048 319.66 491.34 1.54
llama 8B Q3_K_S P40 64 pp2048 427.35 580.71 1.36
llama 8B Q3_K_S P40 128 pp2048 512.99 657.29 1.28
llama 8B Q3_K_S P40 256 pp2048 546.78 701.47 1.28
llama 8B Q3_K_S P40 512 pp2048 559.32 706.64 1.26
llama 8B Q3_K_S P40 1024 pp2048 545.48 677.40 1.24
llama 8B Q3_K_S P40 2048 pp2048 524.01 656.32 1.25
llama 8B Q4_0 RX 6800 16 pp2048 267.42 269.53 1.01
llama 8B Q4_0 RX 6800 32 pp2048 371.21 372.44 1.00
llama 8B Q4_0 RX 6800 64 pp2048 431.24 435.60 1.01
llama 8B Q4_0 RX 6800 128 pp2048 538.75 548.28 1.02
llama 8B Q4_0 RX 6800 256 pp2048 635.11 647.32 1.02
llama 8B Q4_0 RX 6800 512 pp2048 638.26 648.59 1.02
llama 8B Q4_0 RX 6800 1024 pp2048 717.49 731.47 1.02
llama 8B Q4_0 RX 6800 2048 pp2048 650.85 662.03 1.02
llama 8B Q4_0 RTX 3090 16 pp2048 1281.27 1306.26 1.02
llama 8B Q4_0 RTX 3090 32 pp2048 1935.28 2088.34 1.08
llama 8B Q4_0 RTX 3090 64 pp2048 2711.35 2859.45 1.05
llama 8B Q4_0 RTX 3090 128 pp2048 3499.77 3673.60 1.05
llama 8B Q4_0 RTX 3090 256 pp2048 3916.27 4194.32 1.07
llama 8B Q4_0 RTX 3090 512 pp2048 4179.44 4404.74 1.05
llama 8B Q4_0 RTX 3090 1024 pp2048 4220.53 4459.62 1.06
llama 8B Q4_0 RTX 3090 2048 pp2048 4074.92 4306.90 1.06
llama 8B Q4_0 RTX 4090 16 pp2048 1913.70 1970.75 1.03
llama 8B Q4_0 RTX 4090 32 pp2048 3391.11 3547.00 1.05
llama 8B Q4_0 RTX 4090 64 pp2048 5319.52 5662.40 1.06
llama 8B Q4_0 RTX 4090 128 pp2048 7420.40 8062.83 1.09
llama 8B Q4_0 RTX 4090 256 pp2048 9683.83 10592.87 1.09
llama 8B Q4_0 RTX 4090 512 pp2048 10723.48 11469.81 1.07
llama 8B Q4_0 RTX 4090 1024 pp2048 10498.61 11155.82 1.06
llama 8B Q4_0 RTX 4090 2048 pp2048 9539.28 10074.84 1.06
llama 8B Q4_0 P40 16 pp2048 441.42 444.28 1.01
llama 8B Q4_0 P40 32 pp2048 623.30 621.07 1.00
llama 8B Q4_0 P40 64 pp2048 679.84 683.57 1.01
llama 8B Q4_0 P40 128 pp2048 797.60 803.85 1.01
llama 8B Q4_0 P40 256 pp2048 886.51 890.30 1.00
llama 8B Q4_0 P40 512 pp2048 927.27 930.13 1.00
llama 8B Q4_0 P40 1024 pp2048 906.63 913.20 1.01
llama 8B Q4_0 P40 2048 pp2048 861.62 865.13 1.00
llama 8B Q4_1 RX 6800 16 pp2048 250.31 249.05 0.99
llama 8B Q4_1 RX 6800 32 pp2048 345.15 349.96 1.01
llama 8B Q4_1 RX 6800 64 pp2048 403.89 403.89 1.00
llama 8B Q4_1 RX 6800 128 pp2048 506.08 509.53 1.01
llama 8B Q4_1 RX 6800 256 pp2048 597.20 603.53 1.01
llama 8B Q4_1 RX 6800 512 pp2048 605.02 611.31 1.01
llama 8B Q4_1 RX 6800 1024 pp2048 674.91 683.83 1.01
llama 8B Q4_1 RX 6800 2048 pp2048 616.99 624.88 1.01
llama 8B Q4_1 RTX 3090 16 pp2048 1392.57 1400.65 1.01
llama 8B Q4_1 RTX 3090 32 pp2048 1782.54 2138.78 1.20
llama 8B Q4_1 RTX 3090 64 pp2048 2747.31 2931.09 1.07
llama 8B Q4_1 RTX 3090 128 pp2048 3351.07 3294.11 0.98
llama 8B Q4_1 RTX 3090 256 pp2048 3739.25 3764.51 1.01
llama 8B Q4_1 RTX 3090 512 pp2048 3901.46 4013.45 1.03
llama 8B Q4_1 RTX 3090 1024 pp2048 3909.61 4098.31 1.05
llama 8B Q4_1 RTX 3090 2048 pp2048 3785.17 3997.97 1.06
llama 8B Q4_1 RTX 4090 16 pp2048 1813.98 1852.97 1.02
llama 8B Q4_1 RTX 4090 32 pp2048 2945.00 3356.13 1.14
llama 8B Q4_1 RTX 4090 64 pp2048 5398.91 5665.09 1.05
llama 8B Q4_1 RTX 4090 128 pp2048 7280.85 7536.47 1.04
llama 8B Q4_1 RTX 4090 256 pp2048 9357.25 9716.65 1.04
llama 8B Q4_1 RTX 4090 512 pp2048 10040.79 10494.50 1.05
llama 8B Q4_1 RTX 4090 1024 pp2048 9814.30 10393.62 1.06
llama 8B Q4_1 RTX 4090 2048 pp2048 8964.59 9337.89 1.04
llama 8B Q4_1 P40 16 pp2048 450.16 451.50 1.00
llama 8B Q4_1 P40 32 pp2048 612.08 612.42 1.00
llama 8B Q4_1 P40 64 pp2048 666.95 674.17 1.01
llama 8B Q4_1 P40 128 pp2048 785.71 793.62 1.01
llama 8B Q4_1 P40 256 pp2048 865.36 876.04 1.01
llama 8B Q4_1 P40 512 pp2048 902.32 914.56 1.01
llama 8B Q4_1 P40 1024 pp2048 886.82 896.99 1.01
llama 8B Q4_1 P40 2048 pp2048 842.85 851.55 1.01
llama 8B Q4_K_S RX 6800 16 pp2048 231.04 231.56 1.00
llama 8B Q4_K_S RX 6800 32 pp2048 296.76 302.02 1.02
llama 8B Q4_K_S RX 6800 64 pp2048 330.54 337.78 1.02
llama 8B Q4_K_S RX 6800 128 pp2048 407.30 418.60 1.03
llama 8B Q4_K_S RX 6800 256 pp2048 490.15 505.31 1.03
llama 8B Q4_K_S RX 6800 512 pp2048 504.09 518.25 1.03
llama 8B Q4_K_S RX 6800 1024 pp2048 557.35 574.65 1.03
llama 8B Q4_K_S RX 6800 2048 pp2048 519.62 533.56 1.03
llama 8B Q4_K_S RTX 3090 16 pp2048 1376.25 1408.99 1.02
llama 8B Q4_K_S RTX 3090 32 pp2048 2030.30 2157.70 1.06
llama 8B Q4_K_S RTX 3090 64 pp2048 2652.98 2899.88 1.09
llama 8B Q4_K_S RTX 3090 128 pp2048 3251.68 3437.08 1.06
llama 8B Q4_K_S RTX 3090 256 pp2048 3623.10 3934.53 1.09
llama 8B Q4_K_S RTX 3090 512 pp2048 3771.06 4181.71 1.11
llama 8B Q4_K_S RTX 3090 1024 pp2048 3831.95 4284.17 1.12
llama 8B Q4_K_S RTX 3090 2048 pp2048 3737.99 4163.34 1.11
llama 8B Q4_K_S RTX 4090 16 pp2048 2003.27 2041.99 1.02
llama 8B Q4_K_S RTX 4090 32 pp2048 3531.27 3647.90 1.03
llama 8B Q4_K_S RTX 4090 64 pp2048 5458.35 5893.00 1.08
llama 8B Q4_K_S RTX 4090 128 pp2048 7328.10 7902.76 1.08
llama 8B Q4_K_S RTX 4090 256 pp2048 9284.65 10190.76 1.10
llama 8B Q4_K_S RTX 4090 512 pp2048 9954.41 11032.18 1.11
llama 8B Q4_K_S RTX 4090 1024 pp2048 9787.04 10898.24 1.11
llama 8B Q4_K_S RTX 4090 2048 pp2048 9005.04 9864.98 1.10
llama 8B Q4_K_S P40 16 pp2048 406.29 406.26 1.00
llama 8B Q4_K_S P40 32 pp2048 503.41 512.72 1.02
llama 8B Q4_K_S P40 64 pp2048 619.38 627.37 1.01
llama 8B Q4_K_S P40 128 pp2048 729.17 734.96 1.01
llama 8B Q4_K_S P40 256 pp2048 779.99 788.10 1.01
llama 8B Q4_K_S P40 512 pp2048 791.46 799.43 1.01
llama 8B Q4_K_S P40 1024 pp2048 766.79 780.88 1.02
llama 8B Q4_K_S P40 2048 pp2048 725.33 733.57 1.01
llama 8B Q5_0 RX 6800 16 pp2048 226.09 226.36 1.00
llama 8B Q5_0 RX 6800 32 pp2048 330.01 335.07 1.02
llama 8B Q5_0 RX 6800 64 pp2048 402.30 407.39 1.01
llama 8B Q5_0 RX 6800 128 pp2048 501.50 510.97 1.02
llama 8B Q5_0 RX 6800 256 pp2048 586.79 600.49 1.02
llama 8B Q5_0 RX 6800 512 pp2048 592.98 605.56 1.02
llama 8B Q5_0 RX 6800 1024 pp2048 663.32 678.13 1.02
llama 8B Q5_0 RX 6800 2048 pp2048 606.36 620.93 1.02
llama 8B Q5_0 RTX 3090 16 pp2048 1085.35 1037.79 0.96
llama 8B Q5_0 RTX 3090 32 pp2048 1839.64 1861.96 1.01
llama 8B Q5_0 RTX 3090 64 pp2048 2810.22 2869.14 1.02
llama 8B Q5_0 RTX 3090 128 pp2048 3443.30 3704.81 1.08
llama 8B Q5_0 RTX 3090 256 pp2048 3896.08 4218.52 1.08
llama 8B Q5_0 RTX 3090 512 pp2048 4077.99 4455.14 1.09
llama 8B Q5_0 RTX 3090 1024 pp2048 4088.51 4549.91 1.11
llama 8B Q5_0 RTX 3090 2048 pp2048 3949.49 4362.42 1.10
llama 8B Q5_0 RTX 4090 16 pp2048 1644.45 1630.53 0.99
llama 8B Q5_0 RTX 4090 32 pp2048 3024.16 3101.52 1.03
llama 8B Q5_0 RTX 4090 64 pp2048 5173.85 5349.46 1.03
llama 8B Q5_0 RTX 4090 128 pp2048 7218.09 7957.21 1.10
llama 8B Q5_0 RTX 4090 256 pp2048 9491.88 10658.20 1.12
llama 8B Q5_0 RTX 4090 512 pp2048 10487.57 11545.50 1.10
llama 8B Q5_0 RTX 4090 1024 pp2048 10270.52 11383.14 1.11
llama 8B Q5_0 RTX 4090 2048 pp2048 9419.49 10201.96 1.08
llama 8B Q5_0 P40 16 pp2048 355.58 359.96 1.01
llama 8B Q5_0 P40 32 pp2048 531.72 533.27 1.00
llama 8B Q5_0 P40 64 pp2048 617.74 629.08 1.02
llama 8B Q5_0 P40 128 pp2048 717.33 731.98 1.02
llama 8B Q5_0 P40 256 pp2048 800.06 817.11 1.02
llama 8B Q5_0 P40 512 pp2048 838.73 857.04 1.02
llama 8B Q5_0 P40 1024 pp2048 827.11 845.22 1.02
llama 8B Q5_0 P40 2048 pp2048 788.75 803.56 1.02
llama 8B Q5_1 RX 6800 16 pp2048 222.78 223.49 1.00
llama 8B Q5_1 RX 6800 32 pp2048 322.53 323.28 1.00
llama 8B Q5_1 RX 6800 64 pp2048 391.56 395.53 1.01
llama 8B Q5_1 RX 6800 128 pp2048 492.43 500.66 1.02
llama 8B Q5_1 RX 6800 256 pp2048 581.63 591.55 1.02
llama 8B Q5_1 RX 6800 512 pp2048 591.61 599.74 1.01
llama 8B Q5_1 RX 6800 1024 pp2048 660.05 672.26 1.02
llama 8B Q5_1 RX 6800 2048 pp2048 604.36 613.82 1.02
llama 8B Q5_1 RTX 3090 16 pp2048 1105.27 1107.41 1.00
llama 8B Q5_1 RTX 3090 32 pp2048 1632.07 1768.71 1.08
llama 8B Q5_1 RTX 3090 64 pp2048 2479.76 2684.07 1.08
llama 8B Q5_1 RTX 3090 128 pp2048 3261.49 3085.09 0.95
llama 8B Q5_1 RTX 3090 256 pp2048 3642.14 3581.00 0.98
llama 8B Q5_1 RTX 3090 512 pp2048 3804.44 3812.90 1.00
llama 8B Q5_1 RTX 3090 1024 pp2048 3824.69 3916.59 1.02
llama 8B Q5_1 RTX 3090 2048 pp2048 3729.99 3872.03 1.04
llama 8B Q5_1 RTX 4090 16 pp2048 1572.05 1606.43 1.02
llama 8B Q5_1 RTX 4090 32 pp2048 2672.81 2860.89 1.07
llama 8B Q5_1 RTX 4090 64 pp2048 4440.16 5273.79 1.19
llama 8B Q5_1 RTX 4090 128 pp2048 7094.02 7149.64 1.01
llama 8B Q5_1 RTX 4090 256 pp2048 9060.70 9178.58 1.01
llama 8B Q5_1 RTX 4090 512 pp2048 9861.47 10070.20 1.02
llama 8B Q5_1 RTX 4090 1024 pp2048 9746.59 10128.23 1.04
llama 8B Q5_1 RTX 4090 2048 pp2048 8897.90 9259.38 1.04
llama 8B Q5_1 P40 16 pp2048 375.78 380.30 1.01
llama 8B Q5_1 P40 32 pp2048 550.57 561.59 1.02
llama 8B Q5_1 P40 64 pp2048 633.51 635.96 1.00
llama 8B Q5_1 P40 128 pp2048 732.92 735.79 1.00
llama 8B Q5_1 P40 256 pp2048 819.88 820.69 1.00
llama 8B Q5_1 P40 512 pp2048 856.29 857.99 1.00
llama 8B Q5_1 P40 1024 pp2048 838.17 845.51 1.01
llama 8B Q5_1 P40 2048 pp2048 794.40 803.97 1.01
llama 8B Q5_K_S RX 6800 16 pp2048 221.02 222.01 1.00
llama 8B Q5_K_S RX 6800 32 pp2048 304.83 305.84 1.00
llama 8B Q5_K_S RX 6800 64 pp2048 326.12 334.29 1.03
llama 8B Q5_K_S RX 6800 128 pp2048 399.84 412.08 1.03
llama 8B Q5_K_S RX 6800 256 pp2048 482.44 499.21 1.03
llama 8B Q5_K_S RX 6800 512 pp2048 496.96 513.65 1.03
llama 8B Q5_K_S RX 6800 1024 pp2048 549.13 568.78 1.04
llama 8B Q5_K_S RX 6800 2048 pp2048 511.45 528.45 1.03
llama 8B Q5_K_S RTX 3090 16 pp2048 1161.60 1182.85 1.02
llama 8B Q5_K_S RTX 3090 32 pp2048 1845.83 1977.99 1.07
llama 8B Q5_K_S RTX 3090 64 pp2048 2436.93 2747.16 1.13
llama 8B Q5_K_S RTX 3090 128 pp2048 3070.83 3305.02 1.08
llama 8B Q5_K_S RTX 3090 256 pp2048 3424.08 3811.23 1.11
llama 8B Q5_K_S RTX 3090 512 pp2048 3566.58 4068.29 1.14
llama 8B Q5_K_S RTX 3090 1024 pp2048 3606.63 4155.53 1.15
llama 8B Q5_K_S RTX 3090 2048 pp2048 3545.85 4037.04 1.14
llama 8B Q5_K_S RTX 4090 16 pp2048 1712.71 1749.88 1.02
llama 8B Q5_K_S RTX 4090 32 pp2048 3161.70 3265.98 1.03
llama 8B Q5_K_S RTX 4090 64 pp2048 5040.32 5521.94 1.10
llama 8B Q5_K_S RTX 4090 128 pp2048 6886.68 7644.84 1.11
llama 8B Q5_K_S RTX 4090 256 pp2048 8726.65 9782.59 1.12
llama 8B Q5_K_S RTX 4090 512 pp2048 9422.20 10634.01 1.13
llama 8B Q5_K_S RTX 4090 1024 pp2048 9329.98 10606.55 1.14
llama 8B Q5_K_S RTX 4090 2048 pp2048 8621.52 9589.72 1.11
llama 8B Q5_K_S P40 16 pp2048 350.42 361.75 1.03
llama 8B Q5_K_S P40 32 pp2048 449.11 470.73 1.05
llama 8B Q5_K_S P40 64 pp2048 590.43 611.75 1.04
llama 8B Q5_K_S P40 128 pp2048 694.95 714.26 1.03
llama 8B Q5_K_S P40 256 pp2048 755.47 757.14 1.00
llama 8B Q5_K_S P40 512 pp2048 765.46 765.82 1.00
llama 8B Q5_K_S P40 1024 pp2048 742.85 748.53 1.01
llama 8B Q5_K_S P40 2048 pp2048 695.89 712.74 1.02
llama 8B Q6_K RX 6800 16 pp2048 207.37 213.56 1.03
llama 8B Q6_K RX 6800 32 pp2048 283.20 277.09 0.98
llama 8B Q6_K RX 6800 64 pp2048 287.87 301.34 1.05
llama 8B Q6_K RX 6800 128 pp2048 351.28 372.26 1.06
llama 8B Q6_K RX 6800 256 pp2048 423.73 447.08 1.06
llama 8B Q6_K RX 6800 512 pp2048 440.74 462.90 1.05
llama 8B Q6_K RX 6800 1024 pp2048 481.65 507.73 1.05
llama 8B Q6_K RX 6800 2048 pp2048 453.79 476.32 1.05
llama 8B Q6_K RTX 3090 16 pp2048 1008.80 1038.01 1.03
llama 8B Q6_K RTX 3090 32 pp2048 1740.13 1784.98 1.03
llama 8B Q6_K RTX 3090 64 pp2048 2568.33 2611.05 1.02
llama 8B Q6_K RTX 3090 128 pp2048 3129.10 3270.89 1.05
llama 8B Q6_K RTX 3090 256 pp2048 3537.01 3777.64 1.07
llama 8B Q6_K RTX 3090 512 pp2048 3696.48 3980.75 1.08
llama 8B Q6_K RTX 3090 1024 pp2048 3724.88 4026.53 1.08
llama 8B Q6_K RTX 3090 2048 pp2048 3596.34 3892.68 1.08
llama 8B Q6_K RTX 4090 16 pp2048 1452.60 1479.91 1.02
llama 8B Q6_K RTX 4090 32 pp2048 2785.91 2842.04 1.02
llama 8B Q6_K RTX 4090 64 pp2048 4710.12 4555.03 0.97
llama 8B Q6_K RTX 4090 128 pp2048 6728.92 7180.00 1.07
llama 8B Q6_K RTX 4090 256 pp2048 8717.74 9366.04 1.07
llama 8B Q6_K RTX 4090 512 pp2048 9488.11 10199.94 1.08
llama 8B Q6_K RTX 4090 1024 pp2048 9300.89 10097.27 1.09
llama 8B Q6_K RTX 4090 2048 pp2048 8473.21 9150.08 1.08
llama 8B Q6_K P40 16 pp2048 333.72 295.57 0.89
llama 8B Q6_K P40 32 pp2048 460.03 466.11 1.01
llama 8B Q6_K P40 64 pp2048 581.10 590.98 1.02
llama 8B Q6_K P40 128 pp2048 649.29 661.03 1.02
llama 8B Q6_K P40 256 pp2048 685.27 701.79 1.02
llama 8B Q6_K P40 512 pp2048 696.77 711.19 1.02
llama 8B Q6_K P40 1024 pp2048 681.71 700.28 1.03
llama 8B Q6_K P40 2048 pp2048 648.71 659.70 1.02
llama 8B Q8_0 RX 6800 16 pp2048 254.44 248.89 0.98
llama 8B Q8_0 RX 6800 32 pp2048 355.17 352.40 0.99
llama 8B Q8_0 RX 6800 64 pp2048 418.60 435.19 1.04
llama 8B Q8_0 RX 6800 128 pp2048 527.21 547.83 1.04
llama 8B Q8_0 RX 6800 256 pp2048 627.28 650.18 1.04
llama 8B Q8_0 RX 6800 512 pp2048 634.19 657.36 1.04
llama 8B Q8_0 RX 6800 1024 pp2048 714.38 739.91 1.04
llama 8B Q8_0 RX 6800 2048 pp2048 647.66 671.90 1.04
llama 8B Q8_0 RTX 3090 16 pp2048 989.94 977.21 0.99
llama 8B Q8_0 RTX 3090 32 pp2048 1739.38 1834.08 1.05
llama 8B Q8_0 RTX 3090 64 pp2048 2784.83 2895.64 1.04
llama 8B Q8_0 RTX 3090 128 pp2048 3681.37 3751.25 1.02
llama 8B Q8_0 RTX 3090 256 pp2048 4243.38 4395.91 1.04
llama 8B Q8_0 RTX 3090 512 pp2048 4434.48 4641.51 1.05
llama 8B Q8_0 RTX 3090 1024 pp2048 4482.22 4733.34 1.06
llama 8B Q8_0 RTX 3090 2048 pp2048 4291.45 4533.04 1.06
llama 8B Q8_0 RTX 4090 16 pp2048 1233.64 1364.73 1.11
llama 8B Q8_0 RTX 4090 32 pp2048 2340.64 2536.91 1.08
llama 8B Q8_0 RTX 4090 64 pp2048 4201.01 4469.49 1.06
llama 8B Q8_0 RTX 4090 128 pp2048 6776.67 7006.66 1.03
llama 8B Q8_0 RTX 4090 256 pp2048 9835.87 10517.99 1.07
llama 8B Q8_0 RTX 4090 512 pp2048 11124.57 11976.77 1.08
llama 8B Q8_0 RTX 4090 1024 pp2048 11070.31 11805.52 1.07
llama 8B Q8_0 RTX 4090 2048 pp2048 10117.85 10644.41 1.05
llama 8B Q8_0 P40 16 pp2048 318.64 352.32 1.11
llama 8B Q8_0 P40 32 pp2048 562.07 559.65 1.00
llama 8B Q8_0 P40 64 pp2048 609.36 629.17 1.03
llama 8B Q8_0 P40 128 pp2048 731.95 751.52 1.03
llama 8B Q8_0 P40 256 pp2048 815.45 846.11 1.04
llama 8B Q8_0 P40 512 pp2048 854.82 886.85 1.04
llama 8B Q8_0 P40 1024 pp2048 840.69 865.01 1.03
llama 8B Q8_0 P40 2048 pp2048 798.29 820.32 1.03

@github-actions github-actions bot added the Nvidia GPU Issues specific to Nvidia GPUs label Jul 10, 2024
@JohannesGaessler JohannesGaessler added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jul 10, 2024
@slaren
Copy link
Member

slaren commented Jul 10, 2024

The need_sum stuff seems very hacky and it will be very hard to understand for other people what is going on. If these are effectively different types, it would be much clearer if each type had its own name and struct.

@JohannesGaessler
Copy link
Collaborator Author

I would prefer to keep block_q8_1_mmq as a single struct because sizeof(block_q8_1_mmq) is used in many places to calculate pointers and it would be unclear why you would arbitrarily choose one of the three memory layouts over the other ones. How about this: define structs only for the scales/partial sums and change the 16 byte padding where these are saved to a union of those structs. Also replace the integer in need_sum with an enum.

@slaren
Copy link
Member

slaren commented Jul 11, 2024

I think that would work, it would still achieve the goal of making the code easier to understand.

@JohannesGaessler JohannesGaessler merged commit 808aba3 into ggml-org:master Jul 11, 2024
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 12, 2024
* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 12, 2024
* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants