Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative sampling #17

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
94 changes: 85 additions & 9 deletions fms_extras/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,24 +357,80 @@ def __extract_decode_output(

Returns:
Tuple[torch.Tensor, torch.Tensor]
the un-flattened next tokens per candidate per sequence, and the
un-flattened output embedding vector
the un-flattened logit scores per token per candidate per sequence,
and the un-flattened output embedding vectors
"""
logits, _, embeds = model_output # 1 n' v, 1 n' d OR bk 1+h v, bk 1+h d
next_vals = torch.argmax(logits, dim=-1) # 1 n' OR bk 1+h

# If we used batch flattening / tree attention, unflatten the outputs
if unflat_indices is not None:
next_vals = apply_index_map(next_vals[0], unflat_indices) # b k 1+h
logits = apply_index_map(logits[0], unflat_indices) # b k 1+h v
embeds = apply_index_map(embeds[0], unflat_indices) # b k 1+h d
else:
next_vals = next_vals.view(
batch_size, n_candidates, decode_seq_length
) # b k 1+h
logits = logits.view(
batch_size, n_candidates, decode_seq_length, logits.size(2)
) # b k 1+h v
embeds = embeds.view(
batch_size, n_candidates, decode_seq_length, embeds.size(2)
) # b k 1+h d
return next_vals, embeds
return logits, embeds


def __generate_targets(
logits: torch.Tensor,
do_sample: torch.Tensor,
temperature: float = 1.0,
top_k: int = 5,
) -> torch.Tensor:
"""
Extracts ground-truth tokens from a set of logits. If performing greedy decoding,
simply returns the most confident tokens. Otherwise, implements consistent multinomial
sampling - two identical distributions will always produce the same (randomized) sample.
Thus by induction, two candidates with identical prefixes will receive the same ground
truth sample up to the point their inputs diverge. This allows us to ensure that at least
one candidate will be accepted, so long as the candidate set covers the top_k options.

For example, if the base model predicts tokens A and B with equal 50% probability, and the
speculator produces one candidate with A and another with B, with independent sampling there's
a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us
to avoid this.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the goal is to speculate on a mutually exclusive set of possible continuations, why are we sampling at all and not just speculating on the top-k predictions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this, but we're more concerned with the ability to sample here than we are with the non-greediness of the approach. In this case "not greedy" is meant strictly literally, in that sampling involves not selecting greedily (assuming I'm understanding the question)


Args:
logits: torch.Tensor
Probability logits for a set of candidate sequences. Expects size
bsize x n_candidates x seq_len x vocab_size
do_sample: torch.Tensor
A tensor of booleans enabling/disabling non-greedy decoding with consistent
sampling, for each of bsize input sequences
temperature: float
Degree of smoothing on softmax sampling distribution
top_k: int
Sample only among the top_k most confident tokens

Returns:
torch.Tensor
Tensor of chosen token values for each sequence
"""

# Get sample distributions
logits = logits / temperature
v, _ = logits.topk(top_k)
logits[logits < v[:, :, :, [-1]]] = -float("inf")
probs = logits.softmax(-1) # b k 1+h v

# Sample candidate-consistent ground truths: partition number line in [0,1]
# according to given multinomial distribution. Pick a random location
# on that line, return interval containing that location.
key = torch.rand(1, 1, logits.size(2), 1, device=probs.device)
a = (
probs.cumsum(3).sub(key).sign()
) # Sign flips on probability interval containing key
samples = a.sub(1).div(-2).sum(3) # Get index of sign-flip

# Composite greedy and non greedy outputs
greedy = logits.argmax(-1)
mask = do_sample[:, None, None].int()
return samples * mask + (1 - mask) * greedy
Copy link
Member

@nairbv nairbv Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the mask is really a mask and not a weighting, might be better to use torch.where.

we're calculating the sampled results even if we don't use them? I guess that's something to do with compilation but I would have thought the generation code would be outside the compile path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll swap to torch.where. We are calculating the sampled result for every case, and while that will be useful for compile down the road, in this case it's mostly just for efficient gpu usage - pretty sure that partitioning the greedy/non-greedy lines and then re-mixing them after is more work than just sampling everything



def speculative_generate(
Expand All @@ -390,6 +446,9 @@ def speculative_generate(
decode_model: Optional[Union[Callable, torch.nn.Module]] = None,
# todo: This is a WIP to enable cudagraphs, currently its only for batch_size=1
cudagraphs: bool = False,
do_sample: bool = False,
temperature: float = 1.0,
top_k: int = 5,
):
"""
A reference implementation of speculative decoding generation.
Expand Down Expand Up @@ -433,6 +492,15 @@ def speculative_generate(
if True, cudagraphs is used and all metadata will be padded, otherwise
metadata will not be padded unless required. Note: This is a WIP and
only works for batch_size=1
do_sample: bool
non-deterministic, multinomial output sampling. False for greedy.
Provides output diversity, but lowers speculative decoding speedup.
temperature: float
temperature of softmax when sampling. Lowering this should provide
better speculative decoding speedup when do_sample=True.
top_k: int
only search among top k tokens. Lowering this should provide
better speculative decoding speedup when do_sample=True.
Returns:
result: List of id tensors, possibly different lengths if batching.
n_steps: Number of foward passes used to generate provided tokens.
Expand Down Expand Up @@ -518,10 +586,18 @@ def speculative_generate(
use_cache=True,
) # 1 n' v OR bk 1+h v

next_vals, embeds = __extract_decode_output(
logits, embeds = __extract_decode_output(
output, unflat_indices, bsize, n_candidates, inp_len
)

if do_sample:
do_sample_vector = torch.ones(bsize, device=logits.device)
else:
do_sample_vector = torch.zeros(bsize, device=logits.device)
next_vals = __generate_targets(
logits, do_sample_vector, temperature=temperature, top_k=top_k
)

next_vals_list, embeds, parent_sequence_ids = __prune_candidates(
input_ids, next_vals, embeds, kv_cache_manager, child_sequence_ids_list
)
Expand Down
23 changes: 21 additions & 2 deletions scripts/paged_speculative_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,21 @@
action="store_true",
help="use a batch of prompts as input (note this is still wip for reduce-overhead=True)",
)
# top_k_tokens_per_head
parser.add_argument(
"--top_k",
type=int,
default=10,
help="sample only among top k most confident tokens (ignored if do_sample=False)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="degree of smoothing for sampling distribution (ignored if do_sample=False)",
)
parser.add_argument(
"--do_sample", action="store_true", help="enable non-greedy generation"
)
parser.add_argument(
"--top_k_tokens_per_head",
type=lambda s: list(map(int, s.split(","))),
Expand Down Expand Up @@ -252,6 +266,9 @@ def infer(ids, warmup):
# todo: we can only reduce-overhead for now when batch size is 1
flattening=not (args.compile and compile_mode == "reduce-overhead"),
cudagraphs=cudagraphs,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
threshes=args.top_k_tokens_per_head,
)
else:
Expand All @@ -261,9 +278,11 @@ def infer(ids, warmup):
kv_cache_manager,
max_new_tokens=100,
max_seq_len=model.config.max_expected_seq_len,
do_sample=False,
decode_model=decode_model,
cudagraphs=cudagraphs,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
)
if not warmup:
total_tokens = 0
Expand Down
Loading