-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect Saving Peft Models using HuggingFace Trainer #96
Comments
Another way that I created to save storage without any code modifications in HF is to create a callback:
|
That is really clean way to save the PEFT checkpoints, I think that should serve the purpose. |
Very nice solution, @agemagician. I have adapted it to the use-case when saving at steps/epochs instead of end of training from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
How do you resume_from_checkpoint with this approach? Removing pytorch_model.bin causes "ValueError: Can't find a valid checkpoint" |
Is there any way to only save the peft-Model with such a callback? As far as i understood the Callbacks, they save the whole model and later remove the main model. The problem is, when Training in 8bit mode this leads to a crash because of OOM. |
You can do so by subclassing the Trainer class and overwriting the class PeftTrainer(Trainer):
def _save_checkpoint(self, _, trial, metrics=None):
""" Don't save base model, optimizer etc.
but create checkpoint folder (needed for saving adapter) """
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (self.state.best_metric is None or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
os.makedirs(output_dir, exist_ok=True)
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
class PeftSavingCallback(TrainerCallback):
""" Correctly save PEFT model and not full model """
def _save(self, model, folder):
peft_model_path = os.path.join(folder, "adapter_model")
model.save_pretrained(peft_model_path)
def on_train_end(self, args: TrainingArguments, state: TrainerState,
control: TrainerControl, **kwargs):
""" Save final best model adapter """
self._save(kwargs['model'], state.best_model_checkpoint)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
control: TrainerControl, **kwargs):
""" Save intermediate model adapters in case of interrupted training """
folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
self._save(kwargs['model'], folder) |
Hello there, may I ask how did you define the model_init in this user case? Not sure the appropriate way to do it with LoRA (comparing to the usual way as outlined in this tutorial. Many thanks. @agemagician @pie3636 |
Hello,
Thanks a lot for the great project.
I am fine-tuning Flan-T5-XXL using HuggingFace Seq2SeqTrainer and hyperparameter_search.
However, the trainer doesn't store Peft models correctly because it is not a "PreTrainedModel" type.
It stores the whole PyTorch model, including the Flan-T5-XXL, which is around 42 GB.
I have dug into the code, and I made a hacky solution inside "trainer.py" for now:
Do you have a better solution for saving the "Peft models" correctly using HuggingFace Seq2SeqTrainer and hyperparameter_search ?
The text was updated successfully, but these errors were encountered: