Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : perform cublas mat mul of quantized types as f16 #3412

Merged
merged 3 commits into from
Sep 30, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Sep 30, 2023

Improves prompt processing speed with quantized types with mmq disabled only (-nommq).

Essentially this is the same as #3370, extended to quantized types by dequantizing to fp16.

model size mmq test master t/s PR t/s speedup
llama 7B mostly Q2_K 2.63 GiB 0 pp 512 1740.05 ± 2.18 3422.52 ± 26.59 1.97
llama 7B mostly Q3_K - Large 3.35 GiB 0 pp 512 1704.44 ± 36.02 3434.08 ± 21.88 2.02
llama 7B mostly Q3_K - Medium 3.07 GiB 0 pp 512 1725.62 ± 2.40 3423.26 ± 44.50 1.98
llama 7B mostly Q3_K - Small 2.75 GiB 0 pp 512 1720.28 ± 17.54 3415.61 ± 15.86 1.98
llama 7B mostly Q4_0 3.56 GiB 0 pp 512 1705.19 ± 5.29 3230.66 ± 16.41 1.89
llama 7B mostly Q4_1 3.95 GiB 0 pp 512 1696.79 ± 16.01 3241.20 ± 25.87 1.91
llama 7B mostly Q4_K - Medium 3.80 GiB 0 pp 512 1718.43 ± 16.33 3507.96 ± 8.80 2.04
llama 7B mostly Q4_K - Small 3.59 GiB 0 pp 512 1727.00 ± 4.03 3413.81 ± 97.65 1.98
llama 7B mostly Q5_0 4.33 GiB 0 pp 512 1695.06 ± 6.91 3172.95 ± 15.14 1.87
llama 7B mostly Q5_1 4.72 GiB 0 pp 512 1697.97 ± 4.07 3179.81 ± 35.89 1.87
llama 7B mostly Q5_K - Medium 4.45 GiB 0 pp 512 1721.78 ± 2.12 3460.38 ± 34.02 2.01
llama 7B mostly Q5_K - Small 4.33 GiB 0 pp 512 1722.66 ± 4.93 3474.46 ± 36.03 2.02
llama 7B mostly Q6_K 5.15 GiB 0 pp 512 1712.02 ± 3.62 3468.20 ± 23.51 2.03
llama 7B mostly Q8_0 6.67 GiB 0 pp 512 1685.94 ± 8.08 3176.80 ± 51.69 1.88

For comparison, this is the performance that I get with mmq enabled (the default):

model size params backend ngl test t/s
llama 7B mostly Q2_K 2.63 GiB 6.74 B CUDA 99 pp 512 1814.65 ± 4.73
llama 7B mostly Q3_K - Large 3.35 GiB 6.74 B CUDA 99 pp 512 1922.97 ± 19.27
llama 7B mostly Q3_K - Medium 3.07 GiB 6.74 B CUDA 99 pp 512 2009.16 ± 8.20
llama 7B mostly Q3_K - Small 2.75 GiB 6.74 B CUDA 99 pp 512 1901.53 ± 31.49
llama 7B mostly Q4_0 3.56 GiB 6.74 B CUDA 99 pp 512 2420.71 ± 9.51
llama 7B mostly Q4_1 3.95 GiB 6.74 B CUDA 99 pp 512 2099.10 ± 31.81
llama 7B mostly Q4_K - Medium 3.80 GiB 6.74 B CUDA 99 pp 512 2220.84 ± 1.78
llama 7B mostly Q4_K - Small 3.59 GiB 6.74 B CUDA 99 pp 512 2181.68 ± 54.97
llama 7B mostly Q5_0 4.33 GiB 6.74 B CUDA 99 pp 512 2191.70 ± 5.07
llama 7B mostly Q5_1 4.72 GiB 6.74 B CUDA 99 pp 512 1945.15 ± 6.50
llama 7B mostly Q5_K - Medium 4.45 GiB 6.74 B CUDA 99 pp 512 2070.01 ± 5.23
llama 7B mostly Q5_K - Small 4.33 GiB 6.74 B CUDA 99 pp 512 2044.41 ± 12.79
llama 7B mostly Q6_K 5.15 GiB 6.74 B CUDA 99 pp 512 2125.72 ± 7.12
llama 7B mostly Q8_0 6.67 GiB 6.74 B CUDA 99 pp 512 2346.21 ± 35.02

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Can’t test atm, but if ppl looks ok we should merge

@slaren
Copy link
Collaborator Author

slaren commented Sep 30, 2023

Perplexity looks good:

model ppl
7B/ggml-model-f16.gguf 5.9073 +/- 0.03309
7B/ggml-model-Q2_K.gguf 6.5864 +/- 0.03755
7B/ggml-model-Q3_K_S.gguf 6.4524 +/- 0.03672
7B/ggml-model-Q3_K.gguf 6.1548 +/- 0.03456
7B/ggml-model-Q3_K_L.gguf 6.0866 +/- 0.03416
7B/ggml-model-Q4_0.gguf 6.1159 +/- 0.03504
7B/ggml-model-Q4_1.gguf 6.0655 +/- 0.03400
7B/ggml-model-Q4_K_S.gguf 6.0067 +/- 0.03371
7B/ggml-model-Q4_K.gguf 5.9616 +/- 0.03342
7B/ggml-model-Q5_0.gguf 5.9814 +/- 0.03409
7B/ggml-model-Q5_1.gguf 5.9418 +/- 0.03328
7B/ggml-model-Q5_K_S.gguf 5.9463 +/- 0.03330
7B/ggml-model-Q5_K.gguf 5.9196 +/- 0.03317
7B/ggml-model-Q6_K.gguf 5.9076 +/- 0.03309
7B/ggml-model-Q8_0.gguf 5.9078 +/- 0.03309

@Ph0rk0z
Copy link

Ph0rk0z commented Sep 30, 2023

Will this murder P40? Also, what if I am running a model on 3090s and also P40s together?

@slaren
Copy link
Collaborator Author

slaren commented Sep 30, 2023

This is only used on Volta and up.

@Ph0rk0z
Copy link

Ph0rk0z commented Sep 30, 2023

Right but what happens if one gpu is pascal and one GPU is ampere? Will it go with the lowest cuda version for all?

@slaren
Copy link
Collaborator Author

slaren commented Sep 30, 2023

This was already only used on the main GPU, but I think that even that may not work properly when converting dst to fp32 due to synchronization issues. So this is completely disabled with multi GPU now, the fp32 mat mul is used when using more than one GPU.

@ggerganov
Copy link
Owner

ggerganov commented Sep 30, 2023

I updated the A100 numbers using this PR: #3359

@slaren slaren merged commit f5ef5cf into master Sep 30, 2023
34 checks passed
@slaren slaren deleted the cublas-q-f16 branch September 30, 2023 16:13
@Dampfinchen
Copy link

Dampfinchen commented Sep 30, 2023

This increases VRAM usage for some reason. With this build and using --nommap my q4K_S model no longer fits in VRAM and it slows down dramatically.

Edit: Apologies, I misread. I was confusing the new mul mat kernels (MMQ) with MMAP. So the higher VRAM usage is expected. The difference MMQ makes is dramatic in my case:

llama_print_timings:        load time =  2527.29 ms
llama_print_timings:      sample time =   111.18 ms /   180 runs   (    0.62 ms per token,  1618.94 tokens per second)
llama_print_timings: prompt eval time = 35581.01 ms /  1849 tokens (   19.24 ms per token,    51.97 tokens per second)
llama_print_timings:        eval time =  6702.92 ms /   179 runs   (   37.45 ms per token,    26.70 tokens per second)
llama_print_timings:       total time = 42690.89 ms
llama_print_timings:        load time =  2477.21 ms
llama_print_timings:      sample time =   109.03 ms /   180 runs   (    0.61 ms per token,  1650.92 tokens per second)
llama_print_timings: prompt eval time =  4031.07 ms /  1849 tokens (    2.18 ms per token,   458.69 tokens per second)
llama_print_timings:        eval time =  6651.27 ms /   179 runs   (   37.16 ms per token,    26.91 tokens per second)
llama_print_timings:       total time = 11085.66 ms

Hopefully this change can be made to work with mmq as well.

@slaren
Copy link
Collaborator Author

slaren commented Sep 30, 2023

Hopefully this change can be made to work with mmq as well.

Once support for tensor cores is added to mmq, it will be as fast or faster than cublas again, while still using less VRAM. For now, cublas is the only way to use tensor cores.

@Dampfinchen
Copy link

Dampfinchen commented Sep 30, 2023

Hopefully this change can be made to work with mmq as well.

Once support for tensor cores is added to mmq, it will be as fast or faster than cublas again, while still using less VRAM. For now, cublas is the only way to use tensor cores.

Just ran NSight and can confirm the tensor cores are, for the first time ever, used to their full extent.

tensorcores

Awesome work! Now fingers crossed its easy to enable tensor core support for mmq as well. If mmq (which was a lot faster than cublas before this commit) can benefit from TC support as well, then we are definately in for another revolution here. Exciting stuff!

@YellowRoseCx
Copy link
Contributor

Would be interesting to see how the changes affect AMD users

joelkuiper added a commit to vortext/llama.cpp that referenced this pull request Oct 2, 2023
…example

* 'master' of github.com:ggerganov/llama.cpp:
  ggml-cuda : perform cublas mat mul of quantized types as f16 (ggerganov#3412)
  llama.cpp : add documentation about rope_freq_base and scale values (ggerganov#3401)
  train : fix KQ_pos allocation (ggerganov#3392)
  llama : quantize up to 31% faster on Linux and Windows with mmap (ggerganov#3206)
  readme : update hot topics + model links (ggerganov#3399)
  readme : add link to grammars app (ggerganov#3388)
  swift : fix build on xcode 15 (ggerganov#3387)
  build : enable more non-default compiler warnings (ggerganov#3200)
  ggml_tensor: update the structure comments. (ggerganov#3283)
  ggml : release the requested thread pool resource (ggerganov#3292)
  llama.cpp : split llama_context_params into model and context params (ggerganov#3301)
  ci : multithreaded builds (ggerganov#3311)
  train : finetune LORA (ggerganov#2632)
  gguf : basic type checking in gguf_get_* (ggerganov#3346)
  gguf : make token scores and types optional (ggerganov#3347)
  ci : disable freeBSD builds due to lack of VMs (ggerganov#3381)
  llama : custom attention mask + parallel decoding + no context swaps (ggerganov#3228)
  docs : mark code as Bash (ggerganov#3375)
  readme : add Mistral AI release 0.1 (ggerganov#3362)
  ggml-cuda : perform cublas fp16 matrix multiplication as fp16 (ggerganov#3370)
@JohannesGaessler
Copy link
Collaborator

Once support for tensor cores is added to mmq, it will be as fast or faster than cublas again, while still using less VRAM.

~2 weeks ago I did a prototype implementation for mmq using tensor cores and was not able to get better performance. From what I can tell a prerequisite to getting good tensor core utilization would be to load data asynchronously. As of right now the mmq compute pipeline utilization (without tensor cores) is only ~50%.

yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
…ov#3412)

* ggml-cuda : perform cublas matrix multiplication of quantized types as fp16

* rename CC_TURING to CC_VOLTA

* disable fp16 mat mul completely with multi GPU
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants