Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure KV cache is not returned as output tensor during decode phase for Falcon #993

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def pre_attn_forward(
dtype=self.query_key_value.weight.dtype,
device=self.query_key_value.weight.device,
)
layer_past = (past_key, past_value)
layer_past = [past_key, past_value]
key_layer = self.k_cache.update(
layer_past[0], key_layer, -2, token_idx, self.inp_seq_len
) # k_layer bs*1, q_len, head_dim
Expand All @@ -359,6 +359,11 @@ def pre_attn_forward(
else:
kv_length = present[0][-2] if reuse_cache else present[0].shape[-2]

if (not reuse_cache) and (token_idx is not None) and (cache_idx is not None) and (query_length == 1):
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
present = (present[0].shape, present[1].shape)

if alibi is None:
if output_attentions:
attention_scores = query_layer @ key_layer.transpose(-1, -2)
Expand Down Expand Up @@ -871,6 +876,7 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
reuse_cache = kwargs.get("reuse_cache")
bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand All @@ -885,8 +891,9 @@ def prepare_inputs_for_generation(
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]

Expand Down
Loading