🚨 Generation cache preparation#43679
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # 10. Prefill | ||
| model_inputs.update({"output_attentions": generation_config.output_attentions}) | ||
| model_inputs.update({"output_hidden_states": generation_config.output_hidden_states}) | ||
| outputs = self(**model_inputs, return_dict=True) | ||
|
|
There was a problem hiding this comment.
the model was running prefill two times because self._sample also calls prefill. That caused the first prefill to flush its cache and the second prefill to start over
| logits_processor=prepared_logits_processor, | ||
| stopping_criteria=prepared_stopping_criteria, | ||
| generation_config=generation_config, | ||
| prefill_outputs=outputs, |
There was a problem hiding this comment.
not an expected arg for self._sample
| parent, | ||
| batch_size=4, | ||
| seq_length=128, | ||
| seq_length=12, |
There was a problem hiding this comment.
the model has sliding window of 128 so when generating the cache is cropped to max=128. We need to either override many generation tests to match the expected length, or use a smaller seq length
There was a problem hiding this comment.
Totally fine, this is absurdly high so better to reduce
| super().__init__(config.get_text_config()) | ||
| super().__init__(config) | ||
| self.text_config = config.get_text_config() |
There was a problem hiding this comment.
not a good practice if self.config is only the decoder-subconfig
| def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs) -> Cache: | ||
| def _prepare_static_cache( | ||
| self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs | ||
| ) -> Cache: |
There was a problem hiding this comment.
just a naming change, imo this is more descriptive
vasqu
left a comment
There was a problem hiding this comment.
LGTM, added a few smaller comments but nothing major
Let's run slow tests for the special models: blt, dia, kyutai + general special ones mamba1/2, bamba, etc
src/transformers/generation/utils.py
Outdated
| if backend == "quanto" and not is_optimum_quanto_available(): | ||
| raise ImportError( | ||
| "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto " | ||
| "backend. Please install it via with `pip install optimum-quanto`" | ||
| ) | ||
| elif backend == "HQQ" and not is_hqq_available(): | ||
| raise ImportError( | ||
| "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " | ||
| "Please install it via with `pip install hqq`" | ||
| ) | ||
| model_kwargs["past_key_values"] = QuantizedCache(backend=backend, **cache_config) |
There was a problem hiding this comment.
Imo, this error should be raised within the constructor of the cache - we should not have to check this ourselves here
There was a problem hiding this comment.
actually it is checked inside cache class as well. I don't remember exactly what was the reason to check it here, I'll git blame and see
| parent, | ||
| batch_size=4, | ||
| seq_length=128, | ||
| seq_length=12, |
There was a problem hiding this comment.
Totally fine, this is absurdly high so better to reduce
|
Remark: kinda breaking but in a good way. Prev models init their cache in model.forward without passing config. That means sliding windows weren't always respected (e.g. Afmoe). From now on we always respect sliding window if it is in config Added a 🚨 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, blt, dia, janus, kyutai_speech_to_text |
|
@vasqu can you force merge, CI is super flaky still and I already retriggered 5-6 times 😢 |
What does this PR do?
I also want to see if linear cache thing can be squeezed in this PR. If it requires big diffs, I'll split into two
Fixes #43673
Sidenote: kinda breaking but in a good way. Prev models init their cache in
model.forwardwithout passingconfig. That means sliding windows weren't always respected (e.g. Afmoe or remote models). From now on we always respect sliding window if it is in config