Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resuming PeftModel checkpoints in Trainer #24274

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,14 +1997,23 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
model = self.model

config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
Comment on lines +2000 to +2001
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding, why don't we need the equivalent adapter_weights_index_file or adapter_safe_weights_index_file here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I think index files are used for sharded checkpoints only. In PEFT the saved checkpoints are always extremely light (order of magnitude of few MBs) - even for very large models - thus we never save sharded checkpoint, therefore we don't save index files!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

if not any(
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

Expand Down Expand Up @@ -2057,6 +2066,21 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# release memory
del state_dict
self._issue_warnings_after_load(load_result)

# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint) or os.path.exists(resume_from_checkpoint):
Copy link
Collaborator

@amyeroberts amyeroberts Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this if x or y check, x and y are the same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

model.load_adapter(resume_from_checkpoint, model.active_adapter)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to parse this warning sentence. Is it saying I should use TrainingCallback to resolve this issue, using TrainingCallback caused this issue or that using TrainerCallback will be used as a result of this warning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, as this was a copy paste, I have modified the original version of the warning as well

)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(
Expand Down
Loading