Skip to content

Commit

Permalink
FIX: Fix multiple generations for new HF cache format (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed May 1, 2024
1 parent 76edff6 commit 33af761
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,19 @@ def forward(
self.start_pos = 0

hf_is_generating = False
hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None
hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0

if self.is_hf_transformers and "use_cache" in kwargs:
hf_is_generating = kwargs["use_cache"]

# print(kwargs["past_key_value"].get_seq_length())

# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating):
if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating):
self.start_pos = 0


Expand Down

0 comments on commit 33af761

Please sign in to comment.