Skip to content

Commit

Permalink
Fix _speculative_sampling implementation (#28508)
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirzaf authored and amyeroberts committed Jan 21, 2024
1 parent 5fc3e60 commit 344943b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
10 changes: 7 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,16 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
"""
input_ids = input_ids.to(self.assistant_model.device)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
if max_new_tokens == 0:
return input_ids, None

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cur_len = input_ids.shape[-1]

new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
Expand All @@ -190,7 +194,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"max_new_tokens": int(self.num_assistant_tokens),
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
Expand Down
44 changes: 25 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4395,7 +4395,7 @@ def assisted_decoding(
else:
selected_tokens = new_logits.argmax(dim=-1)

candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
candidate_new_tokens = candidate_input_ids[:, cur_len:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

# Ensure we don't generate beyond max_len or an EOS token
Expand Down Expand Up @@ -4531,41 +4531,47 @@ def _speculative_sampling(
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q = candidate_logits.softmax(dim=-1)
q_i = q[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[:, torch.arange(candidate_length), candidate_input_ids[:, -candidate_length:]].squeeze(0, 1)
p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
# (= keep with p = probability_ratio). Keep all the tokens until the first rejection
r_i = torch.rand_like(probability_ratio)
is_accepted = r_i <= probability_ratio
n_matches = (~is_accepted.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1

# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
# due to acceptance on EOS we fix `n_matches`
n_matches -= 1
n_matches = min(n_matches, max_matches)

# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
n_matches = min(n_matches, max_matches)

# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
gamma = min(candidate_logits.shape[1], max_matches)
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
p_prime.div_(p_prime.sum())
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((candidate_input_ids[:, -n_matches:], t), dim=-1)
else:
valid_tokens = t
# The selected tokens include the matches (if any) plus the next sampled tokens
if n_matches > 0:
valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
else:
valid_tokens = t

return valid_tokens, n_matches

Expand Down
38 changes: 38 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import _speculative_sampling


class GenerationTesterMixin:
Expand Down Expand Up @@ -2391,6 +2392,43 @@ def test_top_k_top_p_filtering_with_filter_value(self):

self.assertTrue(torch.allclose(expected_output, output, atol=1e-12))

def test_speculative_sampling(self):
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
candidate_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
]
]
)
candidate_length = 3
inf = float("inf")
new_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
[-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
]
]
)
last_assistant_token_is_eos = False
max_matches = 5
validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
max_matches,
)
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])


@require_torch
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
Expand Down

0 comments on commit 344943b

Please sign in to comment.