From 4ad5adaf1d224fa28ffa8e1d124846b1d55a5d0e Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 2 May 2024 11:00:26 +0100 Subject: [PATCH] Fix copies for DBRX - neuron fix (#30610) --- src/transformers/models/dbrx/modeling_dbrx.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 362b0085aeb36..14baf7ae19e44 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1256,8 +1256,11 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only.