Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle n_kv_heads for fused layers #54

Closed
casper-hansen opened this issue Sep 15, 2023 · 0 comments
Closed

Handle n_kv_heads for fused layers #54

casper-hansen opened this issue Sep 15, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@casper-hansen
Copy link
Owner

casper-hansen commented Sep 15, 2023

Currently, the TinyLlama model fails with fused layers because the shapes are not set up correctly for that specific model due to not handling n_kv_heads.

The general calculation seems to be:
n_tokens*(n_heads+(n_kv_heads*2))*(hidden_size // n_heads)

This should fix problems for all models with GQA.

EDIT: Potentially also fix shapes during forward of GEMM.

print(x.reshape(-1, x.shape[-1]).shape[1] / self.scales.shape[0]) is 4 for TinyLlama, needs to be multiple of 32

EDIT: Probably best to take inspiration from Llama / CodeLlama

No kv_heads:
https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L76

With kv_heads:
https://github.com/facebookresearch/codellama/blob/427d6ac90f0b7db206bc4c62f4c5d38f92ca4d10/llama/model.py#L90

@casper-hansen casper-hansen added the enhancement New feature or request label Sep 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant