-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speculative sampling #17
base: main
Are you sure you want to change the base?
Changes from 7 commits
3ed5a89
1ac6623
be7e0cc
96fdc7d
ea646a2
f6d96f6
82a591f
6552516
2f47c70
bc75432
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -357,24 +357,80 @@ def __extract_decode_output( | |
|
||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor] | ||
the un-flattened next tokens per candidate per sequence, and the | ||
un-flattened output embedding vector | ||
the un-flattened logit scores per token per candidate per sequence, | ||
and the un-flattened output embedding vectors | ||
""" | ||
logits, _, embeds = model_output # 1 n' v, 1 n' d OR bk 1+h v, bk 1+h d | ||
next_vals = torch.argmax(logits, dim=-1) # 1 n' OR bk 1+h | ||
|
||
# If we used batch flattening / tree attention, unflatten the outputs | ||
if unflat_indices is not None: | ||
next_vals = apply_index_map(next_vals[0], unflat_indices) # b k 1+h | ||
logits = apply_index_map(logits[0], unflat_indices) # b k 1+h v | ||
embeds = apply_index_map(embeds[0], unflat_indices) # b k 1+h d | ||
else: | ||
next_vals = next_vals.view( | ||
batch_size, n_candidates, decode_seq_length | ||
) # b k 1+h | ||
logits = logits.view( | ||
batch_size, n_candidates, decode_seq_length, logits.size(2) | ||
) # b k 1+h v | ||
embeds = embeds.view( | ||
batch_size, n_candidates, decode_seq_length, embeds.size(2) | ||
) # b k 1+h d | ||
return next_vals, embeds | ||
return logits, embeds | ||
|
||
|
||
def __generate_targets( | ||
logits: torch.Tensor, | ||
do_sample: torch.Tensor, | ||
temperature: float = 1.0, | ||
top_k: int = 5, | ||
) -> torch.Tensor: | ||
""" | ||
Extracts ground-truth tokens from a set of logits. If performing greedy decoding, | ||
simply returns the most confident tokens. Otherwise, implements consistent multinomial | ||
sampling - two identical distributions will always produce the same (randomized) sample. | ||
Thus by induction, two candidates with identical prefixes will receive the same ground | ||
truth sample up to the point their inputs diverge. This allows us to ensure that at least | ||
one candidate will be accepted, so long as the candidate set covers the top_k options. | ||
|
||
For example, if the base model predicts tokens A and B with equal 50% probability, and the | ||
speculator produces one candidate with A and another with B, with independent sampling there's | ||
a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us | ||
to avoid this. | ||
|
||
Args: | ||
logits: torch.Tensor | ||
Probability logits for a set of candidate sequences. Expects size | ||
bsize x n_candidates x seq_len x vocab_size | ||
do_sample: torch.Tensor | ||
A tensor of booleans enabling/disabling non-greedy decoding with consistent | ||
sampling, for each of bsize input sequences | ||
temperature: float | ||
Degree of smoothing on softmax sampling distribution | ||
top_k: int | ||
Sample only among the top_k most confident tokens | ||
|
||
Returns: | ||
torch.Tensor | ||
Tensor of chosen token values for each sequence | ||
""" | ||
|
||
# Get sample distributions | ||
logits = logits / temperature | ||
v, _ = logits.topk(top_k) | ||
logits[logits < v[:, :, :, [-1]]] = -float("inf") | ||
probs = logits.softmax(-1) # b k 1+h v | ||
|
||
# Sample candidate-consistent ground truths: partition number line in [0,1] | ||
# according to given multinomial distribution. Pick a random location | ||
# on that line, return interval containing that location. | ||
key = torch.rand(1, 1, logits.size(2), 1, device=probs.device) | ||
a = ( | ||
probs.cumsum(3).sub(key).sign() | ||
) # Sign flips on probability interval containing key | ||
samples = a.sub(1).div(-2).sum(3) # Get index of sign-flip | ||
|
||
# Composite greedy and non greedy outputs | ||
greedy = logits.argmax(-1) | ||
mask = do_sample[:, None, None].int() | ||
return samples * mask + (1 - mask) * greedy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the mask is really a mask and not a weighting, might be better to use we're calculating the sampled results even if we don't use them? I guess that's something to do with compilation but I would have thought the generation code would be outside the compile path? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I'll swap to |
||
|
||
|
||
def speculative_generate( | ||
|
@@ -390,6 +446,9 @@ def speculative_generate( | |
decode_model: Optional[Union[Callable, torch.nn.Module]] = None, | ||
# todo: This is a WIP to enable cudagraphs, currently its only for batch_size=1 | ||
cudagraphs: bool = False, | ||
do_sample: bool = False, | ||
temperature: float = 1.0, | ||
top_k: int = 5, | ||
): | ||
""" | ||
A reference implementation of speculative decoding generation. | ||
|
@@ -433,6 +492,15 @@ def speculative_generate( | |
if True, cudagraphs is used and all metadata will be padded, otherwise | ||
metadata will not be padded unless required. Note: This is a WIP and | ||
only works for batch_size=1 | ||
do_sample: bool | ||
non-deterministic, multinomial output sampling. False for greedy. | ||
Provides output diversity, but lowers speculative decoding speedup. | ||
temperature: float | ||
temperature of softmax when sampling. Lowering this should provide | ||
better speculative decoding speedup when do_sample=True. | ||
top_k: int | ||
only search among top k tokens. Lowering this should provide | ||
better speculative decoding speedup when do_sample=True. | ||
Returns: | ||
result: List of id tensors, possibly different lengths if batching. | ||
n_steps: Number of foward passes used to generate provided tokens. | ||
|
@@ -518,10 +586,18 @@ def speculative_generate( | |
use_cache=True, | ||
) # 1 n' v OR bk 1+h v | ||
|
||
next_vals, embeds = __extract_decode_output( | ||
logits, embeds = __extract_decode_output( | ||
output, unflat_indices, bsize, n_candidates, inp_len | ||
) | ||
|
||
if do_sample: | ||
do_sample_vector = torch.ones(bsize, device=logits.device) | ||
else: | ||
do_sample_vector = torch.zeros(bsize, device=logits.device) | ||
next_vals = __generate_targets( | ||
logits, do_sample_vector, temperature=temperature, top_k=top_k | ||
) | ||
|
||
next_vals_list, embeds, parent_sequence_ids = __prune_candidates( | ||
input_ids, next_vals, embeds, kv_cache_manager, child_sequence_ids_list | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the goal is to speculate on a mutually exclusive set of possible continuations, why are we sampling at all and not just speculating on the top-k predictions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do this, but we're more concerned with the ability to sample here than we are with the non-greediness of the approach. In this case "not greedy" is meant strictly literally, in that sampling involves not selecting greedily (assuming I'm understanding the question)