-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Closed
Labels
Description
System Info
- transformers version: 4.53
- Python version: 3.12
- Platform: Linux
- Models affected: google/gemma-3-12b-it (confirmed), likely affects other models
- Training setup: torchrun with DeepSpeed, LoRA fine-tuning
Description
The HuggingFace Trainer exhibits incorrect loss scaling behavior when using gradient accumulation, where logged loss values increase proportionally with gradient_accumulation_steps
instead of remaining consistent. This affects model training monitoring, convergence analysis, and hyperparameter tuning.
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Minimal Example
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-12b-it")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-12b-it")
tokenizer.padding_side = "right"
# Create dummy dataset
dummy_data = {
"input_ids": [[1] * 100] * 1000, # Simple repeated tokens
"labels": [[1] * 100] * 1000
}
dataset = Dataset.from_dict(dummy_data)
# Test with different gradient accumulation steps
for ga_steps in [1, 2, 4, 8]:
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
output_dir=f"./test_ga_{ga_steps}",
per_device_train_batch_size=1,
gradient_accumulation_steps=ga_steps,
max_steps=10,
logging_steps=1,
save_steps=1000,
)
)
trainer.train()
Expected behavior
Loss values should remain mathematically consistent regardless of gradient_accumulation_steps
, as the effective batch size and optimization steps should be equivalent.
Actual Behavior
gradient_accumulation_steps=8
: Loss = Xgradient_accumulation_steps=16
: Loss = 2X (exactly double)gradient_accumulation_steps=32
: Loss = 4X
Root Cause Analysis
The issue stems from the model_accepts_loss_kwargs
parameter in the Trainer class:
this issue has been raised by here but it is not properly handled yet (transformers 4.56.0 has same issue)
marseller