Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nSlices > nKvHeads #70

Open
b4rtaz opened this issue May 27, 2024 · 0 comments
Open

Support nSlices > nKvHeads #70

b4rtaz opened this issue May 27, 2024 · 0 comments

Comments

@b4rtaz
Copy link
Owner

b4rtaz commented May 27, 2024

After the attention layers were splitted into all nodes I missed the implications what it introduced.

image

Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For x Q head I need to have whole floor(x / (nHeads / nKvHeads)) K head to calculate the result.

For example Llama 3 8B:

💡 dim: 128
💡 nHeads: 32
💡 nKvHeads: 8

Q head 0  => floor(  0 / ( 32 / 8) ) => K head 0
Q head 1  => floor(  1 / ( 32 / 8) ) => K head 0
Q head 2  => floor(  2 / ( 32 / 8) ) => K head 0
...
Q head 8  => floor(  8 / ( 32 / 8) ) => K head 2
Q head 9  => floor(  9 / ( 32 / 8) ) => K head 2
...
Q head 31 => floor( 31 / ( 32 / 8) ) => K head 7

By this currently is not possible to split nodes to more than nKvHeads nodes.

^ The same problem is with the V layer.


How this could be fixed?

1. Synchronize missing outputs

For nSlices > nKvHeads setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.

2. Redundancy

The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to headSize. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0).

For example Llama 3 8B:

headSize = dim / nHeads = 128
kvDim = (dim * kvHeads) / nHeads = 1024

nSlices = 16
kvDim0 = kvDim / nSlices = 64
redundancy = 128 - 64 = 64 outputs of K & V

nSlices = 32
kvDim0 = kvDim / nSlices = 32
redundancy = 128 - 32 = 96 outputs of K & V
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant