Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3912,20 +3912,31 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.
if "shift_labels" in inputs:
buffers.append(inputs["shift_labels"])
buffer_seq_dims.append(1)
if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False):
# Context parallel currently doesn't support other masks than causal
# Accelerate applies hooks to replace mask with is_causal arg in SDPA
# Check if the mask is really causal and if not throw an error
# TODO: check this only once or always, with speed being the cost
attention_mask = inputs["attention_mask"]
if not self._is_attention_mask_causal(attention_mask):
raise ValueError(
"Context parallelism only supports causal attention masks. "
"The provided attention_mask is not causal. "
"Please ensure your data uses causal masking (lower triangular) "
"or remove the attention_mask to use the model's default causal masking."
)
self._attn_mask_causal_checked = True
# Add attention_mask to buffers for context parallel splitting (only if causal)
if "attention_mask" in inputs:
# Only validate causal mask once for performance
if not getattr(self, "_attn_mask_causal_checked", False):
# Context parallel currently doesn't support other masks than causal
# Accelerate applies hooks to replace mask with is_causal arg in SDPA
# Check if the mask is really causal and if not throw an error
attention_mask = inputs["attention_mask"]
if not self._is_attention_mask_causal(attention_mask):
raise ValueError(
"Context parallelism only supports causal attention masks. "
"The provided attention_mask is not causal. "
"Please ensure your data uses causal masking (lower triangular) "
"or remove the attention_mask to use the model's default causal masking."
)
self._attn_mask_causal_checked = True
if self._attn_mask_causal_checked:
# Add to buffers only after validation (or if validation already passed)
attention_mask = inputs["attention_mask"]
if attention_mask.dim() == 2:
buffers.append(attention_mask)
buffer_seq_dims.append(1)
else:
# Other dimensionality; keep as-is without sharding to avoid incorrect splits
pass
# Include position_ids in context parallelism splitting
if "position_ids" in inputs and inputs["position_ids"] is not None:
buffers.append(inputs["position_ids"])
Expand Down