Remove cache_position in more models (3)#44759
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I think it's overlapping with #44667 😅 Saw an opportunity and started deleting last week |
|
Ha indeed 😬 Though from what I'm seeing, I don't think the change are fully correct on #44667, at least for the mambas! If you don't mind, can we keep this PR at least for the mambas and a few audio models which need special treatments? I believe most of the rest is very standard, and everything can be simply erased |
|
@Cyrilvallez ah yeah, the PR is still at fix-and-replace stage and I didn't have time to check out mambas and audio models. Those need special treatment I am fine with merging this PR first with mambas and I'll rebase the second one later. Does that sound good? |
|
run-slow: kyutai_speech_to_text mamba encoder_decoder clipseg falcon_h1 clvp xglm x_clip gptj musicgen_melody musicgen recurrent_gemma vision_encoder_decoder csm shieldgemma2 speech_to_text ctrl moshi chameleon zamba2 nemotron_h zamba owlvit umt5 groupvit falcon_mamba cpmant owlv2 falcon mamba2 codegen |
|
This comment contains models: ["models/chameleon", "models/clipseg", "models/clvp", "models/codegen", "models/cpmant", "models/csm", "models/ctrl", "models/encoder_decoder", "models/falcon", "models/falcon_h1", "models/falcon_mamba", "models/gptj", "models/groupvit", "models/kyutai_speech_to_text", "models/mamba", "models/mamba2", "models/moshi", "models/musicgen", "models/musicgen_melody", "models/nemotron_h", "models/owlv2", "models/owlvit", "models/recurrent_gemma", "models/shieldgemma2", "models/speech_to_text", "models/umt5", "models/vision_encoder_decoder", "models/x_clip", "models/xglm", "models/zamba", "models/zamba2"] |
CI ResultsCommit Info
The test failure analysis could not be completed. Please check the workflow run for details. |
|
|
| if cache_init: | ||
| self.conv_states[layer_idx].copy_(new_conv_state) | ||
| else: | ||
| conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) | ||
| conv_state[:, :, -1:] = new_conv_state | ||
| self.conv_states[layer_idx].copy_(conv_state) |
There was a problem hiding this comment.
@vasqu this is what I discussed with you offline. This is technically not correct here and in some other mambas, however should be fine in practice. So I simplified this to remove the useless usage of cache_position and align with mamba2. The behavior is the same as before (technically wrong), but let's fix later when we refactor mamba caches as it should be fine in practice!
There was a problem hiding this comment.
Yup, let's add a small comment there tho so others are aware
vasqu
left a comment
There was a problem hiding this comment.
Careful approval, mostly nits but a few smaller questions which might be relevant
| def prepare_inputs_for_generation( | ||
| self, | ||
| decoder_input_ids, | ||
| next_sequence_length: int | None = None, |
There was a problem hiding this comment.
Just to be sure, this is the correct order, not that passing args here would mess things up
There was a problem hiding this comment.
Yes, I mirrored the general prepare_inputs_for_generation here. Though we always pass next_sequence_length as kwarg in generate anyway to avoid those issues 👌
There was a problem hiding this comment.
Same as in t5 I suppose? Just that this did not have any copies or similar anymore?
| end_idx = (model_inputs["cache_position"][-1] + 1) * self.config.downsample_factor | ||
| past_key_values = model_inputs.get("past_key_values") | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| current_seq_len = model_inputs.get("position_ids").shape[-1] |
There was a problem hiding this comment.
Yes, as it was already truncated in the super() call. We could use the main input (i.e. input_ids/inputs_embeds), but as it can change between the 2, easier to use the pos_ids which are always there
|
[For maintainers] Suggested jobs to run (before merge) run-slow: chameleon, clipseg, clvp, codegen, cpmant, csm, ctrl, encoder_decoder, falcon, falcon_h1, falcon_mamba, gptj, groupvit, kyutai_speech_to_text, mamba, mamba2 |
What does this PR do?
Follow-up of many related PR, last one in time being #44602.
This PR completes all the models that may need non-trivial treatment. Only about 30-40 models still have mentions of
cache_position, and those are trivial arg forwarding/cache update. They will be extremely easy to remove in #44667 (or any other PR)