-
Notifications
You must be signed in to change notification settings - Fork 930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add adapter_only
option to save_fsdp_model
and load_fsdp_model
to only save/load PEFT weights
#2321
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! I left one comment - I will let @pacman100 give his final opinion on the PR as he is more familiar with FSDP and accelerate
src/accelerate/utils/fsdp_utils.py
Outdated
def _is_peft_model(model): | ||
if is_peft_available(): | ||
from peft import PeftModel | ||
unwrapped_model = getattr(model.module, "_orig_mod", model.module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have an unwrap model utility method in accelerate? If model
is not a DDP model.module
will fail here no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant FSDP, not DDP, but, this should be safe, these functions (save_fsdp_...
, load_fsdp_...
) are only used with FSDP wrapped models and therefore they will always have .module
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok perfect then! maybe worht mentioning in a comment that _is_peft_model
is only meant to be used for FSDP models, or maybe change the method name to _is_fsdp_peft_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think extract_model_from_parallel
in accelerate should cover this case if desired
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But worth double checking/writing a test :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @muellerzr, I've just updated this with the unwrapping utility and re-ran the test we have on the other branch on the multiple-GPU machine and it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @AjayP13, the changes in this PR make sense to me. However, they have been only tested with the FULL_STATE_DICT
state dict type, would this work with SHARDED_STATE_DICT
too? If not, then this logic should only be limited to the cases wherein FULL_STATE_DICT
is used.
Thanks for the review @pacman100 , I just tested this, and this works with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @AjayP13 for enabling only the storage of adapter weights when using PEFT+FSDP, very useful! ✨
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall this seems very helpful, but let's work on the design a bit to adhear to more of our practices. Left some comments and suggestions 😄
src/accelerate/utils/fsdp_utils.py
Outdated
def _is_peft_model(model): | ||
if is_peft_available(): | ||
from peft import PeftModel | ||
return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel) | ||
|
||
|
||
def _get_model_state_dict(model, adapter_only=False): | ||
if adapter_only and _is_peft_model(model): | ||
from peft import get_peft_model_state_dict | ||
|
||
return get_peft_model_state_dict(model, adapter_name=model.active_adapter) | ||
else: | ||
return model.state_dict() | ||
|
||
|
||
def _set_model_state_dict(model, state_dict, adapter_only=False): | ||
if adapter_only and _is_peft_model(model): | ||
from peft import set_peft_model_state_dict | ||
|
||
return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter) | ||
else: | ||
return model.load_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not the biggest fan of this pattern we're working with here.
These should live inside of one function, and currently from what I can tell we have 2 different ways of getting the state_dict inside accelerate now:
- Here
Accelerator.get_state_dict
This hints at me that perhaps we need to chunk up Accelerator
's get_state_dict
into something that can be called elsewhere (and stored in utils.modeling
probably).
Same with _set_state_dict
as well. While yes it's just FSDP, then we can write a check for that if necessary.
Secondly: we generally don't follow the practice of hidden function names in Accelerate, and everything should be made public unless for extreme circumstances, and I'm not convinced this is one of those.
Can we rewrite this a bit to make it more extensible?
Perhaps just do an if/else
for calling set_peft_model_state_dict
or not, and just importing them at the top of the file. This way we can avoid this entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr See the newest refactor, but I believe this is not possible.
Accelerator.get_state_dict
already has logic for FSDP, but it handles a different use case for getting the FSDP state dict. Accelerator.get_state_dict
always returns the FULL_STATE_DICT
from FSDP. Meanwhile, this file, gets the state dict according to how the user configured they want it to be saved (could be a SHARDED_STATE_DICT
). I believe that is why this file originally never called Accelerator.get_state_dict
.
As for your suggestion of importing get_peft_model_state_dict
at the top v.s. encapsulated in the hidden function, we can't import at the top because accelerate
does not have a dependency on peft
and importing at the top would throw an error if peft
is not available on a users computer. Also, the hidden function _get_model_state_dict
is used 3 times in this file. I could get rid of the hidden function and replace with an inline if-statement in those 3 places, but it would introduce 3 places of repeated code that would need to be updated together. Between the repeated code and the need for keeping the peft
imports local v.s. top-of-file imports, I think it may make sense to keep the two hidden functions here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr Have you had a chance to look this over? Waiting on this PR to get huggingface/transformers#28297 merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, your comments make sense!
What does this PR do?
transformers
for motivation: Support saving only PEFT adapter in checkpoints when using PEFT + FSDP transformers#28297save_fsdp_model
andload_fsdp_model
adapter_only
parameter (default off) controls whether only the PEFT weights will be saved.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?