[loading] Clean way to add/remove full parts in checkpoint names#45448
[loading] Clean way to add/remove full parts in checkpoint names#45448Cyrilvallez wants to merge 36 commits intomainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| @dataclass(slots=True) |
There was a problem hiding this comment.
They were dataclasses but it did not make any sense, so removed it (but kept the slots, the only feature we were really using from dataclass - makes it much easier to inherit etc
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"), | ||
| WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"), |
There was a problem hiding this comment.
They were not uniquely defined - it's the same transforms, but uniquely defined now
There was a problem hiding this comment.
does the encoder have only self_attn in sd, I see that part removed from regex?
There was a problem hiding this comment.
Agree, looks sus with the removed self_attn - let's definitely run some slow tests (locally) to double check
There was a problem hiding this comment.
It's just that they are a bit redundant with the wildcard before - I checked locally and it loads correctly! Can add back though!
zucchini-nlp
left a comment
There was a problem hiding this comment.
Yep, I think this is the way to go and kinda matches the suggestion I had previously to iterate over state-dict dynamically. I have a few q to be sure and suggest we move to PrefixMatch in timm/sam models
| # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but | ||
| # only if it was used, i.e. it matched any weight from the checkpoints | ||
| model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] |
There was a problem hiding this comment.
Does it not create another loophole if some conversions aren't used by model? But prob if the test checks full list of conversions, we can consciously decide to add task-specific patterns like lm_head
There was a problem hiding this comment.
Not sure I can follow, do you have a small example where you think we would fall through?
There was a problem hiding this comment.
The test that checks that all conversions are used (to make sure we don't add useless conversions) does it by checking the keys between original model and saved weights (because from our tiny models, we only know the "correct" keys, and the saved keys mimic the "wrong" keys), so it will still check it correctly!
| # Now, use a fresh model, without going trough loading first, so the model won't have `_weight_conversions` attached | ||
| # and the prefix should not be added when saving directly (i.e. the conversion should be dropped) | ||
| model = DummyRoot() | ||
| saved_state_dict = revert_weight_conversion(model, model.state_dict()) | ||
| model_state_dict = model.state_dict() | ||
| self.assertEqual(set(model_state_dict.keys()), set(saved_state_dict.keys())) | ||
| for k, v in saved_state_dict.items(): | ||
| self.assertTrue((v == model_state_dict[k]).all()) | ||
|
|
There was a problem hiding this comment.
to be clear, let's take a llava example. User creates random llava arch, trains it and saves without prefix. Then if the model was loaded from_pretrained, we still match keys right?
There was a problem hiding this comment.
If user creates llava from scratch, then train, then save, the saved weights will be the same as the arch weights, so without prefix. Then when reloading those checkpoints, and resaving, it will stays the same!
There was a problem hiding this comment.
ahh right, we are removing the prefix and not adding. I would like us to re-check it with prefix_to_add in subsequent PRs for timm/sam/etc
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"), | ||
| WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"), |
There was a problem hiding this comment.
does the encoder have only self_attn in sd, I see that part removed from regex?
vasqu
left a comment
There was a problem hiding this comment.
Some initial thoughts from my side as well. I think this is definitely needed especially also looking at the rf detr addition - would be nice if we could sync both problems together
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"), | ||
| WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"), | ||
| WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_q", r"encoder.\1.q_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_k", r"encoder.\1.k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_v", r"encoder.\1.v_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_out", r"encoder.\1.o_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.linear_pos", r"encoder.\1.relative_k_proj"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_u", r"encoder.\1.bias_u"), | ||
| WeightRenaming(r"encoder\.(.+)\.pos_bias_v", r"encoder.\1.bias_v"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.query_net", r"decoder.\1.self_attn.q_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.key_net", r"decoder.\1.self_attn.k_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.value_net", r"decoder.\1.self_attn.v_proj"), | ||
| WeightRenaming(r"decoder\.(.+)\.first_sub_layer\.out_projection", r"decoder.\1.self_attn.o_proj"), |
There was a problem hiding this comment.
Agree, looks sus with the removed self_attn - let's definitely run some slow tests (locally) to double check
| # Keep the current weight conversion mapping for later saving (in case it was coming directly from the user), but | ||
| # only if it was used, i.e. it matched any weight from the checkpoints | ||
| model_specific_conversions = [conversion for conversion in weight_mapping if conversion.was_used()] |
There was a problem hiding this comment.
Not sure I can follow, do you have a small example where you think we would fall through?
There was a problem hiding this comment.
Imo, it would make sense to have some integration / slow tests with real models as well - llava (removal case), rf detr (future model) (add case)
There was a problem hiding this comment.
Agreed this would be nice - we can maybe even do it on fast CI with a small enough real model
There was a problem hiding this comment.
Do you mind if I do it on the next PR where I'll switch timm and other to this?
There was a problem hiding this comment.
Nope, fine with having another PR for this
|
Thanks @zucchini-nlp @vasqu! I answered everything. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: altclip |
|
Thanks for addressing this issue, @Cyrilvallez. I have tested
For example:
Is this renaming intended? This is the reason why our tests are still failing, as they rely on the previous name nesting structure. |
|
@albertvillanova The renaming is fully intended in order to load the weights. Since the module was removed, the weights need to be renamed to match the actual module graph. However, if you resave the weights afterwards, the name should be reverted back to the same as initially. Could you please point me towards the exact code that triggers the issue? From my own tests, everything was good |
|
If you compare initial weights to |
|
@Cyrilvallez this is the test code: https://github.com/huggingface/trl/blob/a09320e384461bc2a1bf301578bdc2c71fdc91b5/tests/test_dpo_trainer.py#L1085-L1102 previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
for n, param in previous_trainable_params.items():
if model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.encoder.layers.1" in n:
continue |
What does this PR do?
As per the title.
The issue
The problem is that transforms that want to remove a full part of a model name (such as a prefix, e.g. the
model.start) are non bijective in general, i.e. we completely lose the information when they are dropped. So adding them back later when saving is impossible without runtime information about the checkpoint that was used, i.e. we need to know if we had the prefix before or not, we cannot infer it based on anything.Proposed solution
This PR add a simple mechanism for such things, i.e. WeightTransform have a simple flag to describe if they were used to rename a weight or not. If it is the case, we keep them when we save the Transform on the model (this was already performed before). If not, we drop them, so that they are not used when resaving.
It also introduces the
PrefixChangeclass (a simple class inherited fromWeightRenaming) to simplify full addition/removal of full parts, because otherwise the regexes to use in such cases are hard to read/write.