Skip to content

Commit

Permalink
Fix fused_qkv check in rope and optimize falcon-7b (#2700)
Browse files Browse the repository at this point in the history
* Fix fused_qkv check in rope

* optimize falcon-7b: directly use fused_qkv instead of splitting heads before rope
  • Loading branch information
blzheng committed Mar 27, 2024
1 parent c8ce285 commit f57307d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
17 changes: 16 additions & 1 deletion csrc/cpu/aten/kernels/RotaryPositionEmbeddingKnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ namespace torch_ipex {
namespace cpu {

namespace {
bool is_fused_qkv(at::Tensor& t_in, int64_t hidden_size) {
auto in_stride_s = t_in.stride(1);
if (t_in.stride(0) * t_in.size(0) != t_in.numel()) {
if (t_in.dim() == 4) {
in_stride_s = t_in.size(2) * t_in.size(3);
} else if (t_in.dim() == 3) {
in_stride_s = t_in.size(2);
}
}
if (in_stride_s > hidden_size) {
return true;
}
return false;
}

/**
* Applies the Rotary Position Embedding Kernel to the input tensors.
*
Expand Down Expand Up @@ -48,7 +63,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> ApplyROPEKernel(
auto N_KV = N; // GQA/MQA, N_KV: number of head for key/value
auto concat_qkv = in_stride_s > N * H;

if (in_stride_s > N * H) {
if (is_fused_qkv(t_in, N * H)) {
TORCH_CHECK(
in_stride_s == HS,
"The shape of input tensor of rotary_position_embedding should be in (batch, seq_len, qkv_hidden_size) when using fused qkv)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,32 +403,46 @@ def _FalconAttention_forward(
num_kv_heads = (
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
)

(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
if self.new_decoder_architecture or not self.rotary:
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, query_length, _, _ = query_layer.shape
else:
batch_size, query_length, _ = fused_qkv.shape

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]

if self.rotary:
seq_len = query_length + past_kv_length
key_layer = self._IPEXROPE(
key_layer,
torch.tensor(past_kv_length),
num_kv_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
seq_len,
)
query_layer = self._IPEXROPE(
query_layer,
torch.tensor(past_kv_length),
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
seq_len,
)
if self.new_decoder_architecture:
key_layer = self._IPEXROPE(
key_layer,
torch.tensor(past_kv_length),
num_kv_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
seq_len,
)
query_layer = self._IPEXROPE(
query_layer,
torch.tensor(past_kv_length),
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
seq_len,
)
else:
query_layer, key_layer, value_layer = self._IPEXROPE(
fused_qkv,
torch.tensor(past_kv_length),
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
seq_len,
3,
)
attention_mask_float = (
(attention_mask * 1.0)
.masked_fill(attention_mask.to(torch.bool), float("-1e9"))
Expand Down

0 comments on commit f57307d

Please sign in to comment.