Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group kv support and fix past kv from cache #2263

Merged
merged 4 commits into from
Aug 23, 2023

Conversation

siddartha-RE
Copy link
Contributor

@siddartha-RE siddartha-RE commented Aug 18, 2023

Why are these changes needed?

Primary goal is to enable support for group key-value in Llama2 models, specifically the 70B model.
In addition

  • The code simplifies flash-attn usage by avoiding most transpose operation
  • Avoids stacking when possible using the unstacked variants
  • FIxes behavior in the presence of caching, KV lengths can be different from Q lengths

Related issue number (if applicable)

Addresses #2229

Checks

  • I've run format.sh to lint the changes in this PR.
  • I've included any doc changes needed.
  • I've made sure the relevant tests are passing (if applicable).

@tmm1
Copy link
Contributor

tmm1 commented Aug 19, 2023

hi @siddartha-RE, i was looking into this the other day too (inference support for fa2 llama2, i.e. past_kv/cache), and ended up using flash_attn_kvpacked_func. does that seem right to you?

https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py#L204-L209

@siddartha-RE
Copy link
Contributor Author

No there is an issue with alignment. I opened an in flash attention repo. Owner already asked and is planning to adjust behavior. Bump the issue if you want to let him know there is more interest.

This PR has a pretty comprehensive check of correctness.

@merrymercy
Copy link
Member

@siddartha-RE Thanks for the contribution.

If this is comprehensively tested, can we make this one the default implementation and replace the old monkey patch with this one?

@siddartha-RE
Copy link
Contributor Author

The test runs the attention block against

  • base HF implementation
  • current monkey patch
  • new monkey patch

current and new are exactly identical and both differ from HF implementation with tolerance. I also have tested training 70b with this and the loss looks correct so the group kv handling appears to be correct. I could add a test with group kv enabled and compare new against HF to further verify implementation.

FYI -- I have an issue open on Flash-attn code base
Dao-AILab/flash-attention#466
regarding using flash attn for the use cache case. I am considering falling back to the old implementation when past_kv len >> q_len is present as I think it may actually be faster since it would not be quadratic in q_len + past_kv_len

@siddartha-RE
Copy link
Contributor Author

One more comment. I tested the proper optimized implementation of use_cache:
siddartha-RE@6ad40e9

against a build of flash-attn from this PR:
Dao-AILab/flash-attention#436
which fixes the Causal handling of the q_len != k_len case

and confirmed that the test I added in the file still passes.

@merrymercy
Copy link
Member

@siddartha-RE Is this ready for merge? Which one do you prefer?

  1. Merge this first as an optional monkey patch and let more people do the test
  2. Directly replace this line
    replace_llama_attn_with_flash_attn()
    with your implementation in this PR

@siddartha-RE
Copy link
Contributor Author

This is ready for merge. It is well tested so I am comfortable updating train hook but suggest keeping both versions for now so that people can test if they see issues.

I have a change that further fixes the inference usage but will need to wait till the underlying fix in flash attention is released for handling of the causal flag when q != kv.

@merrymercy merrymercy merged commit da2b80b into lm-sys:main Aug 23, 2023
1 check passed
@merrymercy
Copy link
Member

@siddartha-RE Thanks! It is merged.

):
# [bsz, seq_len]
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this concatenation necessary? attention_mask passed in should have shape of (bsz, kv_len), and past_kv_len is already in kv_len

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants