Skip to content

Remove cache_position in more models (3)#44759

Merged
Cyrilvallez merged 19 commits intomainfrom
cache-position
Mar 18, 2026
Merged

Remove cache_position in more models (3)#44759
Cyrilvallez merged 19 commits intomainfrom
cache-position

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Mar 16, 2026

What does this PR do?

Follow-up of many related PR, last one in time being #44602.

This PR completes all the models that may need non-trivial treatment. Only about 30-40 models still have mentions of cache_position, and those are trivial arg forwarding/cache update. They will be extremely easy to remove in #44667 (or any other PR)

@Cyrilvallez Cyrilvallez changed the title start on the mambas Remove cache_position in more models (3) Mar 16, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

I think it's overlapping with #44667 😅 Saw an opportunity and started deleting last week

@Cyrilvallez
Copy link
Member Author

Ha indeed 😬 Though from what I'm seeing, I don't think the change are fully correct on #44667, at least for the mambas! If you don't mind, can we keep this PR at least for the mambas and a few audio models which need special treatments? I believe most of the rest is very standard, and everything can be simply erased

@zucchini-nlp
Copy link
Member

@Cyrilvallez ah yeah, the PR is still at fix-and-replace stage and I didn't have time to check out mambas and audio models. Those need special treatment

I am fine with merging this PR first with mambas and I'll rebase the second one later. Does that sound good?

@huggingface huggingface deleted a comment from github-actions bot Mar 17, 2026
@Cyrilvallez
Copy link
Member Author

run-slow: kyutai_speech_to_text mamba encoder_decoder clipseg falcon_h1 clvp xglm x_clip gptj musicgen_melody musicgen recurrent_gemma vision_encoder_decoder csm shieldgemma2 speech_to_text ctrl moshi chameleon zamba2 nemotron_h zamba owlvit umt5 groupvit falcon_mamba cpmant owlv2 falcon mamba2 codegen

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 0961e31a workflow commit (merge commit)
PR db7035e5 branch commit (from PR)
main bbe251a4 base commit (on main)

⚠️ No test being reported (jobs are skipped or cancelled)!

@github-actions
Copy link
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/chameleon", "models/clipseg", "models/clvp", "models/codegen", "models/cpmant", "models/csm", "models/ctrl", "models/encoder_decoder", "models/falcon", "models/falcon_h1", "models/falcon_mamba", "models/gptj", "models/groupvit", "models/kyutai_speech_to_text", "models/mamba", "models/mamba2", "models/moshi", "models/musicgen", "models/musicgen_melody", "models/nemotron_h", "models/owlv2", "models/owlvit", "models/recurrent_gemma", "models/shieldgemma2", "models/speech_to_text", "models/umt5", "models/vision_encoder_decoder", "models/x_clip", "models/xglm", "models/zamba", "models/zamba2"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN a2a93b4a workflow commit (merge commit)
PR aabc9f73 branch commit (from PR)
main af93d384 base commit (on main)

⚠️ Model CI failed to report results

The test failure analysis could not be completed. Please check the workflow run for details.

@Cyrilvallez
Copy link
Member Author

run-slow is not working, but I personally checked that all IntegrationTests were similar on this PR and on main! So all good!

@huggingface huggingface deleted a comment from github-actions bot Mar 18, 2026
@huggingface huggingface deleted a comment from github-actions bot Mar 18, 2026
Comment on lines +140 to +145
if cache_init:
self.conv_states[layer_idx].copy_(new_conv_state)
else:
conv_state = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
conv_state[:, :, -1:] = new_conv_state
self.conv_states[layer_idx].copy_(conv_state)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu this is what I discussed with you offline. This is technically not correct here and in some other mambas, however should be fine in practice. So I simplified this to remove the useless usage of cache_position and align with mamba2. The behavior is the same as before (technically wrong), but let's fix later when we refactor mamba caches as it should be fine in practice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, let's add a small comment there tho so others are aware

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful approval, mostly nits but a few smaller questions which might be relevant

def prepare_inputs_for_generation(
self,
decoder_input_ids,
next_sequence_length: int | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure, this is the correct order, not that passing args here would mess things up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I mirrored the general prepare_inputs_for_generation here. Though we always pass next_sequence_length as kwarg in generate anyway to avoid those issues 👌

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in t5 I suppose? Just that this did not have any copies or similar anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exact!

end_idx = (model_inputs["cache_position"][-1] + 1) * self.config.downsample_factor
past_key_values = model_inputs.get("past_key_values")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
current_seq_len = model_inputs.get("position_ids").shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that it's shape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as it was already truncated in the super() call. We could use the main input (i.e. input_ids/inputs_embeds), but as it can change between the 2, easier to use the pos_ids which are always there

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: chameleon, clipseg, clvp, codegen, cpmant, csm, ctrl, encoder_decoder, falcon, falcon_h1, falcon_mamba, gptj, groupvit, kyutai_speech_to_text, mamba, mamba2

@Cyrilvallez Cyrilvallez merged commit 83a6c5b into main Mar 18, 2026
29 checks passed
@Cyrilvallez Cyrilvallez deleted the cache-position branch March 18, 2026 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants