Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations #19

Closed
SinanAkkoyun opened this issue Feb 8, 2024 · 16 comments
Closed

Optimizations #19

SinanAkkoyun opened this issue Feb 8, 2024 · 16 comments

Comments

@SinanAkkoyun
Copy link

Hey! Thank you so so much for this repo and great work, this is what the world needs right now, I have been waiting for such a great foundation model for years!

When wanting to use vanilla KV cache (I suppose that's the fastest inference?), I get this error:

/home/ai/.mconda3/envs/metavoice/lib/python3.11/site-packages/df/io.py:9: UserWarning: `torchaudio.backend.common.AudioMetaData` has been moved to `torchaudio.AudioMetaData`. Please update the import path.
  from torchaudio.backend.common import AudioMetaData
Fetching 5 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 38130.04it/s]
number of parameters: 1239.00M
Traceback (most recent call last):
  File "/home/ai/ml/voice/metavoice/metavoice-src/fam/llm/sample.py", line 690, in <module>
    smodel, llm_first_stage, llm_second_stage = build_models(
                                                ^^^^^^^^^^^^^
  File "/home/ai/ml/voice/metavoice/metavoice-src/fam/llm/sample.py", line 565, in build_models
    llm_first_stage = Model(
                      ^^^^^^
  File "/home/ai/ml/voice/metavoice/metavoice-src/fam/llm/sample.py", line 92, in __init__
    self._init_model()
  File "/home/ai/ml/voice/metavoice/metavoice-src/fam/llm/sample.py", line 159, in _init_model
    raise Exception(
Exception: kv_cache only supported for flash attention 2 but found torch_attn inside model!

I would be super grateful for help, thanks!

@vatsalaggarwal
Copy link
Contributor

Hey thanks for doing this... Could you open a PR, and I can have a look?

I believe doing this would resolve #1 as well, and allow CPU/MacOS folks to use it, so it would be a great contribution!

@vatsalaggarwal
Copy link
Contributor

It should also resolve #7 !

@SinanAkkoyun
Copy link
Author

Hi, I think this is a misunderstanding, I did not implement anything, I just tried to use the --use_kv_cache arg provided in the sample.py code

@vatsalaggarwal
Copy link
Contributor

Ah, sorry for the confusion...

So, the fastest inference is the default's we've currently got, NOT "vanilla"... The current default is flash decoding (ref: https://crfm.stanford.edu/2023/10/12/flashdecoding.html)

Are you running this on a NVIDIA GPU?

We do have implementations for non-flash attention based kv-caching (e.g. if you don't have a NVIDIA GPU), which is what should be getting utilized when you try to change use_kv_cache but these haven't been hooked up properly yet, so will require changes.

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented Feb 8, 2024

No problem, thank you for the fast responses! :)

Are you running this on a NVIDIA GPU?

Yes, on a 4090

What is the fastest inference possible? It takes around 8 seconds for me to generate this text as speech:

!!!! USING KV-CACHING ASSUMED TORCH.BFLOAT16
tokens: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1728/1728 [00:08<00:00, 192.42it/s]
batch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.99s/it]
Text: Sir, I successfully managed to move the model parameters to another hard-drive. Should I proceed with process 2?
non-causal batching: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 186.66it/s]
Text: Sir, I successfully managed to move the model parameters to another hard-drive. Should I proceed with process 2?

Saved audio to ...

For text LLMs, a 1B model can be as fast as 300 tokens/second with flash attention, is it possible to optimize this model too for low latency?

Thank you!

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented Feb 8, 2024

In fact, it always takes around 8 seconds, no matter the input text length

@vatsalaggarwal
Copy link
Contributor

100% - lots of optimisations to be done here!

re your questions:

  • yes, we generate a fixed number of tokens and don't stop synthesis when we see the "end of generation" token, so it takes the same time regardless of input text length... this could be a good thing to fix at beginning if someone here gets the chance!
  • we can also do batching, so synthesising multiple sentences takes a similar time to synthesising one

@sidroopdaska
Copy link
Contributor

@SinanAkkoyun, out of interest, what are you building with TTS?

@SinanAkkoyun
Copy link
Author

@vatsalaggarwal
Thank you for the response!
I implemented EOT stopping, here is the PR: #29

(I encountered hallucinations with "How?" and some short sentence prompts when cloning a different voice embedding, the EOT token is sometimes not being generated, which allows hallucinations. But, as soon as you release fine-tuning, I see no issue in that.)

Batching is nice, but I am building at a real-time voice assistant for my visualization company (@sidroopdaska thanks for your interest); 11labs offers sub 200ms responses, which are necessary for a long sentence, which I would love to see with your model! :)

I did not look all too much into the arch, is the following possible?:

  • Autoregressive streaming decoding of the LLM output instead of waiting for all tokens to be generated
  • Quantization for model weights (GPTQ implementation, which LLM model arch do you use?)

With those, the model should be able to be extremely low latency and extremely high quality. I will do my best to work with you on that
I am happy to hear back from you both

@SinanAkkoyun
Copy link
Author

I am excited to collaborate further, feel free to connect with me on Discord sinan2 or let me know if there's another platform you prefer, I am eager to dive in and achieve great things together!

@SinanAkkoyun SinanAkkoyun changed the title FlashAttn 2 not working Optimizations Feb 9, 2024
@vatsalaggarwal
Copy link
Contributor

vatsalaggarwal commented Feb 9, 2024

Very nice, thanks, will check it out when I get a sec! Are you able to share the reference you had issue with? I might be able to help!

Yeah, totally get your point about streaming latency.


So, we've got 4 models stacked on top of each other: I) causal LLM 1B, 2) non-causal LLM 15Mn, 3) MultiBand Diffusion, 4) DeepFilterNet.

The causal LLM 1B is naturally streamable, however the current implementation is not faster than real-time, so as you suggested that would need to be improved.

The non-causal LLM 15Mn is super fast, and we also have a streamable version we can push after we've finished testing it.

MultiBand diffusion is supposed to support streaming by default but needs to played with to enable this.

DeepFilterNet is a super tiny model and shouldn't be a problem I think.


In terms of quantisation, and the backbone of the architecture. It's roughly a GPT2 with some changes:

  1. activations from GELU to SwiGLU
  2. LayerNorm to RMSNorm
  3. we add speaker embeddings on top of token embeddings and positional embeddings
  4. (not relevant to quantisation I think) we use classifier free guidance technique to boost the timbre matching of the speaker reference. This means we do two forward passes (in parallel) for each sentence, combine the resulting logits, sample, and repeat autoregressively.

@SinanAkkoyun
Copy link
Author

SinanAkkoyun commented Feb 9, 2024

I've sent you the reference and output WAV to your e-mail. I thank you for the support but with fine-tuning this should be no problem anymore! Can you say when you will release the fine-tuning script?


That sounds very promising! I am looking very forward to the streaming capabilities, those would be phenomenal given your speech quality! Then, quantization would not be needed anymore and one can rely on precise bf16 outputs


Thank you very much, may I ask why you did not use the Llama2 architecture as a base (it utilizes SwiGLU and RMSNorm) and how long did the training take for over 100k hours of speech?

@vatsalaggarwal
Copy link
Contributor

I've sent you the reference and output WAV to your e-mail. I thank you for the support but with fine-tuning this should be no problem anymore! Can you say when you will release the fine-tuning script?

thanks! we're still trying to work it out as we have a few things to do... i think it's likely someone could peft/lora working quicker than we'll be able to push this, but we'll try our best...

That sounds very promising! I am looking very forward to the streaming capabilities, those would be phenomenal given your speech quality! Then, quantization would not be needed anymore and one can rely on precise bf16 outputs

awesome!

Thank you very much, may I ask why you did not use the Llama2 architecture as a base (it utilizes SwiGLU and RMSNorm) and how long did the training take for over 100k hours of speech?

yeah, he had to swap in/out components of llama2 to make the various pieces work properly... will try to write about it! on training time, it depended on the number of GPUs and type of GPUs, so also hard to answer that one...

@vatsalaggarwal
Copy link
Contributor

By the way @SinanAkkoyun i sent you a dm request on discord

@SinanAkkoyun
Copy link
Author

Great, thanks!

@jamesbiederbeck
Copy link

@SinanAkkoyun, out of interest, what are you building with TTS?

Personally, my wife and I just want to do free local tts for chrome text reading. Textbooks for school mostly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants