Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: enable lookup decoding #6828

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Apr 22, 2024

This PR aims to enable lookup decoding for the llama.cpp server in the same way as it is used in examples/lookup, see #5479 . To recapitulate, the implementation tries to guess the next few tokens that will be generated using simple text statistics. I think the current implementation works but it is difficult to properly benchmark. The intended way for it to work is:

  • Start with empty context cache, empty dynamic cache (unless the user provides one), and static cache loaded from file.
  • When generating tokens, try to continue with context cache, validate with static cache.
  • If that fails, try to continue with dynamic cache, validate with static cache.
  • If that fails, try to continue with static cache.
  • When new tokens are generated, use them to update the context cache.
  • When a generation is finished, update the dynamic cache with the context cache, then empty the context cache.
  • On server shutdown, save dynamic cache to file if the user provided a path.

These are the results I get from examples/server/bench.py using an RTX 4090 and various static lookup caches and an initially empty dynamic cache:

Model Static lookup cache Iterations master PR --draft 0 Iterations PR --draft 5 Speedup
Phi-2 3b q4_0 None 274 365 1.33
Phi-2 3b q4_0 Wikitext 103 274 361 1.32
Phi-2 3b q4_0 Mistral 1.64e8 274 354 1.29
LLaMA 2 7b q4_0 None 148 256 1.73
LLaMA 2 7b q4_0 Wikitext 103 148 255 1.72
LLaMA 2 7b q4_0 Mistral 1.64e8 148 255 1.72

Edit: the table was labeled incorrectly. The speedup was not relative to master but relative to --draft 0 which included the overhead for no benefit.

It does seem to provide a speedup but adding a static lookup cache does not seem to help (the caches are created either from Wikitext 103 or from 164 million tokens generated with Mistral q8_0). Assuming there are no bugs, what I think is happening is that the dataset for the benchmark (see server bench README) is very repetitive so using a static cache pulls the drafts away from these very repetitive patterns and reduces the speed. Also for Phi-2 in particular I think that I simply don't have enough input data for the static cache to get sufficiently precise text statistics (since it has a larger vocabulary size). Regarding the latter, I recently built a machine with 6x RTX 4090 so I think I will be able to significantly scale up the rate at which I can produce synthetic text (I was previously using 3x P40 and 1x RX 6800).

In this PR I also changed the interface of llama_ngram_cache_load to be more in line with the rest of llama.cpp; I'll maybe change the interface of some of the other functions as well.

Also: is it somehow possible to retrieve the tokens that were previously fed to the model? I'm currently manually tracking this in server.cpp but this adds the potential for error.

@phymbert
Copy link
Collaborator

Great work. As we discussed previously, servers' test coverage matters, and adding a new scenario in the test framework is mandatory.

@JohannesGaessler
Copy link
Collaborator Author

adding a new scenario in the test framework is mandatory.

Are there already any tests that assert correctness for the server? I didn't see any so as part of this implementation I would try to add some.

@phymbert
Copy link
Collaborator

Are there already any tests that assert correctness for the server? I didn't see any so as part of this implementation I would try to add some.

https://github.com/ggerganov/llama.cpp/tree/master/examples/server/tests

@JohannesGaessler
Copy link
Collaborator Author

While writing tests I'm noticing that when using > 1 slots the results for a given seed are not consistent on master. @phymbert is this a known problem?

@phymbert
Copy link
Collaborator

I was not aware, but this is not asserted in the parallel test suite AFAIK.

Also, I recall that each architecture generates different results.

@arnfaldur
Copy link

arnfaldur commented Apr 25, 2024

A research paper studying this exact technique was recently published and suggested for integration in an issue #6813
I have been looking at the current implementation and trying to make it match the implementation in the paper.

  • The short version of it is that there are no static or dynamic N-gram caches, only ones like you call context, generated on the fly.
  • They use what they call a multi level N-gram. There are multiple caches. Each cache only considers N-grams of length N. When querying the multi level N-gram, the module with the longest N-grams that match the token suffix is the one used.
  • They then generate K tokens using this multi level N-gram as a proposal to be validated by the LLM.
  • Based on their ablation study, they suggest N = 5 and K = 7.

I haven't spent enough time reading ngram-cache.cpp and friends to tell how it works and how it differs, besides the persistence.
I will post results if there are any. I'm new to this codebase and a bit rusty at C++. Adjust your expectations accordingly.

@JohannesGaessler
Copy link
Collaborator Author

Thanks for the input. I saw the paper but didn't yet get around to reading it.

They use what they call a multi level N-gram. There are multiple caches. Each cache only considers N-grams of length N. When querying the multi level N-gram, the module with the longest N-grams that match the token suffix is the one used.

In practice my implementation also uses N-grams of varying sizes. A llama_ngram_cache can contain N-grams of multiple sizes simultaneously; it's just easier to bundle them together. The context and dynamic caches contain 1-grams, 2-grams, 3-grams, and 4-grams. The static caches only contain 2-grams (given enough input data 3-grams or 4-grams should also be viable).

@JohannesGaessler
Copy link
Collaborator Author

I forgot: if you want to play around with the llama.cpp implementation, take a look at lookup-stats. It has the same interface as perplexity and can be used to estimate how many tokens tokens would be predicted on a pre-generated text (so you don't have to actually evaluate the model).

@JamshedQurbonboev
Copy link

How much does this PR increase token generation? As far I am aware #5479 had rather tiny speedup. And when do you think this PR will be ready to be merged?

@JohannesGaessler
Copy link
Collaborator Author

How much does this PR increase token generation? As far I am aware #5479 had rather tiny speedup.

Should be something like 1.1-1.4 for natural language with an essentially empty context. For source code or summarization it's going to be a lot more. The numbers in the OP are indicative of the long-term speedup using similar prompts once the dynamic cache fills up.

And when do you think this PR will be ready to be merged?

I'm aiming for the end of the week.

@sorasoras
Copy link

Does this allow us to create a static cache during inference?

@JohannesGaessler
Copy link
Collaborator Author

No, use lookup-create for that. I'll upload the caches that I've been using myself before I merge this PR.

@JohannesGaessler
Copy link
Collaborator Author

Functionally I think everything is in order now. Unfortunately I think that it's currently not possible to get bit-for-bit identical results with lookup decoding since the results seem to change slightly when the batch size is varied, see #6950 . For this reason there are no automated tests for lookup decoding that assert that the results do not change (because they do).

Copy link
Contributor

github-actions bot commented Apr 28, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 484 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=9681.58ms p(95)=24109.08ms fails=, finish reason: stop=429 truncated=55
  • Prompt processing (pp): avg=116.51tk/s p(95)=508.02tk/s
  • Token generation (tg): avg=28.36tk/s p(95)=54.29tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=server-ngram-4 commit=71c98cc3bd4afd19a813b89197b93a29d8cc0e86

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 484 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715552714 --> 1715553340
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 664.96, 664.96, 664.96, 664.96, 664.96, 554.92, 554.92, 554.92, 554.92, 554.92, 569.16, 569.16, 569.16, 569.16, 569.16, 597.17, 597.17, 597.17, 597.17, 597.17, 662.36, 662.36, 662.36, 662.36, 662.36, 684.79, 684.79, 684.79, 684.79, 684.79, 684.96, 684.96, 684.96, 684.96, 684.96, 689.13, 689.13, 689.13, 689.13, 689.13, 700.63, 700.63, 700.63, 700.63, 700.63, 701.78, 701.78, 701.78, 701.78, 701.78, 717.84, 717.84, 717.84, 717.84, 717.84, 721.95, 721.95, 721.95, 721.95, 721.95, 746.06, 746.06, 746.06, 746.06, 746.06, 791.89, 791.89, 791.89, 791.89, 791.89, 801.42, 801.42, 801.42, 801.42, 801.42, 728.36, 728.36, 728.36, 728.36, 728.36, 730.82, 730.82, 730.82, 730.82, 730.82, 730.84, 730.84, 730.84, 730.84, 730.84, 729.46, 729.46, 729.46, 729.46, 729.46, 740.83, 740.83, 740.83, 740.83, 740.83, 743.75, 743.75, 743.75, 743.75, 743.75, 744.05, 744.05, 744.05, 744.05, 744.05, 751.84, 751.84, 751.84, 751.84, 751.84, 752.26, 752.26, 752.26, 752.26, 752.26, 758.02, 758.02, 758.02, 758.02, 758.02, 762.14, 762.14, 762.14, 762.14, 762.14, 762.87, 762.87, 762.87, 762.87, 762.87, 764.75, 764.75, 764.75, 764.75, 764.75, 766.72, 766.72, 766.72, 766.72, 766.72, 781.72, 781.72, 781.72, 781.72, 781.72, 778.63, 778.63, 778.63, 778.63, 778.63, 779.73, 779.73, 779.73, 779.73, 779.73, 778.17, 778.17, 778.17, 778.17, 778.17, 778.67, 778.67, 778.67, 778.67, 778.67, 783.54, 783.54, 783.54, 783.54, 783.54, 785.1, 785.1, 785.1, 785.1, 785.1, 783.67, 783.67, 783.67, 783.67, 783.67, 786.25, 786.25, 786.25, 786.25, 786.25, 783.81, 783.81, 783.81, 783.81, 783.81, 792.32, 792.32, 792.32, 792.32, 792.32, 801.87, 801.87, 801.87, 801.87, 801.87, 802.64, 802.64, 802.64, 802.64, 802.64, 801.62, 801.62, 801.62, 801.62, 801.62, 801.58, 801.58, 801.58, 801.58, 801.58, 804.53, 804.53, 804.53, 804.53, 804.53, 806.0, 806.0, 806.0, 806.0, 806.0, 796.97, 796.97, 796.97, 796.97, 796.97, 786.16, 786.16, 786.16, 786.16, 786.16, 748.27, 748.27, 748.27, 748.27, 748.27, 747.44, 747.44, 747.44, 747.44, 747.44, 746.86, 746.86, 746.86, 746.86, 746.86, 751.51, 751.51, 751.51, 751.51, 751.51, 750.58, 750.58, 750.58, 750.58, 750.58, 751.4, 751.4, 751.4, 751.4, 751.4, 757.75, 757.75, 757.75, 757.75, 757.75, 757.78, 757.78, 757.78, 757.78, 757.78, 762.74, 762.74, 762.74, 762.74, 762.74, 767.18, 767.18, 767.18, 767.18, 767.18, 766.77, 766.77, 766.77, 766.77, 766.77, 772.08, 772.08, 772.08, 772.08]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 484 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715552714 --> 1715553340
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 29.58, 29.58, 29.58, 29.58, 29.58, 26.11, 26.11, 26.11, 26.11, 26.11, 25.16, 25.16, 25.16, 25.16, 25.16, 25.5, 25.5, 25.5, 25.5, 25.5, 26.79, 26.79, 26.79, 26.79, 26.79, 26.66, 26.66, 26.66, 26.66, 26.66, 26.73, 26.73, 26.73, 26.73, 26.73, 26.87, 26.87, 26.87, 26.87, 26.87, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.42, 27.78, 27.78, 27.78, 27.78, 27.78, 27.4, 27.4, 27.4, 27.4, 27.4, 27.19, 27.19, 27.19, 27.19, 27.19, 26.8, 26.8, 26.8, 26.8, 26.8, 26.91, 26.91, 26.91, 26.91, 26.91, 26.54, 26.54, 26.54, 26.54, 26.54, 26.52, 26.52, 26.52, 26.52, 26.52, 26.44, 26.44, 26.44, 26.44, 26.44, 26.38, 26.38, 26.38, 26.38, 26.38, 26.17, 26.17, 26.17, 26.17, 26.17, 25.61, 25.61, 25.61, 25.61, 25.61, 25.57, 25.57, 25.57, 25.57, 25.57, 25.48, 25.48, 25.48, 25.48, 25.48, 25.23, 25.23, 25.23, 25.23, 25.23, 25.42, 25.42, 25.42, 25.42, 25.42, 25.38, 25.38, 25.38, 25.38, 25.38, 25.36, 25.36, 25.36, 25.36, 25.36, 25.44, 25.44, 25.44, 25.44, 25.44, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 25.4, 24.75, 24.75, 24.75, 24.75, 24.75, 24.39, 24.39, 24.39, 24.39, 24.39, 24.38, 24.38, 24.38, 24.38, 24.38, 24.4, 24.4, 24.4, 24.4, 24.4, 24.45, 24.45, 24.45, 24.45, 24.45, 24.62, 24.62, 24.62, 24.62, 24.62, 24.74, 24.74, 24.74, 24.74, 24.74, 24.84, 24.84, 24.84, 24.84, 24.84, 24.91, 24.91, 24.91, 24.91, 24.91, 24.87, 24.87, 24.87, 24.87, 24.87, 24.81, 24.81, 24.81, 24.81, 24.81, 24.51, 24.51, 24.51, 24.51, 24.51, 24.61, 24.61, 24.61, 24.61, 24.61, 24.65, 24.65, 24.65, 24.65, 24.65, 24.79, 24.79, 24.79, 24.79, 24.79, 24.83, 24.83, 24.83, 24.83, 24.83, 24.88, 24.88, 24.88, 24.88, 24.88, 24.81, 24.81, 24.81, 24.81, 24.81, 24.64, 24.64, 24.64, 24.64, 24.64, 24.52, 24.52, 24.52, 24.52, 24.52, 24.09, 24.09, 24.09, 24.09, 24.09, 23.89, 23.89, 23.89, 23.89, 23.89, 23.83, 23.83, 23.83, 23.83, 23.83, 23.94, 23.94, 23.94, 23.94, 23.94, 23.99, 23.99, 23.99, 23.99, 23.99, 24.14, 24.14, 24.14, 24.14, 24.14, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.19, 24.23, 24.23, 24.23, 24.23, 24.23, 24.09, 24.09, 24.09, 24.09]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 484 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715552714 --> 1715553340
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24, 0.24, 0.24, 0.24, 0.24, 0.38, 0.38, 0.38, 0.38, 0.38, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.24, 0.24, 0.24, 0.24, 0.24, 0.16, 0.16, 0.16, 0.16, 0.16, 0.39, 0.39, 0.39, 0.39, 0.39, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.38, 0.38, 0.38, 0.38, 0.38, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.27, 0.27, 0.27, 0.27, 0.27, 0.15, 0.15, 0.15, 0.15, 0.15, 0.34, 0.34, 0.34, 0.34, 0.34, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.31, 0.31, 0.31, 0.31, 0.31, 0.49, 0.49, 0.49, 0.49, 0.49, 0.53, 0.53, 0.53, 0.53, 0.53, 0.46, 0.46, 0.46, 0.46, 0.46, 0.4, 0.4, 0.4, 0.4, 0.4, 0.1, 0.1, 0.1, 0.1, 0.1, 0.18, 0.18, 0.18, 0.18, 0.18, 0.1, 0.1, 0.1, 0.1, 0.1, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 484 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715552714 --> 1715553340
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0]
                    
Loading

@JohannesGaessler JohannesGaessler force-pushed the server-ngram-4 branch 2 times, most recently from ce22137 to 1d516d3 Compare April 28, 2024 19:53
@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 28, 2024

I'm very sorry but it seems the numbers that I previously reported were incorrect. The speed I reported for "master" was actually the speed for this PR with --draft 0. However, this means that the numbers were still including the overhead associated with the lookup caches which is quite significant. These are the correct numbers for the most recent version:

Model Static lookup cache Slots Iterations master Iterations PR --draft 0 Iterations PR --draft 5 Speedup vs. master Speedup vs. --draft 0
Phi-2 3b q4_0 None 1 549 274 363 0.66 1.32
Phi-2 3b q4_0 None 2 947 455 599 0.63 1.32
Phi-2 3b q4_0 None 4 1465 704 797 0.54 1.13
Phi-2 3b q4_0 None 8 1856 855 900 0.48 1.05

For Phi-2 on an RTX 4090 there is a regression relative to master because it is quite fast so the constant overhead per token is too large relative to the speedup. I'll investigate performance for larger models/slower hardware.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 28, 2024

Performance for LLaMA 3 70 on 3x RTX 4090 is looking much better:

Model Static lookup cache Slots Iterations master Iterations PR --draft 5 Speedup vs. master
LLaMA 3 70b q4_K_M None 1 24 44 1.83
LLaMA 3 70b q4_K_M WT 103 1 24 42 1.75

@Green-Sky
Copy link
Collaborator

Green-Sky commented Apr 29, 2024

Regarding performance, it seems your hashes for the lookup table are of low quality. std::hash<llama_token> is the same as std::hash<int32_t>, which just returns the identity of the token. Also the standard containers are known to be not performing best in class, but that's a different issue. :)

edit: this is a wold class article on how to make a fast lookup table, including a pretty neat hash function https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/

edit2: the way the hashes are combined in the ngram means that for the 64bit, only 32bit have any entropy at all. A better hash would probably fix this, but hashes are often combined with an extra shift or another multiplication.

@JohannesGaessler
Copy link
Collaborator Author

Thank you for the high-quality post. I definitely agree that the hashing is suboptimal, my main concern for now is to get something that works at all, and to also implement tests that assert this.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Apr 29, 2024

Prior to reading the hashing function blog post I wrote a simple implementation that just uses bit shifts and xors but that already results in much better performance:

Model Static lookup cache Slots Iterations master Iterations PR --draft 5 Speedup vs. master
Phi-2 3b q4_0 None 1 549 634 1.15
Phi-2 3b q4_0 None 2 947 1113 1.18
Phi-2 3b q4_0 None 4 1465 1572 1.07
Phi-2 3b q4_0 None 8 1856 1790 0.96
Phi-2 3b q4_0 WT 103 1 549 643 1.17
Phi-2 3b q4_0 WT 103 2 947 1098 1.16
Phi-2 3b q4_0 WT 103 4 1465 1549 1.06
Phi-2 3b q4_0 WT 103 8 1856 1766 0.95

@JamshedQurbonboev
Copy link

Thanks for improving performance of llama.cpp. It seems that you were correct: lookup decoding improves speed, but adds constant overhead. So larger models have greater benefit from it. How does performance looks like for 7-13b models, in slower GPU and CPU-only backends?

@JohannesGaessler
Copy link
Collaborator Author

I think the model and prompt will be a bigger factor than the hardware as long as the hashing is fast enough. These are some numbers I get on my Epyc 7742 CPU with 8x 3200 MHz Micron DIMMs:

Model Static lookup cache Slots Iterations master Iterations PR --draft 5 Speedup vs. master
Phi-2 3b q4_0 None 1 103 119 1.16
LLaMA 3 70b q4_K_M None 1 3 5 1.67

Note that the comparatively large speedups with LLaMA 3 70b are likely a product of heavy repetition since I am using the base model.

@JohannesGaessler
Copy link
Collaborator Author

I've added a test for asserting that lookup decoding produces correct results. The sequences are the same for temperature 0 though the results are not going to be bit-for-bit identical. I've also investigated the performance for LLaMA 3 Instruct in more detail:

Results
Model GPU Static lookup cache Slots Seed Previous runs Iterations master Iterations PR --draft 5 Speedup vs. master
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 1 42 0 166 182 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 2 42 0 278 283 1.02
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 4 42 0 429 367 0.86
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 8 42 0 531 407 0.77
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 42 0 166 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 42 0 278 282 1.01
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 42 0 429 355 0.83
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 0 166 182 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 1 166 186 1.12
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 2 166 204 1.23
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 3 166 208 1.25
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 4 166 215 1.30
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 5 166 219 1.32
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 6 166 222 1.34
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 7 166 222 1.34
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 0 278 285 1.03
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 1 278 283 1.02
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 2 278 309 1.11
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 3 278 315 1.13
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 4 278 324 1.17
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 5 278 326 1.17
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 6 278 329 1.18
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 7 278 333 1.20
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 0 429 363 0.85
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 1 429 353 0.82
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 2 429 370 0.86
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 3 429 378 0.88
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 4 429 378 0.88
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 5 429 383 0.89
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 6 429 383 0.89
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 7 429 388 0.90
LLaMA 3 Instruct 70b q4_K_M 2x RTX 4090 None 1 42 0 28 31 1.11
LLaMA 3 Instruct 70b q4_K_M 2x RTX 4090 None 2 42 0 55 57 1.04
LLaMA 3 Instruct 70b q4_K_M 2x RTX 4090 None 4 42 0 96 69 0.72
LLaMA 3 Instruct 70b q4_K_M 2x RTX 4090 None 8 42 0 120 OOM ???

The speedup between LLaMA 3 instruct 8b and 70b seems to be very similar. The current implementation is only faster for small numbers of slots since there is comparatively less benefit for adding more tokens to the batch if you're already at 8 tokens per batch without any speculative decoding. Successive, similar runs with different seeds but a carried over dynamic cache result in increasing performance over time, for a single slot the 8th run was ~1.2x faster than the first one.

From my side I would consider this PR ready to be merged if one last issue is resolved: whether n-gram lookup should be enabled or disabled by default. The default for the number of slots is 1 and for that case it is faster. However, due to the varying batch size it also causes nondeterministic results. I personally would tend more towards having n-gram lookup be disabled by default but do not have a strong opinion on it.

@Green-Sky
Copy link
Collaborator

@JohannesGaessler can I convince you to quickly add an overload for std::hash<llama_token_t> and do a quick comparison? While the shift in the ngram hash stuffles the hash a bit, it probably is still pretty bad. + this is a very small change.

@JohannesGaessler
Copy link
Collaborator Author

I'm not sure what you mean by overload but I'm happy to test suggested alternatives.

@Green-Sky
Copy link
Collaborator

Try the following:

diff --git a/common/ngram-cache.h b/common/ngram-cache.h
index 6575ea05..df420e1f 100644
--- a/common/ngram-cache.h
+++ b/common/ngram-cache.h
@@ -37,13 +37,18 @@ struct llama_ngram {
     }
 };

 };

+struct llama_token_hash_function {
+    size_t operator()(const llama_token token) const {
+        return token * 11400714819323198485llu;
+    }
+};
+
 struct llama_ngram_hash_function {
     size_t operator()(const llama_ngram & ngram) const {
-        size_t hash = ngram.tokens[0];
+        size_t hash = llama_token_hash_function{}(ngram.tokens[0]);

         for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
-            hash <<= 15;
-            hash ^= ngram.tokens[i];
+            hash ^= llama_token_hash_function{}(ngram.tokens[i]);
         }

         return hash;
@@ -51,7 +56,7 @@ struct llama_ngram_hash_function {
 };

I went the route you went instead and used another callable type.
Notes:

  • I removed the shift, since it discards a lot in this case.
  • Since ngrams hash is always over the #define LLAMA_NGRAM_MAX 4 and the unused are -1, you actually shift any entropy away, which collapses them to the same hash again.
  • The multiply is probably enough.

Please test :)

@JohannesGaessler
Copy link
Collaborator Author

I took over the Fibonacci hash implementation. For LLaMA 3 q4_K_M on an RTX 4090 it's maybe a ~1% end-to-end speedup.

Results
Model GPU Static lookup cache Slots Seed Previous runs Iterations master Iterations PR --draft 5 Speedup vs. master
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 1 42 0 166 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 2 42 0 278 285 1.03
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 4 42 0 429 365 0.85
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 8 42 0 531 417 0.79
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 42 0 166 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 42 0 278 284 1.02
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 42 0 429 360 0.84
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 0 166 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 1 166 184 1.11
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 2 166 206 1.24
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 3 166 212 1.28
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 4 166 215 1.30
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 5 166 219 1.32
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 6 166 221 1.33
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 1 -1 7 166 223 1.34
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 0 278 288 1.04
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 1 278 283 1.02
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 2 278 308 1.11
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 3 278 315 1.13
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 4 278 322 1.16
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 5 278 322 1.16
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 6 278 329 1.18
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 2 -1 7 278 330 1.19
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 0 429 358 0.83
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 1 429 353 0.82
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 2 429 372 0.87
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 3 429 377 0.88
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 4 429 380 0.89
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 5 429 383 0.89
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 6 429 386 0.90
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 WT 103 4 -1 7 429 389 0.91

@Green-Sky
Copy link
Collaborator

Different STL implementations will perform differently here.

@@ -163,6 +164,10 @@ struct server_slot {
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;

llama_ngram_cache nc_context;
std::vector<llama_token> draft;
std::vector<llama_token> context_tokens;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same as cache_tokens?

Ideally, the llama_sampling_context should maintain a history of the processed tokens. There are already steps in that direction via the std::vector<llama_token> prev; member and the llama_sampling_accept() API, but some more work would be needed (such as API for removing discarded tokens). Not needed to be done in this PR - just clarifying the long-term goals

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it should be possible to refactor the code in such a way that the same vector is used for caching and lookup.

@mofosyne mofosyne added Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level enhancement New feature or request examples labels May 9, 2024
@JohannesGaessler
Copy link
Collaborator Author

I re-tested the performance on 1x RTX 4090 with CUDA graphs but against my expectations I am seeing virtually no performance difference compared to before:

Model GPU Static lookup cache Slots Seed Previous runs Iterations master Iterations PR --draft 5 Speedup vs. master
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 1 42 0 167 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 2 42 0 277 284 1.03
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 4 42 0 426 363 0.85
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 8 42 0 540 417 0.77

@sorasoras
Copy link

I re-tested the performance on 1x RTX 4090 with CUDA graphs but against my expectations I am seeing virtually no performance difference compared to before:

Model GPU Static lookup cache Slots Seed Previous runs Iterations master Iterations PR --draft 5 Speedup vs. master
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 1 42 0 167 183 1.10
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 2 42 0 277 284 1.03
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 4 42 0 426 363 0.85
LLaMA 3 Instruct 8b q4_K_M 1x RTX 4090 None 8 42 0 540 417 0.77

a quick question
How does number of draft affect the performance?
I saw you have many branch of different draft.

@JohannesGaessler
Copy link
Collaborator Author

The numbers for the server-ngram branches on my repository are just the numbers I use internally to keep my branches apart. Just use the branch I'm using for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request examples Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants