@@ -93,7 +93,7 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
9393 )
9494
9595
96- def build_alibi_tensor (max_seq_len , n_head , dtype = torch .bfloat16 ):
96+ def build_alibi_tensor (max_seq_len , n_head , device , dtype = torch .bfloat16 ):
9797 """
9898 Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
9999 relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
@@ -129,7 +129,7 @@ def get_slopes_power_of_2(n):
129129 arange_tensor = torch .arange (max_seq_len ).unsqueeze (0 ).unsqueeze (0 )
130130 alibi = slopes * arange_tensor .expand (n_head , - 1 , - 1 )
131131
132- alibi = alibi .to (dtype )
132+ alibi = alibi .to (device = device , dtype = dtype )
133133
134134 return alibi
135135
@@ -147,7 +147,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
147147 # This usually happens when the inference is done with past_key_values
148148 # In this case we re-create the alibi tensor with the correct sequence length
149149 if attention_mask .shape [- 1 ] != alibi .shape [- 1 ]:
150- alibi = build_alibi_tensor (attention_mask .shape [- 1 ], num_heads , alibi .dtype ).repeat (
150+ alibi = build_alibi_tensor (attention_mask .shape [- 1 ], num_heads , alibi .device , alibi . dtype ).repeat (
151151 attention_mask .shape [0 ], 1 , 1
152152 )
153153 # Get the indexes of the padding tokens
@@ -156,7 +156,7 @@ def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
156156
157157 # Clone the embeddings - we can detach because the embeddings are not learned
158158 # Get a refence tensor
159- slice_reference_alibi = build_alibi_tensor (alibi .shape [- 1 ], num_heads , alibi .dtype )
159+ slice_reference_alibi = build_alibi_tensor (alibi .shape [- 1 ], num_heads , alibi .device , alibi . dtype )
160160
161161 # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
162162 # Only where you do not have padding. Replace padding tokens by zeros
@@ -767,7 +767,7 @@ def forward(
767767 current_sequence_length = hidden_states .shape [1 ]
768768 if past_key_values [0 ] is not None :
769769 current_sequence_length += past_key_values [0 ][0 ].shape [1 ]
770- alibi = build_alibi_tensor (current_sequence_length , self .n_head , hidden_states .dtype )
770+ alibi = build_alibi_tensor (current_sequence_length , self .n_head , hidden_states .device , hidden_states . dtype )
771771
772772 for i , (block , layer_past ) in enumerate (zip (self .h , past_key_values )):
773773
0 commit comments