Skip to content

Commit

Permalink
Generate: starcoder 馃 馃 assisted generation (#23182)
Browse files Browse the repository at this point in the history
* starcoder has joined the chat

* indexing that works for all
  • Loading branch information
gante committed May 8, 2023
1 parent dbc1226 commit bbfb9fc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
14 changes: 12 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4221,6 +4221,9 @@ def assisted_decoding(
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

# other auxiliary variables
max_len = stopping_criteria[0].max_length

this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
Expand All @@ -4235,7 +4238,7 @@ def assisted_decoding(

# Assistant: main logic start
cur_len = input_ids.shape[-1]
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1

# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
Expand All @@ -4244,7 +4247,7 @@ def assisted_decoding(
for _ in range(int(assistant_model.max_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
Expand Down Expand Up @@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values = tuple(new_past)
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
new_past.append(
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def test_assisted_decoding_matches_greedy_search(self):
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return

Expand Down Expand Up @@ -1529,7 +1529,7 @@ def test_assisted_decoding_sample(self):
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return

Expand Down

0 comments on commit bbfb9fc

Please sign in to comment.