-
Notifications
You must be signed in to change notification settings - Fork 916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement NTK-Aware scaled and dynamically scaled RoPE for PositionRotaryEmbedding #529
Conversation
I've tested fixed NTK-Aware scaling on a project I'm working on and was successfully generating at 2400 tokens which is about the limit my RTX 6000 Ada can handle from VRAM with falcon 40BN instruct, but it was entirely coherent generation above the original 2048 token context. I still need to test dynamic scaling and clean up the PR further to comply with guidelines and the checklist, but wanted to open this up in the meantime. |
Just a note that Huggingface Transformers natively supports this now: huggingface/transformers@34d9409. Does this make it easier to implement here? |
@ssmi153 Not particularly, most of the attention modules in this repo are custom to support flash attention. The work in transformers is good to review for my implementation and that's about it from what I see. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR.
- We need only 1 new CLI argument. There's already a LOT of arguments, so let's try to keep them to a bare minimum for new features.
- Overall, can we remove a lot of the complexity ?
From what I read, dynamic scaling seems just better than static scaling, so let's just use dynamic scaling, no ? - The current code has a lot of pathways, can we keep them to a minimum ?
- Keep the code as close to the original as possible
- Nothing should be directly in
custom_modeling
file. This behavior it seems should be entirely agnostic of modeling code.
This can go in the config for instance (like quantize) and be in models/flash_llama.py
for instance (this is not modeling code, but wrapping the model itself, this will probably be factored away at some point, but here would be a good place for now).
I'm happy to make those changes if you want, as they are mostly stylistic choices rather than business logic.
#[clap(default_value = "2048", long, env)] | ||
max_batch_prefill_tokens: u32, | ||
#[clap(default_value = "16000", long, env)] | ||
#[clap(default_value = "8192", long, env)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not belong in this PR.
We can discuss changing the defaults, but it's a separate concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yup fair, I don't want to change them I meant to clean this out. I'll remove!
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true": | ||
ROPE_DYNAMIC_SCALING = True | ||
else: | ||
ROPE_DYNAMIC_SCALING = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing should be model specific.
@@ -369,7 +369,7 @@ def forward(self, hidden_states, residual=None): | |||
import rotary_emb | |||
|
|||
class PositionRotaryEmbedding(nn.Module): | |||
def __init__(self, inv_freq): | |||
def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have at most 1 extra argument.
A lot of information should be extractable directly from inv_freq
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I can try to simplify this
if self.dynamic_scaling: | ||
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1) | ||
max_seq_len = self.original_max_seq_len * scale_factor | ||
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really OK I think.
You ditching entirely the original self.inv_freq
which unfortunately for us is sometimes different from the calculation proposed (that's why not all models are static
and some are load
.
Llama most notably has different saved inv_freq
(not sure why but it's indeed the case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Part of dynamic scaling is calculating the new inv_freq, looking at the dynamic scaling implementation in Transformers I don't see them preserving this value either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would you suggest alternatively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking interpolation when I wrote this.
Now that I reflect more it would make the code even more complex, which is not the desired effect.
Can we maybe move out the scaling factor out of get_inv_freq
and keep it directly here (since it just seems to be rescaling the base
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And so let's keep rewriting inv_freq. It has some indesirable effects on those models, but the other way is even worse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable, I'll make this change after work.
@Narsil thanks for the review! So my only reason for suggesting we keep static scaling is that it's much easier to consider VRAM usage of a statically scaled context window. If you have a model with 2048 context default, and scale it by 4 to a max of 8192 you can much more easily consider the VRAM consumption of that max. Otherwise I agree from what I've read as well dynamic is better performing. That said it's like you mentioned, having both adds complexity! If you still feel that's not enough of a reason to keep static I'll remove it 👍 |
Would love to see this exposed now that the huggingface/transformers#24653 merge is complete. I understand there are various complexities with flash attention, and if flash attention 2 will be implemented.. --rope_scaling {"type": "dynamic", "factor": 2.0} or however seen fit by you guys ? 🙌 |
@iantbutler01 @Narsil some information for decision-making 🤗
The current state of scaling techniques in
What's out there that I'm adding next in
So... perhaps we can jump straight into the best technique in TGI? :D It should only need one flag in practice, the |
@gante That sounds good to me! I'll work on this over the next few days. |
FWIW I tested against Llama via a simple overwrite on the Docker image on the latest TGI 0.9.3, using a flash attn v2 compatible GPU and works good 😄 I pulled against this pr branch and layered it into the overwrite, with a few adjustments and assumptions.
Couple thoughts//concerns..
Here is the docker used to quickly overwrite for testing if its helpful:
Id be happy to share anything else if needed 🍻 |
@gante Looking at @jquesnelle's repo, and the comment you linked to it looks like there is actually both a standard by parts as well as a dynamic by parts method. So it looks like the improvement you were talking about applies to both types of NTK aware scaling? In that case I'm inclined to make this PR just the dynamic by parts method to save on some of the complexity. |
@Narsil @gante I spent some time tonight working to implement the dynamic parts by method I mentioned in my last comment. I'm coming to realize, that with this new method and the comment here: #512 (comment) suggesting there are now models that have been fine tuned with scaling that the complexity here has the chance to really be ballooning. Even just the parts by method itself is more gnarly and requires supporting a whole bunch of parameters on the attention module. Before I continue, at the risk of the complexity putting this in review hell, I'd like some guidance on what you all think I should proceed with. Personally if I add the dynamic parts by method linked in my previous comment it will have effectively set up the ability to implement the other methods here anyway, but maybe a follow up PR for those makes sense. |
@iantbutler01 you raised good points: as users fine-tune their models with rope scaling, they may lose compatibility with TGI (depending on how we decide to do things here). And yes, let's settle on a path that avoids review hell! I'd suggest separating the two use cases and making two separate decisions/PRs:
@Narsil @iantbutler01 WDYT? |
@gante I am fine with that approach, that's basically what I started last night but I wanted to make sure that's what everyone had in mind. |
Given the license change I am no longer comfortable contributing my work. |
@iantbutler01 would you be willing to contribute this change to a fork? I am strongly considering maintaining a fork of the repo from the commit before the license change. I would be adding support for speculative decoding there. |
# What does this PR do? - Adds Rope NTK scaling. Done because #529 was closed Took some code from huggingface/transformers#24653 - `--rope-scaling` and `--rope-factor` are added separately. I considered having a single one and parsing something line ("linear:4.0" , or "dynamic") but decided against it because it would push more parsing+validation a bit everywhere (both in the launcher and the server). Fixes #512 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
# What does this PR do? - Adds Rope NTK scaling. Done because huggingface/text-generation-inference#529 was closed Took some code from huggingface/transformers#24653 - `--rope-scaling` and `--rope-factor` are added separately. I considered having a single one and parsing something line ("linear:4.0" , or "dynamic") but decided against it because it would push more parsing+validation a bit everywhere (both in the launcher and the server). Fixes #512 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
# What does this PR do? - Adds Rope NTK scaling. Done because huggingface/text-generation-inference#529 was closed Took some code from huggingface/transformers#24653 - `--rope-scaling` and `--rope-factor` are added separately. I considered having a single one and parsing something line ("linear:4.0" , or "dynamic") but decided against it because it would push more parsing+validation a bit everywhere (both in the launcher and the server). Fixes #512 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
What does this PR do?
Implements NTK-Aware scaled and dynamically scaled RoPE for the PositionRotaryEmbedding to allow models to scale beyond their default max_tokens.
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
Fixes #512
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.