Skip to content

SFTTrainer compute_loss fails when model is PEFT Prefix-tuning (size mismatch) #4168

@gabrielle-lebellier

Description

@gabrielle-lebellier

Reproduction

from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PrefixTuningConfig, TaskType, get_peft_model
from datasets import Dataset
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling

model_name = 'allenai/OLMo-7B-Instruct-hf'
num_virtual_tokens = 16
encoder_hidden_size = 256
save_dir = './peft_bug'
pad_to_multiple_of = 154

# Dummy dataset
train_examples = [
{'input_ids': [510, 5347, 273, 6181, 310, 7785, 15],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1],
 'completion_mask': [0, 0, 0, 0, 0, 1, 1]
}]
eval_examples = [
{'input_ids': [510, 5347, 273, 6176, 310, 12911, 15],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1],
 'completion_mask': [0, 0, 0, 0, 0, 1, 1]
}]
train_dataset = Dataset.from_list(train_examples)
eval_dataset =  Dataset.from_list(eval_examples)

# Prefix-tuning config
peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,
                                num_virtual_tokens=num_virtual_tokens,
                                prefix_projection=True,
                                encoder_hidden_size=encoder_hidden_size
                                )

model = AutoModelForCausalLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Training arguments
training_args = SFTConfig(
    output_dir=save_dir, 
    report_to ="wandb",  
    eval_strategy="epoch",  
    logging_strategy="epoch",
    save_strategy="epoch",  
    save_total_limit=5,  
    use_cpu=False,  
    seed=42,  
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=0.0001,  
    lr_scheduler_type= "linear",
    warmup_steps=0,  
    num_train_epochs=30,
    label_names=["labels"],
    gradient_checkpointing=False
)


# Data collator
collator = DataCollatorForLanguageModeling(pad_token_id = tokenizer.pad_token_id,
                                           completion_only_loss = True,
                                           pad_to_multiple_of = pad_to_multiple_of)


# Trainer
trainer = SFTTrainer(
    model=model,  
    args=training_args,  
    train_dataset=train_dataset, 
    eval_dataset=eval_dataset, 
    data_collator=collator
)

trainer.train()

outputs:

File ".../site-packages/transformers/trainer.py", line 2329, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/transformers/trainer.py", line 2678, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/trl/trainer/sft_trainer.py", line 1145, in training_step
    return super().training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/transformers/trainer.py", line 4022, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/trl/trainer/sft_trainer.py", line 1083, in compute_loss
    entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
                        ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (154) must match the size of tensor b (170) at non-singleton dimension 1

The size mismatch is exactly 16 (my num_virtual_tokens). It seems to me that the SFTTrainer compute_loss function does not take into account the specificity of prefix-tuning among the PEFT methods. The above code works fine when applied to prompt-tuning. I did not check any other PEFT methods.

In thepeft library, Prefix-tuning seems to be handled differently, by storing the virtual embeddings in past_key_values. We can observe that in the forwardof PeftModelForCausalLM:

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            # overwrite past_kv in kwargs
            # some archs require max_cache_len to re-initialize the cache
            if input_ids is not None:
                max_cache_len = input_ids.shape[1] + peft_config.num_virtual_tokens
            else:
                max_cache_len = inputs_embeds.shape[1] + peft_config.num_virtual_tokens
            kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len)
            return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)

Therefore, in compute_loss of SFTTrainer, the following code should NOT prepend virtual_attention_mask in case of prefix-tuning:

if "attention_mask" in inputs:
                    attention_mask = inputs["attention_mask"]
                    # When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1).
                    virtual_attention_mask = torch.ones(
                        attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
                    )
                    attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)
                    entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()

I replaced this part with:

if "attention_mask" in inputs:
                    attention_mask = inputs["attention_mask"]
                    if isinstance(model, PeftModel) and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING:
                        virtual_attention_mask = torch.ones(
                        attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
                        )
                        attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)

and a few lines below, idem for the shift_logits:

                if isinstance(model, PeftModel) and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING: 
                    shift_logits = shift_logits[:, self.num_virtual_tokens :, :]

Everything seems to work fine now, the training runs smoothly (and the results seem coherent when the training is done on a more complex dataset). Is this solution correct, or am I missing something? Is there a better solution here?

System Info

  • Python version: 3.11.11
  • TRL version: 0.22.2
  • PyTorch version: 2.5.1
  • accelerator(s): cpu
  • Transformers version: 4.56.1
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 3.3.2
  • HF Hub version: 0.34.4
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.17.1
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions