Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen.SelfAttention] Enable auto-regressive decoding for prompt length > 1 #1316

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented May 13, 2021

What does this PR do?

I think currently it is not possible to do auto-regressive decoding with flax.linen.SelfAttention if the input prompt is longer than 1. E.g., the following code works fine if the prompt_length is 1, but fails if prompt_length is 2, see this issue: #1317.

This PR corrects this behavior by allowing the input length to be > 1 the very first time cache is used.

For decoder-only auto-regressive models, such as GPT1-3 it is very important to be able to pass input prompts > 1. Maybe, this is also already possible and I'm using it incorrectly?

If this PR looks like a good solution to you, then I'm more than happy to add a couple of tests to check the behavior @avital @marcvanzee

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion (please add a
    link).
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@google-cla google-cla bot added the cla: yes label May 13, 2021
@patrickvonplaten patrickvonplaten changed the title finish [linen.SelfAttention] Enable auto-regressive decoding for prompt length > 1 May 13, 2021
@avital
Copy link
Contributor

avital commented May 13, 2021

@levskaya do you have time to take a look at this PR improving cached decoding?


# update cache index to overwrite
updated_cache_indices = query.shape[1] if cur_index == 0 else 1
cache_index.value = cache_index.value + updated_cache_indices
# causal mask for cached decoder self-attention:
# our single query position should only attend to those key
# positions that have already been generated and cached,
# not the remaining zero elements.
mask = combine_masks(
mask,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might actually be a bit tricker here since mask's shape doesn't always fit, e.g. if a prompt of length 6 is passed with a mask of shape [1, 1, 6, 6] (batch_size=1, num_heads=1, length=6), and max_length is 20 we would try to combine the tensors [1, 1, 6, 6] with [1, 1, 6, 20] -> which could lead to problems

@marcvanzee marcvanzee linked an issue May 14, 2021 that may be closed by this pull request
Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! So in practice we probably want to handle the cache initialization in the sampler-loop itself or a separate mode rather than having shape-dependent or even value-dependent logic in the core attention layer.

There's 2 ways people handle "prompts" in autoregressive decoders:

  • the fast way: of taking the initial prompt region and just running through token by token forcing the decoding to always choose the pre-existing value. This still imposes autoregressive/causal structure on the keys and values generated along the prompt region, but it can be implemented with a tiny change to the core decoder/sampling algorithm - in fact in the "lm1b" example the temperature sampler implements this "shortcut" logic: https://github.com/google/flax/blob/master/examples/lm1b/temperature_sampler.py#L101
  • the more general way: introduce a third "mode" of operation for the attention layer and the transformer. So right now there's "normal" and "decode", but you can have a third "cache-init" mode, where you run the transformer normally across the prompt/prefix region but you still stuff the keys and values for that region into the cache, and update the cache-index to the length of the prefix-region. Though this takes a bit more code, it's more efficient even for causal forced decoding and it also allows one to use non-causal attention within the prompt region, effectively treating that as a small transformer-encoder in practice. Most people refer to this as a "prefix-LM".

I think we'd be happy to accept a patch to add a prefix-LM init mode, or I might be able to get to it myself sometime in the foreseeable future. However a lot of people working on pure decoders are often happy with the results of a causal forced-decoding approach, which is already supported.

Just let us know if either of these approaches would work for you, or if I could be more clear on some point here. :)

if expected_shape != query.shape:
cur_index = cache_index.value
# first input prompt can have sequence length > 1
if cur_index > 0 and cache_expected_shape != query.shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is cache_expected_shape defined?

if expected_shape != query.shape:
cur_index = cache_index.value
# first input prompt can have sequence length > 1
if cur_index > 0 and cache_expected_shape != query.shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cur_index > 0 <-- we can't use value-dependent python control flow in JAX, as this value is an abstract tracer during trace+compilation.

@jheek jheek added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Jun 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AutoRegressive Decoding currently fails if input prompt > 1
4 participants