-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Attempt to fix VLM gradient enabling #41993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if vision_module is not None: | ||
| for parameter in vision_module.parameters(): | ||
| parameter.requires_grad = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understanding of peft#2880, the problem is mainly that the entry point of the model doesn't require gradients (not a trainable parameter, just for gradient checkpointing) so that targeting modules after that doesn't work with reentrant gradient checkpointing. Isn't setting all vision parameters to requires_grad=True masking the changes done in enable_input_requires_grad and therefore always true, regardless of what that helper function does? Maybe targeting something that is clearly not an input, something resembling an attention layer for example, is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, hmm- followed the implem of idefics2/smolvlm as I remembered they faced this issue at the time. You're right that this isn't necessary, we register twice. The lowest module trick should work though, and I'm not sure targeting an attention layer works either. Currently @BenjaminBossan 's script outputs grad norms properly with gradient checkpointing enabled and PEFT disabled on this branch, so it seems to do the trick?
no GC
{'loss': 9.4971, 'grad_norm': 23.421083450317383, 'learning_rate': 2e-05, 'epoch': 0.33}
{'loss': 7.9526, 'grad_norm': 675.1868896484375, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.67} with GC
{'loss': 9.4971, 'grad_norm': 23.421083450317383, 'learning_rate': 2e-05, 'epoch': 0.33}
{'loss': 7.9526, 'grad_norm': 675.1868896484375, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.67} in either case, agree double registering is useless, will remove!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think the implementation is fine. I'm just worried that the test is masking the behavior of the fix and is therefore not honest enough. Sorry if I didn't make that clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that's fair, I'll revamp the test for a narrower scope!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this solution works only for VLMs and also depends a lot on how the vision model is named. I'm sure we listed all possible names, but new models can get creative with it
So I'm thinking that we could potentially make it works ootx for all MLLMS (audio/vision/omni) by checking for each PreTrainedModel within the model and then setting grads on that models' inputs (model.get_input_embeddings())
We use similar trick when setting attention implementations and check for PreTrainedModel's, so it could be a good option. WDYT?
|
Thanks, yes it's a far less brittle option. There's a few (really a few and hopefully should be 0 after v5) modules that were just |
Sorry I may misunderstand the proposed solution but this doesn't seem to solve the problem? In a VLM where I target a module in the vision stack I need to have the vision model's inputs require grads, not the language model's input ( |
|
@githubnemo |
|
BTW, when working on smth else noticed that we have code like below which can be deleted after this PR transformers/src/transformers/models/idefics2/modeling_idefics2.py Lines 1026 to 1030 in 80134e6
|
|
Yes, for all idefics/smolvlm there'll be no need for that. Should ship that today (finally) |
|
Iterated a bit on that and hit a dead end on idefics2/3 code, back at it tomorrow! |
|
Tests failing appear unrelated (I rebased on main), what do you think of the new method @zucchini-nlp ? Also @githubnemo I updated the test a tad, let me know |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the clean-up! Only one major q about tests, would be super super cool to have a common test imo. Though I realize it can be hard with multimodals
| def enable_input_require_grads(self): | ||
| """ | ||
| Enables the gradients for the input embeddings. | ||
| This is useful for lora when using gradient checkpointing. | ||
| c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice 🔪
| if hooks: | ||
| # for BC | ||
| self._require_grads_hook = hooks[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aren't we ignoring all hooks except for the first one in this case, i.e. when we disable it will disable the text model and will not disable vision model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, this is just because we used to remove _require_grads_hook, now we always iterate over the full list _require_grads_hooks (with an s) so every registered hook (vision or text or whatever) should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh my bad, didn't see the "s" at the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be a bad naming then haha
| for module in self.modules(): | ||
| if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")): | ||
| continue | ||
|
|
||
| input_embeddings = module.get_input_embeddings() | ||
|
|
||
| if input_embeddings is None: | ||
| continue | ||
|
|
||
| embedding_id = id(input_embeddings) | ||
| if embedding_id in seen_modules: | ||
| continue | ||
|
|
||
| seen_modules.add(embedding_id) | ||
| hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super clean!
|
|
||
| self._require_grads_hooks = [] | ||
| if hasattr(self, "_require_grads_hook"): | ||
| del self._require_grads_hook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity, is it required to explicitly delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of safety, not certain it's always necessary but not knowing what people were doing with that hook in their FT scripts I think it's safer to remove it so no reference remains
| def test_multi_gpu_data_parallel_forward(self): | ||
| pass | ||
|
|
||
| def test_enable_input_require_grads_with_gradient_checkpointing(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am thinking, if we can make a common test for all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eeeh... I think we should :D yes
will look before EOD if I have time
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
…ormers into fix_reentrant_gc_vlms
|
run-slow: bart, blip_2, idefics2, idefics3, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_omni_moe, smolvlm, timm_wrapper |
|
This comment contains models: ["models/bart", "models/blip_2", "models/idefics2", "models/idefics3", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_omni_moe", "models/smolvlm", "models/timm_wrapper"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bart, blip_2, idefics2, idefics3, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_omni_moe, smolvlm, timm_wrapper |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this PR do?
As per title. Linked to huggingface/peft#2880.
Follows more or less closely the already existing implementations for idefics2-3 and smolvlm, trying to cover several types of VLMs (they are named differently across the lib.)