-
Notifications
You must be signed in to change notification settings - Fork 25.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix resuming PeftModel checkpoints in Trainer #24274
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1997,14 +1997,23 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): | |
model = self.model | ||
|
||
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) | ||
|
||
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) | ||
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) | ||
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) | ||
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) | ||
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) | ||
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) | ||
|
||
if not any( | ||
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file] | ||
os.path.isfile(f) | ||
for f in [ | ||
weights_file, | ||
safe_weights_file, | ||
weights_index_file, | ||
safe_weights_index_file, | ||
adapter_weights_file, | ||
adapter_safe_weights_file, | ||
] | ||
): | ||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") | ||
|
||
|
@@ -2057,6 +2066,25 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): | |
# release memory | ||
del state_dict | ||
self._issue_warnings_after_load(load_result) | ||
|
||
# Load adapters following PR # 24096 | ||
elif is_peft_available() and isinstance(model, PeftModel): | ||
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. | ||
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): | ||
if os.path.exists(resume_from_checkpoint) or os.path.exists(resume_from_checkpoint): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch! |
||
model.load_adapter(resume_from_checkpoint, model.active_adapter) | ||
# Load_adapter has no return value present, modify it when appropriate. | ||
from torch.nn.modules.module import _IncompatibleKeys | ||
|
||
load_result = _IncompatibleKeys([], []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just set it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe completely remove that line since |
||
else: | ||
logger.warning( | ||
"The intermediate checkpoints of PEFT may not be saved correctly, " | ||
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, " | ||
"here are some examples https://github.com/huggingface/peft/issues/96" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not able to parse this warning sentence. Is it saying I should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah agreed, as this was a copy paste, I have modified the original version of the warning as well |
||
) | ||
else: | ||
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") | ||
else: | ||
# We load the sharded checkpoint | ||
load_result = load_sharded_checkpoint( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding, why don't we need the equivalent
adapter_weights_index_file
oradapter_safe_weights_index_file
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I think index files are used for sharded checkpoints only. In PEFT the saved checkpoints are always extremely light (order of magnitude of few MBs) - even for very large models - thus we never save sharded checkpoint, therefore we don't save index files!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!