Skip to content

[loading] Clean way to add/remove full parts in checkpoint names#45448

Open
Cyrilvallez wants to merge 36 commits intomainfrom
fix-clips
Open

[loading] Clean way to add/remove full parts in checkpoint names#45448
Cyrilvallez wants to merge 36 commits intomainfrom
fix-clips

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Apr 15, 2026

What does this PR do?

As per the title.

The issue

The problem is that transforms that want to remove a full part of a model name (such as a prefix, e.g. the model. start) are non bijective in general, i.e. we completely lose the information when they are dropped. So adding them back later when saving is impossible without runtime information about the checkpoint that was used, i.e. we need to know if we had the prefix before or not, we cannot infer it based on anything.

Proposed solution

This PR add a simple mechanism for such things, i.e. WeightTransform have a simple flag to describe if they were used to rename a weight or not. If it is the case, we keep them when we save the Transform on the model (this was already performed before). If not, we drop them, so that they are not used when resaving.
It also introduces the PrefixChange class (a simple class inherited from WeightRenaming) to simplify full addition/removal of full parts, because otherwise the regexes to use in such cases are hard to read/write.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Fix clips [loading] Clean way to add/remove full parts in checkpoint names Apr 16, 2026
Comment on lines 580 to -582

@dataclass(slots=True)
Copy link
Copy Markdown
Member Author

@Cyrilvallez Cyrilvallez Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were dataclasses but it did not make any sense, so removed it (but kept the slots, the only feature we were really using from dataclass - makes it much easier to inherit etc

Comment on lines -571 to +584
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"),
WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"),
WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"),
WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"),
WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"),
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were not uniquely defined - it's the same transforms, but uniquely defined now

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the encoder have only self_attn in sd, I see that part removed from regex?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks sus with the removed self_attn - let's definitely run some slow tests (locally) to double check

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just that they are a bit redundant with the wildcard before - I checked locally and it loads correctly! Can add back though!

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think this is the way to go and kinda matches the suggestion I had previously to iterate over state-dict dynamically. I have a few q to be sure and suggest we move to PrefixMatch in timm/sam models

Comment on lines +1432 to +1434
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but
# only if it was used, i.e. it matched any weight from the checkpoints
model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not create another loophole if some conversions aren't used by model? But prob if the test checks full list of conversions, we can consciously decide to add task-specific patterns like lm_head

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I can follow, do you have a small example where you think we would fall through?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test that checks that all conversions are used (to make sure we don't add useless conversions) does it by checking the keys between original model and saved weights (because from our tiny models, we only know the "correct" keys, and the saved keys mimic the "wrong" keys), so it will still check it correctly!

Comment thread src/transformers/core_model_loading.py
Comment thread tests/utils/test_core_model_loading.py
Comment thread tests/utils/test_core_model_loading.py
Comment on lines +833 to +841
# Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached
# and the prefix should not be added when saving directly (i.e. the conversion should be dropped)
model = DummyRoot()
saved_state_dict = revert_weight_conversion(model, model.state_dict())
model_state_dict = model.state_dict()
self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys()))
for k, v in saved_state_dict.items():
self.assertTrue((v == model_state_dict[k]).all())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be clear, let's take a llava example. User creates random llava arch, trains it and saves without prefix. Then if the model was loaded from_pretrained, we still match keys right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user creates llava from scratch, then train, then save, the saved weights will be the same as the arch weights, so without prefix. Then when reloading those checkpoints, and resaving, it will stays the same!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh right, we are removing the prefix and not adding. I would like us to re-check it with prefix_to_add in subsequent PRs for timm/sam/etc

Comment thread tests/test_modeling_common.py
Comment thread src/transformers/conversion_mapping.py
Comment on lines -571 to +584
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"),
WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"),
WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"),
WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"),
WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the encoder have only self_attn in sd, I see that part removed from regex?

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial thoughts from my side as well. I think this is definitely needed especially also looking at the rf detr addition - would be nice if we could sync both problems together

Comment on lines -571 to +584
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"),
WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"),
WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"),
WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"),
WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"),
WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"),
WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"),
WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"),
WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks sus with the removed self_attn - let's definitely run some slow tests (locally) to double check

Comment thread src/transformers/conversion_mapping.py
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/core_model_loading.py Outdated
Comment on lines +1432 to +1434
# Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but
# only if it was used, i.e. it matched any weight from the checkpoints
model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I can follow, do you have a small example where you think we would fall through?

Comment thread tests/utils/test_core_model_loading.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, it would make sense to have some integration / slow tests with real models as well - llava (removal case), rf detr (future model) (add case)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed this would be nice - we can maybe even do it on fast CI with a small enough real model

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind if I do it on the next PR where I'll switch timm and other to this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, fine with having another PR for this

@Cyrilvallez
Copy link
Copy Markdown
Member Author

Thanks @zucchini-nlp @vasqu! I answered everything.
We definitely need to switch timm (and I think a few other as well) to this, but I wanted to do it in another PR to avoid too much friction (simplify reversal if I end-up messing it up, and simplify testing as well) - see also my comment here #45448 (comment)

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: altclip

@albertvillanova
Copy link
Copy Markdown
Member

albertvillanova commented Apr 17, 2026

Thanks for addressing this issue, @Cyrilvallez.

I have tested trl using your PR, and unfortunately it seems there is the same renaming issue I mentioned in the precedent PR: #45361 (comment)

However, it looks like there is still a change in parameter naming: the vision_model nesting was eliminated (as I commented in the trl issue: huggingface/trl#5497 (comment)).

For example:

  • model.vision_tower.vision_model.encoder.layers.1.self_attn.k_proj.weight before
  • model.vision_tower.encoder.layers.1.self_attn.k_proj.weight now

Is this renaming intended? This is the reason why our tests are still failing, as they rely on the previous name nesting structure.

@Cyrilvallez
Copy link
Copy Markdown
Member Author

@albertvillanova The renaming is fully intended in order to load the weights. Since the module was removed, the weights need to be renamed to match the actual module graph. However, if you resave the weights afterwards, the name should be reverted back to the same as initially. Could you please point me towards the exact code that triggers the issue? From my own tests, everything was good

@Cyrilvallez
Copy link
Copy Markdown
Member Author

If you compare initial weights to model.state_dict() or something, it's expecting that it does not match, as the model actually sees other names

@albertvillanova
Copy link
Copy Markdown
Member

albertvillanova commented Apr 17, 2026

@Cyrilvallez this is the test code: https://github.com/huggingface/trl/blob/a09320e384461bc2a1bf301578bdc2c71fdc91b5/tests/test_dpo_trainer.py#L1085-L1102

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

for n, param in previous_trainable_params.items():
    if model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.encoder.layers.1" in n:
        continue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants