Skip to content

Commit

Permalink
Fix copies for DBRX - neuron fix (#30610)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed May 2, 2024
1 parent f953025 commit 4ad5ada
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,8 +1256,11 @@ def _update_causal_mask(
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
Expand Down

0 comments on commit 4ad5ada

Please sign in to comment.