Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. #512

Closed
flozi00 opened this issue Jun 30, 2023 · 33 comments · Fixed by #741

Comments

@flozi00
Copy link
Contributor

flozi00 commented Jun 30, 2023

Feature request

Longer context up to 8k tokens, the given discussion and notebook generate promising results

Motivation

Discussion: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

Colab Notebook: https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=d2ceb547

Your contribution

As it's only 3 lines of code to change it would be pretty easy to change

I will start training an model and give an example demo

@Narsil
Copy link
Collaborator

Narsil commented Jun 30, 2023

Oh nice. And if you want to write a PR that would be awesome too.

Please be mindful that tgi code doesn't do batching the same as transformers meaning the change is most likely be slightly more complex.
Also lots of models actually defined this buffer directly in their weights instead of instantiating, with unfortunately some downstream differences in generation.

@flozi00
Copy link
Contributor Author

flozi00 commented Jun 30, 2023

image

The purple one is trained with the 3 line fix given in the colab

@iantbutler01
Copy link

@Narsil Just wanted to chime in here and say I'm working on an implementation and PR for this

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 1, 2023

@iantbutler01 let me know if you need support at any point
atm i am focused on training such models rather than integration into tgi

@iantbutler01
Copy link

I've opened a draft PR, #529

I've tested the fixed NTK Aware scaling and it seems to work, I still need to test dynamic scaling and clean up the PR to comply with contributor guidelines, but I wanted to at least start the discussion.

@GemsFord
Copy link

@iantbutler01 does this method only supports LLaMa models? if yes, why did you add the support in flash_rw_modeling.py?

@iantbutler01
Copy link

@GemsFord The method should work for any model using rotary embeddings, its agnostic. My main use is for Falcon 40bn which I've been running locally and testing these changes with.

@GemsFord
Copy link

@iantbutler01 Thanks for adding the support for Falcon, I use that too that's why I asked. I am waiting for your PR to get merged.

@iantbutler01
Copy link

Yup, I plan to clean this up and make it ready for review this weekend. I was on vacation and now catching back up with my work, but I will have time this weekend. As far as I can tell the implementation works so it's just a matter of cleaning up and then going through review feedback.

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 13, 2023

@iantbutler01
Copy link

Nice, I don't think that effects this work unless they implemented it in a flash attention enabled module. I'll definitely check it out to make sure my implementation here is correct though

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 13, 2023

Most interesting is the dynamic ntk aware rope being added
Maybe an option for tgi too adding the dynamic version ?

@iantbutler01
Copy link

That's already in my PR :D

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 13, 2023

Great 😀👍

@keelezibel
Copy link

Hi, any instructions on how to use this after PR is merged? Also, I was thinking why there would be a desync between transformers lib and this repo since it would be too expensive to run LLMs without an inference server and instantiating an instance using the transformers lib alone.

@iantbutler01
Copy link

@Narsil I've updated the PR to remove draft status, I think I'm ready for review, just pinging you because you were the earliest responder from HF on this thread.

@andreaskoepf
Copy link

andreaskoepf commented Jul 24, 2023

The associated PR #529 seems to add post-hoc RoPE scaling (for models trained without scaling). Now that linear & dynamic rope scaling got merged into transformers (huggingface/transformers#24653) more models will be fine-tuned with scaled RoPE. For example we (open-assistant) uploaded today a first experiment llama2-13b-orca-8k-3319 which was fine-tuned with 8k context with simple linear scaling, it has in the config.json and can be used out of the box with transformers 4.31.0:

  "rope_scaling": {
    "factor": 2.0,
    "type": "linear"
  },

Will support for these kinds of fine-tuned models also be added to TGI? Will a separate PR be required for this?

Currently those models can simply be loaded with TGI but since the rope-scaling is not active the output is gibberish. Until rope-scaled models are supported it might be good to generate an error or warning when rope_scaling is not None in the model's configuration.

Or will the rope-scaling of the HF transformers llama impl automatically be used one the TGI transformers dependency in requirements.txt is updated (currently it is still transformers==4.29.2)?

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

Two separate things, but we'll align with that yes.

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

@andreaskoepf Can you provide an example where the rope scaling fails ?

I'm trying few dummy examples, but I'm not sure if what I'm doing is correct or not as the model output doesn't seem particularly bad either way (I'm guessing I'm not entering large enough prompts)

@Narsil Narsil mentioned this issue Jul 31, 2023
5 tasks
@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

@andreaskoepf the PR linked should fix it.

Narsil added a commit that referenced this issue Jul 31, 2023
# What does this PR do?


- Adds Rope NTK scaling.

Done because
#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
@yadamonk
Copy link

@Narsil So we can now use models like the llama2 orca 8k mentioned by @andreaskoepf?

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

You should be able to !

I was able to get coherent results on prompts of 6k on that model.
I'm still waiting on confirmation that knows expectation from that particular model (my references to test are on llama v1-7b non finetuned, that I'm sure works, for the finetuned the output looks OK but without any reference points to compare to it's kind of hard)

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

I tried to test using gptq weights, on v1.0 everything is fine, with the latest container

File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 78, in serve
    server.serve(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 184, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 136, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 185, in get_model
    return FlashLlama(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_llama.py", line 67, in __init__
    model = FlashLlamaForCausalLM(config, weights)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 456, in __init__
    self.model = FlashLlamaModel(config, weights)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 394, in __init__
    [

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 395, in <listcomp>
    FlashLlamaLayer(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 331, in __init__
    self.self_attn = FlashLlamaAttention(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 206, in __init__
    self.query_key_value = TensorParallelColumnLinear.load_multi(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/layers.py", line 264, in load_multi
    weight = weights.get_multi_weights_col(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 134, in get_multi_weights_col
    bits, groupsize = self._get_gptq_params()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 220, in _get_gptq_params
    raise e

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 213, in _get_gptq_params
    bits = self.get_tensor("gptq_bits").item()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 66, in get_tensor
    filename, tensor_name = self.get_filename(tensor_name)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 53, in get_filename
    raise RuntimeError(f"weight {tensor_name} does not exist")

RuntimeError: weight gptq_bits does not exist
 rank=0
Error: ShardCannotStart

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

What model is that ?

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

flozi00/Llama-2-7b-german-assistant-v2-4bit-autogptq

The only commit touched that part of code is #738 after the 1.0 release

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

Should be ok after this, could you confirm ?

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

Another issue found

def _create_inv_freq(dim, base, device):

defined here

https://github.com/huggingface/text-generation-inference/blob/15fc64668f8d3dd407768286e5a0536aeb78c2e1/server/text_generation_server/utils/layers.py#L486C24-L486C39
used here and not accessible from the other class

so dynamic scaling is not working and raise function not defined error, linear scaling with quantized model is working. I can see that it has problems with the stop tokens now, so the model makes whole conversations, but i think that can be solved by some configuration

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

solving that typo here

#745

@Narsil
Copy link
Collaborator

Narsil commented Jul 31, 2023

Shoot I just merge my PR which is the same :)

Edit: accepted yours so you'll end up in contributors !
Thanks.

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

Thanks a lot :)
I love that, at most huggingface projects the core team is so fast 🚀

@flozi00
Copy link
Contributor Author

flozi00 commented Jul 31, 2023

I can confirm, dynamic is working now too

@MUZAMMILPERVAIZ
Copy link

MUZAMMILPERVAIZ commented Aug 28, 2023

what should be the rope scaling factor for 32k context, 0.125?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
8 participants