Skip to content

Conversation

@molbap
Copy link
Contributor

@molbap molbap commented Nov 3, 2025

What does this PR do?

As per title. Linked to huggingface/peft#2880.
Follows more or less closely the already existing implementations for idefics2-3 and smolvlm, trying to cover several types of VLMs (they are named differently across the lib.)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 427 to 429
if vision_module is not None:
for parameter in vision_module.parameters():
parameter.requires_grad = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding of peft#2880, the problem is mainly that the entry point of the model doesn't require gradients (not a trainable parameter, just for gradient checkpointing) so that targeting modules after that doesn't work with reentrant gradient checkpointing. Isn't setting all vision parameters to requires_grad=True masking the changes done in enable_input_requires_grad and therefore always true, regardless of what that helper function does? Maybe targeting something that is clearly not an input, something resembling an attention layer for example, is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, hmm- followed the implem of idefics2/smolvlm as I remembered they faced this issue at the time. You're right that this isn't necessary, we register twice. The lowest module trick should work though, and I'm not sure targeting an attention layer works either. Currently @BenjaminBossan 's script outputs grad norms properly with gradient checkpointing enabled and PEFT disabled on this branch, so it seems to do the trick?

no GC

{'loss': 9.4971, 'grad_norm': 23.421083450317383, 'learning_rate': 2e-05, 'epoch': 0.33}                                                                     
{'loss': 7.9526, 'grad_norm': 675.1868896484375, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.67}                                                      

with GC

{'loss': 9.4971, 'grad_norm': 23.421083450317383, 'learning_rate': 2e-05, 'epoch': 0.33}
 {'loss': 7.9526, 'grad_norm': 675.1868896484375, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.67} 

in either case, agree double registering is useless, will remove!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the implementation is fine. I'm just worried that the test is masking the behavior of the fix and is therefore not honest enough. Sorry if I didn't make that clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that's fair, I'll revamp the test for a narrower scope!

@molbap molbap requested a review from zucchini-nlp November 6, 2025 21:20
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this solution works only for VLMs and also depends a lot on how the vision model is named. I'm sure we listed all possible names, but new models can get creative with it

So I'm thinking that we could potentially make it works ootx for all MLLMS (audio/vision/omni) by checking for each PreTrainedModel within the model and then setting grads on that models' inputs (model.get_input_embeddings())

We use similar trick when setting attention implementations and check for PreTrainedModel's, so it could be a good option. WDYT?

@molbap
Copy link
Contributor Author

molbap commented Nov 7, 2025

Thanks, yes it's a far less brittle option. There's a few (really a few and hopefully should be 0 after v5) modules that were just nn.Modules instead of PreTrainedModel so they would be off the hook, other than these few exceptions should work out well, will push something like that today

@githubnemo
Copy link
Contributor

So I'm thinking that we could potentially make it works ootx for all MLLMS (audio/vision/omni) by checking for each PreTrainedModel within the model and then setting grads on that models' inputs (model.get_input_embeddings())

We use similar trick when setting attention implementations and check for PreTrainedModel's, so it could be a good option. WDYT?

Sorry I may misunderstand the proposed solution but this doesn't seem to solve the problem? In a VLM where I target a module in the vision stack I need to have the vision model's inputs require grads, not the language model's input (get_input_embeddings).

@zucchini-nlp
Copy link
Member

@githubnemo model.get_input_embeddings() will return the lowest module for each models, which is the same thing current PR does by recursing over modules. So setting grads on inputs of model.get_input_embeddings() will enable it not only for for vision-modules, but also for audio and other modalities

@zucchini-nlp
Copy link
Member

BTW, when working on smth else noticed that we have code like below which can be deleted after this PR

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""

@molbap
Copy link
Contributor Author

molbap commented Nov 13, 2025

Yes, for all idefics/smolvlm there'll be no need for that. Should ship that today (finally)

@molbap
Copy link
Contributor Author

molbap commented Nov 13, 2025

Iterated a bit on that and hit a dead end on idefics2/3 code, back at it tomorrow!

@molbap
Copy link
Contributor Author

molbap commented Nov 14, 2025

Tests failing appear unrelated (I rebased on main), what do you think of the new method @zucchini-nlp ? Also @githubnemo I updated the test a tad, let me know

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the clean-up! Only one major q about tests, would be super super cool to have a common test imo. Though I realize it can be hard with multimodals

Comment on lines -805 to -795
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings.
This is useful for lora when using gradient checkpointing.
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 🔪

Comment on lines +2237 to +2239
if hooks:
# for BC
self._require_grads_hook = hooks[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we ignoring all hooks except for the first one in this case, i.e. when we disable it will disable the text model and will not disable vision model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, this is just because we used to remove _require_grads_hook, now we always iterate over the full list _require_grads_hooks (with an s) so every registered hook (vision or text or whatever) should be removed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my bad, didn't see the "s" at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a bad naming then haha

Comment on lines +2220 to +2235
for module in self.modules():
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
continue

input_embeddings = module.get_input_embeddings()

if input_embeddings is None:
continue

embedding_id = id(input_embeddings)
if embedding_id in seen_modules:
continue

seen_modules.add(embedding_id)
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super clean!


self._require_grads_hooks = []
if hasattr(self, "_require_grads_hook"):
del self._require_grads_hook
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, is it required to explicitly delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of safety, not certain it's always necessary but not knowing what people were doing with that hook in their FT scripts I think it's safer to remove it so no reference remains

def test_multi_gpu_data_parallel_forward(self):
pass

def test_enable_input_require_grads_with_gradient_checkpointing(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am thinking, if we can make a common test for all models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eeeh... I think we should :D yes
will look before EOD if I have time

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@molbap
Copy link
Contributor Author

molbap commented Nov 27, 2025

run-slow: bart, blip_2, idefics2, idefics3, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_omni_moe, smolvlm, timm_wrapper

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/bart", "models/blip_2", "models/idefics2", "models/idefics3", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_omni_moe", "models/smolvlm", "models/timm_wrapper"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: bart, blip_2, idefics2, idefics3, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_omni_moe, smolvlm, timm_wrapper

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! this fixes #42489 (#42494) as well I think

@ArthurZucker ArthurZucker merged commit e2f08ea into main Dec 1, 2025
19 of 24 checks passed
@ArthurZucker ArthurZucker deleted the fix_reentrant_gc_vlms branch December 1, 2025 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants