Skip to content

Commit ad28ca2

Browse files
authored
[bloom] fix alibi device placement (#18087)
1 parent 8b332a6 commit ad28ca2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/transformers/models/bloom/modeling_bloom.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
9393
)
9494

9595

96-
def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
96+
def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
9797
"""
9898
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
9999
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
@@ -129,7 +129,7 @@ def get_slopes_power_of_2(n):
129129
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
130130
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
131131

132-
alibi = alibi.to(dtype)
132+
alibi = alibi.to(device=device, dtype=dtype)
133133

134134
return alibi
135135

@@ -147,7 +147,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
147147
# This usually happens when the inference is done with past_key_values
148148
# In this case we re-create the alibi tensor with the correct sequence length
149149
if attention_mask.shape[-1] != alibi.shape[-1]:
150-
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
150+
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.device, alibi.dtype).repeat(
151151
attention_mask.shape[0], 1, 1
152152
)
153153
# Get the indexes of the padding tokens
@@ -156,7 +156,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
156156

157157
# Clone the embeddings - we can detach because the embeddings are not learned
158158
# Get a refence tensor
159-
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
159+
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.device, alibi.dtype)
160160

161161
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
162162
# Only where you do not have padding. Replace padding tokens by zeros
@@ -767,7 +767,7 @@ def forward(
767767
current_sequence_length = hidden_states.shape[1]
768768
if past_key_values[0] is not None:
769769
current_sequence_length += past_key_values[0][0].shape[1]
770-
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
770+
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.device, hidden_states.dtype)
771771

772772
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
773773

0 commit comments

Comments
 (0)