Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Saving Peft Models using HuggingFace Trainer #96

Closed
agemagician opened this issue Feb 16, 2023 · 8 comments
Closed

Incorrect Saving Peft Models using HuggingFace Trainer #96

agemagician opened this issue Feb 16, 2023 · 8 comments
Labels
solved solved

Comments

@agemagician
Copy link

Hello,

Thanks a lot for the great project.

I am fine-tuning Flan-T5-XXL using HuggingFace Seq2SeqTrainer and hyperparameter_search.
However, the trainer doesn't store Peft models correctly because it is not a "PreTrainedModel" type.
It stores the whole PyTorch model, including the Flan-T5-XXL, which is around 42 GB.

I have dug into the code, and I made a hacky solution inside "trainer.py" for now:

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")
        from peft.peft_model import PeftModelForSeq2SeqLM
        if isinstance(self.model, PeftModelForSeq2SeqLM):
            self.model.save_pretrained(output_dir, state_dict=state_dict)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        elif not isinstance(self.model, PreTrainedModel):
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if state_dict is None:
                    state_dict = self.model.state_dict()
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir, state_dict=state_dict)
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

Do you have a better solution for saving the "Peft models" correctly using HuggingFace Seq2SeqTrainer and hyperparameter_search ?

@agemagician
Copy link
Author

Another way that I created to save storage without any code modifications in HF is to create a callback:

from transformers.trainer_callback import TrainerCallback
import os
class PeftSavingCallback(TrainerCallback):
    def on_train_end(self, args, state, control, **kwargs):
        peft_model_path = os.path.join(state.best_model_checkpoint, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(state.best_model_checkpoint, "pytorch_model.bin")
        os.remove(pytorch_model_path) if os.path.exists(pytorch_model_path) else None

trainer = Seq2SeqTrainer(
    #model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    model_init=model_init,
    callbacks=[PeftSavingCallback]

)

@pacman100
Copy link
Contributor

That is really clean way to save the PEFT checkpoints, I think that should serve the purpose.

@pacman100 pacman100 added the solved solved label Feb 24, 2023
@sorenmulli
Copy link

sorenmulli commented Mar 8, 2023

Very nice solution, @agemagician. I have adapted it to the use-case when saving at steps/epochs instead of end of training

from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(
            args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
        )       

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

@github-actions
Copy link

github-actions bot commented Apr 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@jonathanasdf
Copy link

How do you resume_from_checkpoint with this approach? Removing pytorch_model.bin causes "ValueError: Can't find a valid checkpoint"

@SebastianBodza
Copy link

Is there any way to only save the peft-Model with such a callback? As far as i understood the Callbacks, they save the whole model and later remove the main model.

The problem is, when Training in 8bit mode this leads to a crash because of OOM.

@pie3636
Copy link

pie3636 commented May 16, 2023

Is there any way to only save the peft-Model with such a callback? As far as i understood the Callbacks, they save the whole model and later remove the main model.

The problem is, when Training in 8bit mode this leads to a crash because of OOM.

You can do so by subclassing the Trainer class and overwriting the _save_checkpoint method as well as using callbacks. This is what worked in my case, but I only kept the parts of _save that I needed, so you might need to adapt the code for your use:

class PeftTrainer(Trainer):
    def _save_checkpoint(self, _, trial, metrics=None):
        """ Don't save base model, optimizer etc.
            but create checkpoint folder (needed for saving adapter) """
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)

        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (self.state.best_metric is None or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)):
                self.state.best_metric = metric_value

                self.state.best_model_checkpoint = output_dir

        os.makedirs(output_dir, exist_ok=True)

        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

class PeftSavingCallback(TrainerCallback):
    """ Correctly save PEFT model and not full model """
    def _save(self, model, folder):
        peft_model_path = os.path.join(folder, "adapter_model")
        model.save_pretrained(peft_model_path)

    def on_train_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save final best model adapter """
        self._save(kwargs['model'], state.best_model_checkpoint)

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save intermediate model adapters in case of interrupted training """
        folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        self._save(kwargs['model'], folder)

@hanyin88
Copy link

hanyin88 commented Jun 9, 2023

Is there any way to only save the peft-Model with such a callback? As far as i understood the Callbacks, they save the whole model and later remove the main model.
The problem is, when Training in 8bit mode this leads to a crash because of OOM.

You can do so by subclassing the Trainer class and overwriting the _save_checkpoint method as well as using callbacks. This is what worked in my case, but I only kept the parts of _save that I needed, so you might need to adapt the code for your use:

class PeftTrainer(Trainer):
    def _save_checkpoint(self, _, trial, metrics=None):
        """ Don't save base model, optimizer etc.
            but create checkpoint folder (needed for saving adapter) """
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)

        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (self.state.best_metric is None or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)):
                self.state.best_metric = metric_value

                self.state.best_model_checkpoint = output_dir

        os.makedirs(output_dir, exist_ok=True)

        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

class PeftSavingCallback(TrainerCallback):
    """ Correctly save PEFT model and not full model """
    def _save(self, model, folder):
        peft_model_path = os.path.join(folder, "adapter_model")
        model.save_pretrained(peft_model_path)

    def on_train_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save final best model adapter """
        self._save(kwargs['model'], state.best_model_checkpoint)

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
            control: TrainerControl, **kwargs):
        """ Save intermediate model adapters in case of interrupted training """
        folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        self._save(kwargs['model'], folder)

Hello there, may I ask how did you define the model_init in this user case? Not sure the appropriate way to do it with LoRA (comparing to the usual way as outlined in this tutorial. Many thanks. @agemagician @pie3636

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved solved
Projects
None yet
Development

No branches or pull requests

7 participants