diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bf05563e9b36..e1948205264a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3912,20 +3912,31 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch. if "shift_labels" in inputs: buffers.append(inputs["shift_labels"]) buffer_seq_dims.append(1) - if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False): - # Context parallel currently doesn't support other masks than causal - # Accelerate applies hooks to replace mask with is_causal arg in SDPA - # Check if the mask is really causal and if not throw an error - # TODO: check this only once or always, with speed being the cost - attention_mask = inputs["attention_mask"] - if not self._is_attention_mask_causal(attention_mask): - raise ValueError( - "Context parallelism only supports causal attention masks. " - "The provided attention_mask is not causal. " - "Please ensure your data uses causal masking (lower triangular) " - "or remove the attention_mask to use the model's default causal masking." - ) - self._attn_mask_causal_checked = True + # Add attention_mask to buffers for context parallel splitting (only if causal) + if "attention_mask" in inputs: + # Only validate causal mask once for performance + if not getattr(self, "_attn_mask_causal_checked", False): + # Context parallel currently doesn't support other masks than causal + # Accelerate applies hooks to replace mask with is_causal arg in SDPA + # Check if the mask is really causal and if not throw an error + attention_mask = inputs["attention_mask"] + if not self._is_attention_mask_causal(attention_mask): + raise ValueError( + "Context parallelism only supports causal attention masks. " + "The provided attention_mask is not causal. " + "Please ensure your data uses causal masking (lower triangular) " + "or remove the attention_mask to use the model's default causal masking." + ) + self._attn_mask_causal_checked = True + if self._attn_mask_causal_checked: + # Add to buffers only after validation (or if validation already passed) + attention_mask = inputs["attention_mask"] + if attention_mask.dim() == 2: + buffers.append(attention_mask) + buffer_seq_dims.append(1) + else: + # Other dimensionality; keep as-is without sharding to avoid incorrect splits + pass # Include position_ids in context parallelism splitting if "position_ids" in inputs and inputs["position_ids"] is not None: buffers.append(inputs["position_ids"])