Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter_only option to save_fsdp_model and load_fsdp_model to only save/load PEFT weights #2321

Merged
merged 8 commits into from
Jan 26, 2024

Conversation

AjayP13
Copy link
Contributor

@AjayP13 AjayP13 commented Jan 9, 2024

What does this PR do?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@AjayP13 AjayP13 marked this pull request as ready for review January 9, 2024 17:44
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! I left one comment - I will let @pacman100 give his final opinion on the PR as he is more familiar with FSDP and accelerate

def _is_peft_model(model):
if is_peft_available():
from peft import PeftModel
unwrapped_model = getattr(model.module, "_orig_mod", model.module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have an unwrap model utility method in accelerate? If model is not a DDP model.module will fail here no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant FSDP, not DDP, but, this should be safe, these functions (save_fsdp_..., load_fsdp_...) are only used with FSDP wrapped models and therefore they will always have .module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok perfect then! maybe worht mentioning in a comment that _is_peft_model is only meant to be used for FSDP models, or maybe change the method name to _is_fsdp_peft_model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extract_model_from_parallel in accelerate should cover this case if desired

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But worth double checking/writing a test :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @muellerzr, I've just updated this with the unwrapping utility and re-ran the test we have on the other branch on the multiple-GPU machine and it works.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @AjayP13, the changes in this PR make sense to me. However, they have been only tested with the FULL_STATE_DICT state dict type, would this work with SHARDED_STATE_DICT too? If not, then this logic should only be limited to the cases wherein FULL_STATE_DICT is used.

@AjayP13
Copy link
Contributor Author

AjayP13 commented Jan 18, 2024

Thank you @AjayP13, the changes in this PR make sense to me. However, they have been only tested with the FULL_STATE_DICT state dict type, would this work with SHARDED_STATE_DICT too? If not, then this logic should only be limited to the cases wherein FULL_STATE_DICT is used.

Thanks for the review @pacman100 , I just tested this, and this works with SHARDED_STATE_DICT as well (similarly the loss drops, sharded state dicts per worker get saved to disk, resuming & load_best both work).

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @AjayP13 for enabling only the storage of adapter weights when using PEFT+FSDP, very useful! ✨

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this seems very helpful, but let's work on the design a bit to adhear to more of our practices. Left some comments and suggestions 😄

Comment on lines 36 to 57
def _is_peft_model(model):
if is_peft_available():
from peft import PeftModel
return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)


def _get_model_state_dict(model, adapter_only=False):
if adapter_only and _is_peft_model(model):
from peft import get_peft_model_state_dict

return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
else:
return model.state_dict()


def _set_model_state_dict(model, state_dict, adapter_only=False):
if adapter_only and _is_peft_model(model):
from peft import set_peft_model_state_dict

return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)
else:
return model.load_state_dict(state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not the biggest fan of this pattern we're working with here.

These should live inside of one function, and currently from what I can tell we have 2 different ways of getting the state_dict inside accelerate now:

  1. Here
  2. Accelerator.get_state_dict

This hints at me that perhaps we need to chunk up Accelerator's get_state_dict into something that can be called elsewhere (and stored in utils.modeling probably).

Same with _set_state_dict as well. While yes it's just FSDP, then we can write a check for that if necessary.

Secondly: we generally don't follow the practice of hidden function names in Accelerate, and everything should be made public unless for extreme circumstances, and I'm not convinced this is one of those.

Can we rewrite this a bit to make it more extensible?

Perhaps just do an if/else for calling set_peft_model_state_dict or not, and just importing them at the top of the file. This way we can avoid this entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr See the newest refactor, but I believe this is not possible.

Accelerator.get_state_dict already has logic for FSDP, but it handles a different use case for getting the FSDP state dict. Accelerator.get_state_dict always returns the FULL_STATE_DICT from FSDP. Meanwhile, this file, gets the state dict according to how the user configured they want it to be saved (could be a SHARDED_STATE_DICT). I believe that is why this file originally never called Accelerator.get_state_dict.

As for your suggestion of importing get_peft_model_state_dict at the top v.s. encapsulated in the hidden function, we can't import at the top because accelerate does not have a dependency on peft and importing at the top would throw an error if peft is not available on a users computer. Also, the hidden function _get_model_state_dict is used 3 times in this file. I could get rid of the hidden function and replace with an inline if-statement in those 3 places, but it would introduce 3 places of repeated code that would need to be updated together. Between the repeated code and the need for keeping the peft imports local v.s. top-of-file imports, I think it may make sense to keep the two hidden functions here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr Have you had a chance to look this over? Waiting on this PR to get huggingface/transformers#28297 merged.

src/accelerate/utils/fsdp_utils.py Outdated Show resolved Hide resolved
src/accelerate/utils/imports.py Show resolved Hide resolved
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, your comments make sense!

@muellerzr muellerzr merged commit 581fabb into huggingface:main Jan 26, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants