Skip to content

Commit

Permalink
Ensure KV cache is not returned as output tensor during decode phase …
Browse files Browse the repository at this point in the history
…for Falcon (huggingface#993)
  • Loading branch information
schoi-habana authored and imangohari1 committed Jun 12, 2024
1 parent 4ad3e69 commit d18f0cd
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def pre_attn_forward(
dtype=self.query_key_value.weight.dtype,
device=self.query_key_value.weight.device,
)
layer_past = (past_key, past_value)
layer_past = [past_key, past_value]
key_layer = self.k_cache.update(
layer_past[0], key_layer, -2, token_idx, self.inp_seq_len
) # k_layer bs*1, q_len, head_dim
Expand All @@ -352,6 +352,11 @@ def pre_attn_forward(
else:
kv_length = present[0][-2] if reuse_cache else present[0].shape[-2]

if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1):
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
present = (present[0].shape, present[1].shape)

if alibi is None:
if output_attentions:
attention_scores = query_layer @ key_layer.transpose(-1, -2)
Expand Down Expand Up @@ -861,6 +866,7 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
reuse_cache = kwargs.get("reuse_cache")
bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand All @@ -875,8 +881,9 @@ def prepare_inputs_for_generation(
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]

Expand Down

0 comments on commit d18f0cd

Please sign in to comment.