From 145e15f96ed72af9b9c822da423bb35b99ddc86b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 6 May 2023 16:04:27 +0000 Subject: [PATCH 1/2] starcoder has joined the chat --- src/transformers/generation/utils.py | 9 ++++++++- tests/generation/test_utils.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f0191fb144fc0..6ba148339cb7c0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4244,7 +4244,7 @@ def assisted_decoding( for _ in range(int(assistant_model.max_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] + prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] @@ -4505,6 +4505,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ) ) past_key_values = tuple(new_past) + elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] else: for idx in range(len(past_key_values)): new_past.append( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3b96f2b2bdff10..70de057d5fe7d6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1473,7 +1473,7 @@ def test_assisted_decoding_matches_greedy_search(self): # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return @@ -1529,7 +1529,7 @@ def test_assisted_decoding_sample(self): # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return From 02a70b6f748e973339c08cb6623af7f0a9231c61 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 6 May 2023 16:46:39 +0000 Subject: [PATCH 2/2] indexing that works for all --- src/transformers/generation/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6ba148339cb7c0..8c8a67fa5cb051 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4221,6 +4221,9 @@ def assisted_decoding( # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + # other auxiliary variables + max_len = stopping_criteria[0].max_length + this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: @@ -4235,7 +4238,7 @@ def assisted_decoding( # Assistant: main logic start cur_len = input_ids.shape[-1] - max_len = stopping_criteria[0].max_length + assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1 # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we @@ -4244,7 +4247,7 @@ def assisted_decoding( for _ in range(int(assistant_model.max_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[-2] + prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:]