-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Reproduction
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PrefixTuningConfig, TaskType, get_peft_model
from datasets import Dataset
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
model_name = 'allenai/OLMo-7B-Instruct-hf'
num_virtual_tokens = 16
encoder_hidden_size = 256
save_dir = './peft_bug'
pad_to_multiple_of = 154
# Dummy dataset
train_examples = [
{'input_ids': [510, 5347, 273, 6181, 310, 7785, 15],
'attention_mask': [1, 1, 1, 1, 1, 1, 1],
'completion_mask': [0, 0, 0, 0, 0, 1, 1]
}]
eval_examples = [
{'input_ids': [510, 5347, 273, 6176, 310, 12911, 15],
'attention_mask': [1, 1, 1, 1, 1, 1, 1],
'completion_mask': [0, 0, 0, 0, 0, 1, 1]
}]
train_dataset = Dataset.from_list(train_examples)
eval_dataset = Dataset.from_list(eval_examples)
# Prefix-tuning config
peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=num_virtual_tokens,
prefix_projection=True,
encoder_hidden_size=encoder_hidden_size
)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Training arguments
training_args = SFTConfig(
output_dir=save_dir,
report_to ="wandb",
eval_strategy="epoch",
logging_strategy="epoch",
save_strategy="epoch",
save_total_limit=5,
use_cpu=False,
seed=42,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=0.0001,
lr_scheduler_type= "linear",
warmup_steps=0,
num_train_epochs=30,
label_names=["labels"],
gradient_checkpointing=False
)
# Data collator
collator = DataCollatorForLanguageModeling(pad_token_id = tokenizer.pad_token_id,
completion_only_loss = True,
pad_to_multiple_of = pad_to_multiple_of)
# Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collator
)
trainer.train()
outputs:
File ".../site-packages/transformers/trainer.py", line 2329, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/transformers/trainer.py", line 2678, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/trl/trainer/sft_trainer.py", line 1145, in training_step
return super().training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/transformers/trainer.py", line 4022, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/trl/trainer/sft_trainer.py", line 1083, in compute_loss
entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (154) must match the size of tensor b (170) at non-singleton dimension 1
The size mismatch is exactly 16 (my num_virtual_tokens
). It seems to me that the SFTTrainer compute_loss
function does not take into account the specificity of prefix-tuning among the PEFT methods. The above code works fine when applied to prompt-tuning. I did not check any other PEFT methods.
In thepeft
library, Prefix-tuning seems to be handled differently, by storing the virtual embeddings in past_key_values
. We can observe that in the forward
of PeftModelForCausalLM
:
if peft_config.peft_type == PeftType.PREFIX_TUNING:
# overwrite past_kv in kwargs
# some archs require max_cache_len to re-initialize the cache
if input_ids is not None:
max_cache_len = input_ids.shape[1] + peft_config.num_virtual_tokens
else:
max_cache_len = inputs_embeds.shape[1] + peft_config.num_virtual_tokens
kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len)
return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
Therefore, in compute_loss
of SFTTrainer
, the following code should NOT prepend virtual_attention_mask
in case of prefix-tuning:
if "attention_mask" in inputs:
attention_mask = inputs["attention_mask"]
# When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1).
virtual_attention_mask = torch.ones(
attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
)
attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)
entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
I replaced this part with:
if "attention_mask" in inputs:
attention_mask = inputs["attention_mask"]
if isinstance(model, PeftModel) and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING:
virtual_attention_mask = torch.ones(
attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
)
attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)
and a few lines below, idem for the shift_logits
:
if isinstance(model, PeftModel) and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING:
shift_logits = shift_logits[:, self.num_virtual_tokens :, :]
Everything seems to work fine now, the training runs smoothly (and the results seem coherent when the training is done on a more complex dataset). Is this solution correct, or am I missing something? Is there a better solution here?
System Info
- Python version: 3.11.11
- TRL version: 0.22.2
- PyTorch version: 2.5.1
- accelerator(s): cpu
- Transformers version: 4.56.1
- Accelerate version: 1.10.1
- Accelerate config: not found
- Datasets version: 3.3.2
- HF Hub version: 0.34.4
- bitsandbytes version: not installed
- DeepSpeed version: not installed
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: 0.17.1
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete