Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any tips for speeding up generation? #21

Open
pabloppp opened this issue Mar 11, 2021 · 14 comments
Open

Any tips for speeding up generation? #21

pabloppp opened this issue Mar 11, 2021 · 14 comments

Comments

@pabloppp
Copy link

Because of the autoregressive nature of Transformers, I know that they are fairly slow when generating new sequences from scratch, but I was wondering if you had any tips or tricks on how to do faster inference or to know if you had plans for maybe adding some of the tricks to avoid full computation, like the ones used by Huggingface https://huggingface.co/blog/accelerated-inference

Thank you very much for your amazing work!

@lucidrains
Copy link
Owner

@pabloppp Oh hey Pablo! are you using the repository in production? Yea I can make the inference fast (by adding caching of key / values, standard practice)

@pabloppp
Copy link
Author

pabloppp commented Mar 12, 2021

That would be awesome.
What I have tried to speed up the inference in my custom implementations for autoregressive self-attention is caching the output of the self-attention at timestep T and then, in timestep T+1, passing the full keys/values but only passing the last element of the query sequence, then getting the output and concatenating it with the cache, that way each query can pay attention to the full previous sequence but we don't need to compute attention for all the previous queries when we only need the output at T+1
It looks something like this:
Captura de pantalla 2021-03-12 a las 17 01 32
But I only achieved a x3 speedup 🤔

I actually needed to perform autoregressive inference in a very large dataset, and it was taking more than 1 day even with the above speedup. I am currently doing some weird custom stuff, keeping the Transformer attention layers but replacing the self-attention layers with LSTMs, which are way faster at generating sequences token by token, and with that I achieve the x10 speedup that I needed.

@lucidrains
Copy link
Owner

@pabloppp the fastest speedup you'll get is to train a vanilla transformer, and then fine-tune it with Performer linear attention https://github.com/lucidrains/performer-pytorch that's probably the penultimate trick

@pabloppp
Copy link
Author

What do you mean by 'fine-tune'? Training a vanilla transformer, then replacing the attention layers with performer attention layers and do some more training?

@lucidrains
Copy link
Owner

Yes exactly!

@pabloppp
Copy link
Author

I will try that, thanks! Any idea about what could be the expected speedup?

@lucidrains
Copy link
Owner

In short, it will be as fast as if you had an RNN

@lucidrains
Copy link
Owner

https://arxiv.org/abs/2006.16236

@stas-sl
Copy link
Contributor

stas-sl commented Mar 27, 2021

@lucidrains thanks for your awesome work! Can you explain a bit, why not training performers from scratch, why you recommend to train vanilla and then finetune?

@tomweingarten
Copy link

@stas-sl Performers scale very efficiently at longer sequence lengths (roughly 1500+), but they lose that advantage for short sequences. This is especially true for the softmax Performer, which is the version that's directly compatible with vanilla Transformers. For the softmax Performer, the constant costs of calculating the attention can cause it to be even slower than a Transformer during training. Hope that helps!

@lucidrains
Copy link
Owner

@stas-sl What Tom said :)

@pabloppp relevant to your interests https://arxiv.org/abs/2103.13076

@pabloppp
Copy link
Author

Awesome, thanks!

@DLExpert
Copy link

DLExpert commented Apr 9, 2021

@lucidrains > I can make the inference fast (by adding caching of key / values, standard practice)

Please can you help me how to make the inference fast (by adding caching of key / values, standard practice)?

@cloudhan
Copy link

I am curious why key value can be cached? Doesn't key and value is globally changed, except for the first decoder layer, after a new id is produced?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants