diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 7a0ea8d38501..c00b22156467 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -969,6 +969,9 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + # Fix: https://github.com/huggingface/transformers/issues/40564 + accepts_loss_kwargs = False def __init__(self, config: Gemma3Config): super().__init__(config) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 6e06671ea0bb..18f10fc3ad3d 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -869,6 +869,10 @@ def forward( class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + # Fix: https://github.com/huggingface/transformers/issues/40564 + accepts_loss_kwargs = False + @auto_docstring def forward( self,