Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resuming PeftModel checkpoints in Trainer #24274

Conversation

llohann-speranca
Copy link
Contributor

@llohann-speranca llohann-speranca commented Jun 14, 2023

What does this PR do?

Fix an error occurred when Trainer tries to resume a PeftModel from a training checkpoint. That was caused since PeftModel.pre_trained saves only adapter-related data while _load_from_checkpoint expects a saved torch model. This PR fix this issue and allows the adapter checkpoint to be loaded.

Resolves: #24252

Fixes #24252

Before submitting

Who can review?

@younesbelkada

Fix an error occurred when resuming a PeftModel from a training checkpoint. That was caused since PeftModel.pre_trained saves only adapter-related data while _load_from_checkpoint was expecting a torch sved model. This PR fix this issue and allows the adapter checkpoint to be loaded.

Resolves: huggingface#24252
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 14, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work
Managed to reproduce the issue and your proposed fix successfully fixes it.

Handy snippet:

import os
from transformers import TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("imdb", split="train")
output_dir = "test"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,   
    max_steps=5,
    save_steps=1,
    save_strategy='steps'
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    args=training_args,
    dataset_text_field="text",
    peft_config=peft_config
)
trainer.save_model(os.path.join(output_dir, "checkpoint-1"))
trainer.train(resume_from_checkpoint=True)

Thanks a lot

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thnaks for working on this, just have one comment.

Comment on lines 2076 to 2079
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys

load_result = _IncompatibleKeys([], [])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just set it to None and adapt _issue_warnings_after_load to return early if the load_result is None cause this is a bit long for nothing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adityaaryan77
Copy link

Although this continues training but it doesnt retain the old stuff for me. Can someone look into this?

@techthiyanes
Copy link

techthiyanes commented Jun 17, 2023

This doesn't seem like older resume from checkpoint that has it for pytorch models. Inside trainer.train we need to pass resume from checkpoint parameters as our last checkpoint path. while passing the path, it shows as can't find a valid checkpoint. Could someone please post some code snippet on how to use resume from checkpoint for PEFT models?

@brijesh-6899
Copy link

@pacman100 Requesting to review and merge the changes, thanks!

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 19, 2023

Hi @techthiyanes

This doesn't seem like older resume from checkpoint that has it for pytorch models. Inside trainer.train we need to pass resume from checkpoint parameters as our last checkpoint path. while passing the path, it shows as can't find a valid checkpoint. Could someone please post some code snippet on how to use resume from checkpoint for PEFT models?

As shared in the snippet above, to make resume_from_checkpoint work as expected, it assumes that you have previously trained your model using trainer that saves artifacts under {output_dir}/checkpoint-{i}, I have "faked" that in the example by manually saving a model in a folder called {output_dir}/checkpoint-1. Therefore you need to make sure the model weights lives under that folder.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @llohann-speranca
Again thanks for your great work on this, I think this seems a rather important fix that might unlock a lot of users, if that's ok for you, I can quickly take over the PR and address the last comment so that we can merge the PR. What do you think ?

@adityaaryan77
Copy link

This doesn't seem like older resume from checkpoint that has it for pytorch models. Inside trainer.train we need to pass resume from checkpoint parameters as our last checkpoint path. while passing the path, it shows as can't find a valid checkpoint. Could someone please post some code snippet on how to use resume from checkpoint for PEFT models?

Hi @younesbelkada can you look at my issue with the code and please address it?

@younesbelkada
Copy link
Contributor

Yes @adityaaryan77 , sure, please have a look at my comment on the PEFT issue and discuss there

@techthiyanes
Copy link

techthiyanes commented Jun 19, 2023

Hi @younesbelkada

```python
trainer.train(resume_from_checkpoint=True)

Thank you for your response.

Please look at below code snippet:

-- coding: utf-8 --

"""Untitled345.ipynb

Automatically generated by Colaboratory.

Original file is located at
https://colab.research.google.com/drive/1SgzMXUUDK1wDH0M0yQPfWmeNAKyy7EFs
"""

! pip install datasets transformers peft evaluate

!git clone https://github.com/llohann-speranca/transformers.git -b fix-resume-checkpoint-for-peftmodel

!cp -r /content/transformers /usr/local/lib/python3.10/dist-packages/transformers

import transformers
import numpy as np
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
task = "cola"
model_checkpoint = "bert-large-uncased"
batch_size = 16
from datasets import load_dataset, load_metric
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mnli-mm": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[task]
if sentence2_key is None:
print(f"Sentence: {dataset['train'][0][sentence1_key]}")
else:
print(f"Sentence 1: {dataset['train'][0][sentence1_key]}")
print(f"Sentence 2: {dataset['train'][0][sentence2_key]}")
def preprocess_function(examples):
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
encoded_dataset = dataset.map(preprocess_function, batched=True)
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import (
get_peft_config,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
LoraConfig,
PeftType,
PrefixTuningConfig,
PromptEncoderConfig,
)
peft_type = PeftType.LORA
device = "cuda"
peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
lr = 3e-4

num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model
metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
f"{model_name}-finetuned1-{task}",
evaluation_strategy = "epoch",
save_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=2,
weight_decay=0.01,
# load_best_model_at_end=True,
metric_for_best_model=metric_name,
# push_to_hub=True,
)
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import os

class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

    peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
    kwargs["model"].save_pretrained(peft_model_path)

    pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
    if os.path.exists(pytorch_model_path):
        os.remove(pytorch_model_path)
    return control

def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != "stsb":
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:, 0]
return metric.compute(predictions=predictions, references=labels)
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[SavePeftModelCallback],
)
trainer.train()

trainer.save_model("/content/bert-large-uncased-finetuned1-cola/checkpoint-1")

trainer.train(resume_from_checkpoint='/content/bert-large-uncased-finetuned1-cola/checkpoint-1070/adapter_model')

Inside the resume from checkpoint i have tried with below options

  1. resume_from_checkpoint = True
  2. resume_from_checkpoint = (Last checkpoint path)
  3. resume_from_checkpoint = (trainer.saved model path)

Everywhere I'm getting the same message of Can't find a valid checkpoint at .
At the same time, I'm able to continue my resume from checkpoint in native pytorch code.

@llohann-speranca
Copy link
Contributor Author

Hi @llohann-speranca Again thanks for your great work on this, I think this seems a rather important fix that might unlock a lot of users, if that's ok for you, I can quickly take over the PR and address the last comment so that we can merge the PR. What do you think ?

Hi @younesbelkada. Sure! I have been very busy and have still to learn how to deal with PRs. Sorry about that.

@younesbelkada
Copy link
Contributor

younesbelkada commented Jun 19, 2023

@llohann-speranca thanks!
@techthiyanes it seems you are using the API the wrong way. resume_from_checkpoint will try to retrieve the latest checkpoint from the output directory of the trainer. Therefore make sure you have correct checkpoints-{i} folders inside f"{model_name}-finetuned1-{task}" in your case and use resume_from_checkpoint=True

@techthiyanes
Copy link

@llohann-speranca thanks! @techthiyanes it seems you are using the API the wrong way. resume_from_checkpoint will try to retrieve the latest checkpoint from the output directory of the trainer. Therefore make sure you have correct checkpoints-{i} folders inside f"{model_name}-finetuned1-{task}" in your case and use resume_from_checkpoint=True

Thanks a lot on your response.
By default while passing resume from checkpoint then API automatically consumes the recent checkpoint. But this is something not working as expected for PEFT models than torch models. As mentioned, I have pointed out the correct checkpoint and the same folder resides inside alone.

If you don't mind, could you please try executing the any of huggingface example code inserting PEFT with the trainer & resume from checkpoint? Then you might be able to replicate.

@younesbelkada
Copy link
Contributor

@techthiyanes I think it works as expected with this PR, as explained in #24274 (review) I have tried the attached snippet that was not working before the PR as mentioned and this PR properly fixes it by loading the checkpoint. If you want you can try to replicate using a smaller example (for example imdb as attached) and let me know if you still face an issue by opening a new ticket

@techthiyanes
Copy link

@techthiyanes I think it works as expected with this PR, as explained in #24274 (review) I have tried the attached snippet that was not working before the PR as mentioned and this PR properly fixes it by loading the checkpoint. If you want you can try to replicate using a smaller example (for example imdb as attached) and let me know if you still face an issue by opening a new ticket

Sure..Thanks a lot.. Let me try above snippet for classification models then let you know.

@techthiyanes
Copy link

techthiyanes commented Jun 19, 2023

@techthiyanes I think it works as expected with this PR, as explained in #24274 (review) I have tried the attached snippet that was not working before the PR as mentioned and this PR properly fixes it by loading the checkpoint. If you want you can try to replicate using a smaller example (for example imdb as attached) and let me know if you still face an issue by opening a new ticket

Still able to replicate the issue. Raised a separate issue on the same.
#24354

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Just some small comments. Not super familiar with Trainer or PEFT. However, as @sgugger has already reviewed with just a small (resolved) comment, I think this should be OK to merge once the others have been addressed

elif is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint) or os.path.exists(resume_from_checkpoint):
Copy link
Collaborator

@amyeroberts amyeroberts Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this if x or y check, x and y are the same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Comment on lines 2078 to 2080
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not able to parse this warning sentence. Is it saying I should use TrainingCallback to resolve this issue, using TrainingCallback caused this issue or that using TrainerCallback will be used as a result of this warning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, as this was a copy paste, I have modified the original version of the warning as well

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Comment on lines +2000 to +2001
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding, why don't we need the equivalent adapter_weights_index_file or adapter_safe_weights_index_file here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I think index files are used for sharded checkpoints only. In PEFT the saved checkpoints are always extremely light (order of magnitude of few MBs) - even for very large models - thus we never save sharded checkpoint, therefore we don't save index files!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

@younesbelkada younesbelkada merged commit 183f442 into huggingface:main Jun 20, 2023
22 checks passed
@beyondguo
Copy link

Hi, nice job! Does this new feature available if I pip install the latest peft/transformers packages? Or should I install from source?

@pacman100
Copy link
Contributor

Thank you so much @llohann-speranca and @younesbelkada for adding this 🤗! @beyondguo, please install from source as this isn't yet part of the release.

@shoang22
Copy link

shoang22 commented Jul 10, 2023

Thanks for iterating! Will we perform inference in the same manner? Specifically, peft requires us to load PeftConfig via adapter_config.json. I saw that this is not saved with trainer.save_model(). Will we need to add model.save_pretrained() to use for inference?

@AayushSameerShah
Copy link

I am trying to resume in this way:

from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM

config = PeftConfig.from_pretrained("huggingface_path_TO_MY_PREVIOUSLY_TRAINED_LORA")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") # underlying model
model.enable_input_require_grads() # to make training possible
model = PeftModel.from_pretrained(model, "huggingface_path_TO_MY_PREVIOUSLY_TRAINED_LORA")

model.print_trainable_parameters()
# trainable params: 0 || all params: 792,587,264 || trainable%: 0.0

Then the usual code:

from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
)

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir="./t5-large-r32-lora-JOIN-FIX-RESUME"
#batch_size = 8

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    save_strategy="steps",
    save_steps=500,
    gradient_accumulation_steps=4,
    learning_rate=1e-3, # higher learning rate
    weight_decay=0.01,
    num_train_epochs=2, 
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=100,
    push_to_hub=True)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=TRAINING,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

trainer.train()

The loss starts where I paused the training. Which is so low so that I can tell that it has started training the old model, but I am not sure whether I am doing it right.

Will @younesbelkada you please make me sure if whatever I am doing is right?
Thank you.

@XM-Dong
Copy link

XM-Dong commented Oct 12, 2023

I am trying to resume in this way:

from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM

config = PeftConfig.from_pretrained("huggingface_path_TO_MY_PREVIOUSLY_TRAINED_LORA")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") # underlying model
model.enable_input_require_grads() # to make training possible
model = PeftModel.from_pretrained(model, "huggingface_path_TO_MY_PREVIOUSLY_TRAINED_LORA")

model.print_trainable_parameters()
# trainable params: 0 || all params: 792,587,264 || trainable%: 0.0

Then the usual code:

from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
)

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir="./t5-large-r32-lora-JOIN-FIX-RESUME"
#batch_size = 8

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    save_strategy="steps",
    save_steps=500,
    gradient_accumulation_steps=4,
    learning_rate=1e-3, # higher learning rate
    weight_decay=0.01,
    num_train_epochs=2, 
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=100,
    push_to_hub=True)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=TRAINING,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

trainer.train()

The loss starts where I paused the training. Which is so low so that I can tell that it has started training the old model, but I am not sure whether I am doing it right.

Will @younesbelkada you please make me sure if whatever I am doing is right? Thank you.

I've encountered the same issue. Have you resolved it? @AayushSameerShah

@AayushSameerShah
Copy link

@XM-Dong Nah... it seems like LoRA needs some "special script" :(

@wei-ann-Github
Copy link

Hi, I'm also wondering how can I get these changes? Are they in a new transformers version or PEFT version?

my current versions are:

transformers==4.30.1
peft==0.4.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Peft Model not resuming from Checkpoint