Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer - High VRAM, context length #39

Open
MarcusLoppe opened this issue Dec 27, 2023 · 4 comments
Open

Transformer - High VRAM, context length #39

MarcusLoppe opened this issue Dec 27, 2023 · 4 comments

Comments

@MarcusLoppe
Copy link
Contributor

MarcusLoppe commented Dec 27, 2023

Hello again, this issue is for next year 😃

When training the transformer, I used the follow config:

transformer = MeshTransformer(
    autoencoder,
    dim = 128,
    attn_depth = 4,
    attn_dim_head = 8,
    attn_heads = 4,
    coarse_pre_gateloop_depth = 1,#6,
    fine_pre_gateloop_depth= 0,#4, 
    max_seq_len = max_seq,
    gateloop_use_heinsen = False,
    condition_on_text = True
)

This resulted in a transformer that was 22M parameters.
I then tried try to train it on a 6206 faces mesh which is 37236 tokens (6206 * 6).
When I feed it the faces codes (1,6206,128) it used about 11GB VRAM and at the end of the forward it used about 20 GB.
If I used a transformer that as 188M (256dim) it used 50GB of VRAM.

My suggestion to implement Sliding-Window Attention / Local attention since most long context LLM uses it and it seems to be working.

Or creating a embedding of the tokens and concating it together with the text conditioner embedding so the cross attention can beware of previous tokens as well.

Also take a look if Grouped-Query Attention is beneficial :)

attended_face_codes, coarse_cache = self.decoder(
                face_codes,
                cache = coarse_cache,
                return_hiddens = True,
                **attn_context_kwargs
            )

@lucidrains
Copy link
Owner

lucidrains commented Dec 27, 2023

@MarcusLoppe yes indeed, you are correct on both accounts. local attention is tricky to handle with kv cache

grouped query attention is also already available in x-transformers, which this lib is using

ok, no more AI stuff until after the new years 😆

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe yes indeed, you are correct on both accounts. local attention is tricky to handle with kv cache

grouped query attention is also already available in x-transformers, which this lib is using

ok, no more AI stuff until after the new years 😆

@lucidrains

Have you given any thought on implementing Flash Attention 2? 😄 Seem like a great benefit to speed up the transformers training times & inference.
https://github.com/kyegomez/FlashAttention20/tree/main

image3

@lucidrains
Copy link
Owner

lucidrains commented Jan 4, 2024

@MarcusLoppe flash attention 2 will make it into the next release of pytorch, so no need!

@claudiomartella
Copy link

will it need code change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants