-
Notifications
You must be signed in to change notification settings - Fork 30.6k
[generate] Always use decoder config to init cache #40772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0d88c0d
to
dd57459
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
if requires_cross_attention_cache: | ||
cross_attention_cache_kwargs = { | ||
"config": self.config.get_text_config(encoder=True), | ||
"config": self.config.get_text_config(decoder=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙈 🙈 🙈 past self is derp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry? This one should not be changed no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we want to use the decoder cache! I had the exact same thought [use encoder config] in a recent PR :D
In a nutshell: the config is used here to
- determine which type of layers are used on cross attention
- the number of layers of cross attention we have
Cross attention is a layer in the 👉 decoder 👈 model -- it's the attention between the encoder outputs and the data (tokens on llm) being decoded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh indeed! Easy to get confused with the name of the cache (EncoderDecoderCache) haha!
# If a config is passed, use it to infer the layer types and initialize accordingly | ||
if config is not None: | ||
config = config.get_text_config() | ||
config = config.get_text_config(decoder=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could that somehow clash if we pass an encoder config to it elsewhere? I.e. for encoder/decoder cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always want the decoder config for KV cache purposes
The encoder is only used once to get the encoder outputs, which are then autoregressivelly used in the decoder with cross attention. Both self-attention and cross-attention are decoder layers, and thus parameterized by the decoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!! Indeed, thanks for the fix and the added explanation!
Merging it as an unrelated test is failing, and this is ready to go!
* mega derp * fix * always use the decoder
* mega derp * fix * always use the decoder
* mega derp * fix * always use the decoder
What does this PR do?
(see title)
Fixes #40644