Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference/SpecDec] Add Speculative Decoding Implementation #5423

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(
self.device = device or get_current_device()
self.dtype = dtype

self._use_spec_dec = False
self._num_tokens_to_verify = None

self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
Expand Down Expand Up @@ -88,6 +91,28 @@ def is_compact(self):
== torch.nonzero(self._block_tables[:, 0] >= 0).numel()
)

@property
def use_spec_dec(self) -> bool:
return self._use_spec_dec

@property
def num_tokens_to_verify(self) -> int:
assert self.use_spec_dec and self._num_tokens_to_verify is not None
return self._num_tokens_to_verify

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
and let the main model verifies tokens in parallel.
"""
self._use_spec_dec = True
self._num_tokens_to_verify = num_tokens_to_verify

def reset_use_spec_dec(self) -> None:
"""Reset the usage of speculative decoding for the batch bucket"""
self._use_spec_dec = False
self._num_tokens_to_verify = None

def _make_compact(self) -> None:
# Clean and Compress the batch based on its sequences dict.
# Namely,compress sequences to the front and clean the seq lengths and block tables tensors.
Expand Down Expand Up @@ -347,6 +372,19 @@ def append_batch_tokens(self, tokens: torch.Tensor) -> None:
seq.check_finish()
self._sequence_lengths[: self.current_batch_size] += 1

def revoke_batch_tokens(self, n: int) -> None:
"""Revoke the last n output tokens of the sequences in the batch

Args:
n (int): The number of output tokens to revoke from each sequence.
It does not count in the context tokens (input tokens).
"""
if n >= 1:
for seq_id, seq in self._sequences_dict.items():
assert seq.output_len >= n, "Revoking len exceeds the current output len of the sequence"
seq.output_token_id = seq.output_token_id[:-n]
self._sequence_lengths -= n

def clear(self, free_block_tables_fn: Optional[Callable[[torch.Tensor], None]]) -> List[int]:
"""Clear all the sequences in the batch.

Expand Down Expand Up @@ -401,6 +439,21 @@ def is_prompts(self) -> bool:
return True
return False

def get_1D_inputs_spec_dec(self, n: int) -> torch.Tensor:
# Used for main model verification in **Decoding Stage**
# `n` is the number of tokens to be verified,
# and so that prepare the last `n` tokens of each sequence as the inputs
assert len(self._sequences_dict) > 0, "No sequence in the batch"
assert all(
seq.output_len >= n for seq in self._sequences_dict.values()
), "Sequence output tokens must be greater than or equal to the number of tokens to be verified."
out_li = []
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.output_token_id[-n:])
return torch.tensor(out_li, dtype=torch.long, device=self.device)

# For compatibility
def get_1D_inputs(self) -> torch.Tensor:
assert len(self._sequences_dict) > 0, "No sequence in the batch"
Expand All @@ -411,15 +464,17 @@ def get_1D_inputs(self) -> torch.Tensor:
seq.output_len == 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
out_li = []
num_tokens = torch.sum(self._sequence_lengths)
out = torch.empty([num_tokens], dtype=torch.long)
seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x])
for seq_id in seq_ids:
seq: Sequence = self._sequences_dict[seq_id]
out_li.extend(seq.input_token_id)
return torch.tensor(out_li, dtype=torch.long, device=self.device)
else:
# Assume decoding stage
if self.use_spec_dec:
# For Speculative Decoding
# the number of tokens to be verified in parallel plus the correct token in the last step
return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1)
assert all(
seq.output_len > 0 for seq in self._sequences_dict.values()
), "Sequence stage (Prefill/Decoding) must be the same in the batch"
Expand Down
6 changes: 6 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class InferenceConfig:
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
Expand Down Expand Up @@ -81,6 +83,10 @@ class InferenceConfig:
top_p: Optional[float] = None
min_p: Optional[float] = None

# speculative decoding configs
max_n_spec_tokens: int = 5
glimpse_large_kv: bool = False

# paged attention configs
block_size: int = 16

Expand Down
180 changes: 161 additions & 19 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -47,18 +48,25 @@ def __init__(
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config
self.model_config = model.config
self.model = model
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self._verify_args()

self.generation_config = inference_config.to_generation_config(self.model_config)
model = model.eval()
model = model.cuda()
model.to(self.dtype)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.n_spec_tokens = self.inference_config.max_n_spec_tokens

if model_policy is None:
if self.inference_config.pad_input:
Expand Down Expand Up @@ -86,21 +94,18 @@ def __init__(

self.counter = count()

def _verify_config(self) -> None:
"""
Verify the input config
"""
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."
if self.model.__class__.__name__ not in _supported_models:
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")

def _shardformer(
self,
Expand Down Expand Up @@ -136,6 +141,138 @@ def _shardformer(
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model

def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.

Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.

```python
...
engine = InferenceEngine(model, tokenizer, inference_config)

engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding

engine.disable_spec_dec()
engine.generate(...) # Normal generation

engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# using speculative decoding for subsequent generations
self.use_spec_dec = True

def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_spec_dec = False
return

def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_spec_dec = False
return

def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.

Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
batch.set_use_spec_dec(self.n_spec_tokens) # set batch to use-spec-dec mode

assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model

# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
drafter_out = self.drafter.speculate(input_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values

# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_ids = batch.get_1D_inputs_spec_dec(1)

batch.reset_use_spec_dec() # reset batch use-spec-dec mode
finished_sequences = self.request_handler.update()

while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
batch.set_use_spec_dec(self.n_spec_tokens)

# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values

for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length

# 4. Decoding - Main model verifies `n` tokens in parallel
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)

# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = self.n_spec_tokens if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(self.n_spec_tokens - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
input_ids = batch.get_1D_inputs_spec_dec(1)
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(drafter_past_key_values, self.n_spec_tokens - n_matches - 1)

self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break

batch.reset_use_spec_dec()

return finished_sequences

def generate(
self,
prompts: List[str] = None,
Expand All @@ -158,7 +295,6 @@ def generate(
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)

Expand All @@ -169,8 +305,13 @@ def generate(
if generation_config is not None:
self.generation_config = generation_config

while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()

output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))

Expand Down Expand Up @@ -301,7 +442,8 @@ def step(self) -> List[str]:

if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)

finished_sequences = self.request_handler.update()

Expand Down
Loading
Loading