diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f0191fb144fc..8c8a67fa5cb05 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4221,6 +4221,9 @@ def assisted_decoding( # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + # other auxiliary variables + max_len = stopping_criteria[0].max_length + this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: @@ -4235,7 +4238,7 @@ def assisted_decoding( # Assistant: main logic start cur_len = input_ids.shape[-1] - max_len = stopping_criteria[0].max_length + assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1 # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we @@ -4244,7 +4247,7 @@ def assisted_decoding( for _ in range(int(assistant_model.max_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] + prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] @@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ) ) past_key_values = tuple(new_past) + elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] else: for idx in range(len(past_key_values)): new_past.append( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3b96f2b2bdff1..70de057d5fe7d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1473,7 +1473,7 @@ def test_assisted_decoding_matches_greedy_search(self): # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return @@ -1529,7 +1529,7 @@ def test_assisted_decoding_sample(self): # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return