-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[linen.SelfAttention] Enable auto-regressive decoding for prompt length > 1 #1316
[linen.SelfAttention] Enable auto-regressive decoding for prompt length > 1 #1316
Conversation
@levskaya do you have time to take a look at this PR improving cached decoding? |
|
||
# update cache index to overwrite | ||
updated_cache_indices = query.shape[1] if cur_index == 0 else 1 | ||
cache_index.value = cache_index.value + updated_cache_indices | ||
# causal mask for cached decoder self-attention: | ||
# our single query position should only attend to those key | ||
# positions that have already been generated and cached, | ||
# not the remaining zero elements. | ||
mask = combine_masks( | ||
mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might actually be a bit tricker here since mask
's shape doesn't always fit, e.g. if a prompt of length 6 is passed with a mask
of shape [1, 1, 6, 6] (batch_size=1, num_heads=1, length=6), and max_length
is 20 we would try to combine the tensors [1, 1, 6, 6] with [1, 1, 6, 20] -> which could lead to problems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay! So in practice we probably want to handle the cache initialization in the sampler-loop itself or a separate mode rather than having shape-dependent or even value-dependent logic in the core attention layer.
There's 2 ways people handle "prompts" in autoregressive decoders:
- the fast way: of taking the initial prompt region and just running through token by token forcing the decoding to always choose the pre-existing value. This still imposes autoregressive/causal structure on the keys and values generated along the prompt region, but it can be implemented with a tiny change to the core decoder/sampling algorithm - in fact in the "lm1b" example the temperature sampler implements this "shortcut" logic: https://github.com/google/flax/blob/master/examples/lm1b/temperature_sampler.py#L101
- the more general way: introduce a third "mode" of operation for the attention layer and the transformer. So right now there's "normal" and "decode", but you can have a third "cache-init" mode, where you run the transformer normally across the prompt/prefix region but you still stuff the keys and values for that region into the cache, and update the cache-index to the length of the prefix-region. Though this takes a bit more code, it's more efficient even for causal forced decoding and it also allows one to use non-causal attention within the prompt region, effectively treating that as a small transformer-encoder in practice. Most people refer to this as a "prefix-LM".
I think we'd be happy to accept a patch to add a prefix-LM init mode, or I might be able to get to it myself sometime in the foreseeable future. However a lot of people working on pure decoders are often happy with the results of a causal forced-decoding approach, which is already supported.
Just let us know if either of these approaches would work for you, or if I could be more clear on some point here. :)
if expected_shape != query.shape: | ||
cur_index = cache_index.value | ||
# first input prompt can have sequence length > 1 | ||
if cur_index > 0 and cache_expected_shape != query.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is cache_expected_shape
defined?
if expected_shape != query.shape: | ||
cur_index = cache_index.value | ||
# first input prompt can have sequence length > 1 | ||
if cur_index > 0 and cache_expected_shape != query.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_index > 0
<-- we can't use value-dependent python control flow in JAX, as this value is an abstract tracer during trace+compilation.
What does this PR do?
I think currently it is not possible to do auto-regressive decoding with
flax.linen.SelfAttention
if the input prompt is longer than 1. E.g., the following code works fine if theprompt_length
is 1, but fails ifprompt_length
is 2, see this issue: #1317.This PR corrects this behavior by allowing the input length to be > 1 the very first time
cache
is used.For decoder-only auto-regressive models, such as GPT1-3 it is very important to be able to pass input prompts > 1. Maybe, this is also already possible and I'm using it incorrectly?
If this PR looks like a good solution to you, then I'm more than happy to add a couple of tests to check the behavior @avital @marcvanzee
Checklist
checks if that's the case).
discussion (please add a
link).
documentation guidelines.
(No quality testing = no merge!)