From 0dd1b041523ee9922d14725feea40be66fa5ce35 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 2 Sep 2025 12:58:12 +0200 Subject: [PATCH 1/2] Fix attention mask validation for context parallelism --- src/transformers/trainer.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bf05563e9b36..6806800d9680 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3912,20 +3912,26 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch. if "shift_labels" in inputs: buffers.append(inputs["shift_labels"]) buffer_seq_dims.append(1) - if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False): - # Context parallel currently doesn't support other masks than causal - # Accelerate applies hooks to replace mask with is_causal arg in SDPA - # Check if the mask is really causal and if not throw an error - # TODO: check this only once or always, with speed being the cost - attention_mask = inputs["attention_mask"] - if not self._is_attention_mask_causal(attention_mask): - raise ValueError( - "Context parallelism only supports causal attention masks. " - "The provided attention_mask is not causal. " - "Please ensure your data uses causal masking (lower triangular) " - "or remove the attention_mask to use the model's default causal masking." - ) - self._attn_mask_causal_checked = True + # Add attention_mask to buffers for context parallel splitting (only if causal) + if "attention_mask" in inputs: + # Only validate causal mask once for performance + if not getattr(self, "_attn_mask_causal_checked", False): + # Context parallel currently doesn't support other masks than causal + # Accelerate applies hooks to replace mask with is_causal arg in SDPA + # Check if the mask is really causal and if not throw an error + attention_mask = inputs["attention_mask"] + if not self._is_attention_mask_causal(attention_mask): + raise ValueError( + "Context parallelism only supports causal attention masks. " + "The provided attention_mask is not causal. " + "Please ensure your data uses causal masking (lower triangular) " + "or remove the attention_mask to use the model's default causal masking." + ) + self._attn_mask_causal_checked = True + if self._attn_mask_causal_checked: + # Add to buffers only after validation (or if validation already passed) + buffers.append(inputs["attention_mask"]) + buffer_seq_dims.append(1) # Include position_ids in context parallelism splitting if "position_ids" in inputs and inputs["position_ids"] is not None: buffers.append(inputs["position_ids"]) From 484321d7870fbfb3eef29293258d172e1e08076c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 3 Sep 2025 12:10:25 +0200 Subject: [PATCH 2/2] only split 2d attention masks --- src/transformers/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6806800d9680..e1948205264a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3930,8 +3930,13 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch. self._attn_mask_causal_checked = True if self._attn_mask_causal_checked: # Add to buffers only after validation (or if validation already passed) - buffers.append(inputs["attention_mask"]) - buffer_seq_dims.append(1) + attention_mask = inputs["attention_mask"] + if attention_mask.dim() == 2: + buffers.append(attention_mask) + buffer_seq_dims.append(1) + else: + # Other dimensionality; keep as-is without sharding to avoid incorrect splits + pass # Include position_ids in context parallelism splitting if "position_ids" in inputs and inputs["position_ids"] is not None: buffers.append(inputs["position_ids"])