Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update model definition to support Flash-Decoding #177

Merged
merged 15 commits into from
Jan 30, 2024

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 30, 2024

This PR integrates Flash-Decoding support from apache/tvm#16474. This is a drop-in replacement for the vLLM kernel. The only difference with the vLLM-based build is the shape of KV cache blocks. In particular, the block size for vLLM is 16 while for Flash-Decoding it is 256.

In addition, it supports decoding with multiple, fixed length queries per request, which is necessary for speculative decoding. evaluate_multi_query from #156 can also be used for this purpose, but it supports variable-length queries per request and piggy-backs to the prefill attention, which is not efficient when the number of queries is fixed and small. The changes in run_llama_batched_vllm.py demonstrates that the new Relax function, decode_multi_query, can do exactly the same thing as evaluate_multi_query when the query length is fixed.

This PR only updates the model definition and run_llama_batched_vllm.py example script. I'll follow up with the integration into mlc-serve next.

Need the latest https://github.com/octoml/tvm/tree/for-mlc-serve-jan12

@sunggg @yelite @vinx13

@@ -402,6 +403,10 @@ class BuildArgs:
"action": "store_true",
},
)
paged_kv_cache_type: str = field(
default="vllm",
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new option makes --use_vllm_attention obsolete. Since removing it is a breaking change, I'll do that later when I integrate Flash-Decoding into mlc-serve. @sunggg

Lunderberg pushed a commit to Lunderberg/mlc-llm that referenced this pull request Jan 30, 2024
The repetition penalty (introduced in [CTRL](https://arxiv.org/abs/1909.05858)) can help prevent the LLM from generating repetitive tokens.
This PR implements the repetition penalty.

Note: Previous the logits softmax is performed on GPU, this PR moves it to CPU to accommodate the repetition penalty.
Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, in the follow-up PR, would you share some benchmark numbers? Thank you!

@sunggg sunggg merged commit 253da78 into octoml:batch-serving Jan 30, 2024
1 check passed
This was referenced Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants