Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Optimize and Refactor Inference Batching/Scheduling #5367

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0e9fc31
add kvcache manager funcs for batching
yuanheng-zhao Feb 5, 2024
9c9d199
(trivial) remove print
yuanheng-zhao Feb 5, 2024
1244546
add batch bucket for batching
yuanheng-zhao Feb 5, 2024
5ed7f1e
revise RunningList struct in handler
yuanheng-zhao Feb 6, 2024
b947131
add kvcache/batch funcs for compatibility
yuanheng-zhao Feb 7, 2024
3e411f3
use new batching methods
yuanheng-zhao Feb 7, 2024
4882943
fix indexing bugs
yuanheng-zhao Feb 7, 2024
de59c2a
(trivial) modify comments
yuanheng-zhao Feb 8, 2024
632d5df
revise abort logic
yuanheng-zhao Feb 8, 2024
ca7820f
use cpu seq lengths/block tables
yuanheng-zhao Feb 8, 2024
8ff3615
rm unused attr in Sequence
yuanheng-zhao Feb 8, 2024
734de9b
fix type conversion/default arg
yuanheng-zhao Feb 8, 2024
183d4cf
add and revise pytests
yuanheng-zhao Feb 8, 2024
0b512a3
revise pytests, rm unused tests
yuanheng-zhao Feb 8, 2024
2d7550f
rm unused statements
yuanheng-zhao Feb 8, 2024
b4d913a
fix pop finished indexing issue
yuanheng-zhao Feb 8, 2024
a5e74a5
trivial revise
yuanheng-zhao Feb 8, 2024
5a8a12b
fix: use index in batch when retrieving inputs/update seqs
yuanheng-zhao Feb 15, 2024
3494374
use dict instead of odict in batch struct
yuanheng-zhao Feb 15, 2024
a99f399
arg type hinting
yuanheng-zhao Feb 16, 2024
5323428
fix make compress
yuanheng-zhao Feb 16, 2024
7293e09
refine comments
yuanheng-zhao Feb 16, 2024
6df7714
fix: pop_n_seqs to pop the first n seqs
yuanheng-zhao Feb 16, 2024
0e70068
(trivial) type hints
yuanheng-zhao Feb 16, 2024
07a25d3
add check in request handler
yuanheng-zhao Feb 19, 2024
78cc43a
remove redundant conversion
yuanheng-zhao Feb 19, 2024
b3dcb18
fix test for request handler
yuanheng-zhao Feb 19, 2024
d2e156b
fix pop method in batch bucket
yuanheng-zhao Feb 19, 2024
e1ff72f
fix prefill adding
yuanheng-zhao Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
449 changes: 449 additions & 0 deletions colossalai/inference/batch_bucket.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _verify_config(self) -> None:
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"

# check distributed
assert (
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
# check prompt template
Expand Down
10 changes: 1 addition & 9 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
Expand Down Expand Up @@ -254,20 +254,12 @@ def add_request(
else:
prompt = prompts[i]

max_blocks_per_sequence = (
self.inference_config.max_input_len
+ self.inference_config.max_output_len
+ self.inference_config.block_size
- 1
) // self.inference_config.block_size
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
Expand Down
200 changes: 117 additions & 83 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import List
from typing import Dict, List, Union

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.logging import get_dist_logger

__all__ = ["RunningList", "RequestHandler"]
Expand All @@ -24,45 +25,79 @@ class RunningList:

Args:
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
prefill: (List) List that contains default inputs, defaults to [].
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
"""

def __init__(self, prefill_ratio: str, prefill: List[Sequence] = None):
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
self.prefill_ratio = prefill_ratio
self.decoding: List[Sequence] = []
self.prefill: List[Sequence] = prefill if prefill is not None else []
self._decoding: Dict[int, Sequence] = dict()
self._prefill: Dict[int, Sequence] = (
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
)

def append(self, seq: Sequence):
# add seq to prefilling list first.
self.prefill.append(seq)

def find_seq(self, request_id):
for seq in self.decoding:
if request_id == seq.request_id:
return seq
for seq in self.prefill:
if request_id == seq.request_id:
return seq
return None
@property
def decoding(self):
return list(self._decoding.values())

@property
def prefill(self):
return list(self._prefill.values())

@property
def prefill_seq_num(self):
return len(self._prefill)

@property
def decoding_seq_num(self):
return len(self._decoding)

@property
def total_seq_num(self):
return self.prefill_seq_num + self.decoding_seq_num

def remove(self, seq: Sequence):
if seq in self.decoding:
self.decoding.remove(seq)
elif seq in self.prefill:
self.prefill.remove(seq)
def append(self, seq: Sequence):
assert (seq.request_id not in self._prefill) and (
CjhHa1 marked this conversation as resolved.
Show resolved Hide resolved
seq.request_id not in self._decoding
), f"Sequence uid {seq.request_id} already exists."
self._prefill[seq.request_id] = seq

def extend(self, seqs: List[Sequence]):
for seq in seqs:
self._prefill[seq.request_id] = seq

def find_seq(self, request_id) -> Union[Sequence, None]:
seq = None
if request_id in self._decoding:
seq = self._decoding[request_id]
elif request_id in self._prefill:
seq = self._prefill[request_id]
return seq

def remove(self, seq: Sequence) -> None:
if seq.request_id in self._decoding:
self._decoding.pop(seq.request_id)
elif seq.request_id in self._prefill:
self._prefill.pop(seq.request_id)
else:
raise ValueError(f"sequence {seq.request_id} is not in running list")
raise ValueError(f"Sequence {seq.request_id} is not in running list")

def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
if not self._decoding:
return len(self._prefill) > 0
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio

def is_empty(self):
return not self.decoding and not self.prefill
return not self._decoding and not self._prefill

def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
def mark_prefill_running(self) -> None:
for seq_id in self._prefill:
self._prefill[seq_id].mark_running()

def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
for seq_id in seq_ids:
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
self._decoding[seq_id] = self._prefill.pop(seq_id)


class RequestHandler:
Expand Down Expand Up @@ -110,25 +145,27 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
block_size=inference_config.block_size,
kv_max_split_num=kv_max_split_num,
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
device=device,
)

def _init_cache(self, model_config):
Expand Down Expand Up @@ -159,40 +196,39 @@ def schedule(self):
remove_list.append(seq)
break

# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num():
break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add])

# Try to allocate cache blocks for the sequence.
if (
self.cache_manager.check_allocation(seq)
and (len(self.running_list.prefill) + len(self.running_list.decoding))
< self.max_batch_size # There some bugs in continous batching, so we disable it here.
):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
for seq in remove_list:
lst.remove(seq)

if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
self.prefill_batch.add_seqs(self.running_list.prefill)
return self.prefill_batch
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.running_bb.available_batch_size)

if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
for seq in self.running_list.prefill[:num_seqs_to_add]:
seq.mark_running()
# allocate blocks for the prefill batch
self.prefill_bb.add_seqs(
self.running_list.prefill[:num_seqs_to_add],
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
)

return self.prefill_bb

if not self.running_bb.is_empty:
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
)
if seqs_ids_to_recycle:
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
for seq in seqs_to_recycle:
seq.recycle()
self.running_batch.del_seq(seq)
self.running_list.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.

return self.running_batch
return self.running_bb

def add_sequence(self, req: Sequence):
"""
Expand All @@ -213,7 +249,7 @@ def abort_sequence(self, request_id: str):
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
Expand Down Expand Up @@ -242,7 +278,7 @@ def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config
else:
sample_tokens = greedy_sample(generation_config, logprobs)
else:
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_batch.is_empty)
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)

return sample_tokens

Expand Down Expand Up @@ -273,27 +309,25 @@ def search_tokens(self, generation_config: GenerationConfig, logits):

# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
if not self.prefill_bb.is_empty:
self.prefill_bb.append_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
self.running_bb.append_batch_tokens(sample_tokens)

def update(self):
"""
Update current running list and done list
"""
if not self.prefill_batch.is_empty:
self.running_list.decoding.extend(self.running_list.prefill)
self.running_batch.add_seqs(self.running_list.prefill)
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()

finish_seqs = self.running_batch.fliter_batch()

for seq in finish_seqs:
if not self.prefill_bb.is_empty:
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
self.running_bb.merge(self.prefill_bb)
# clear the prefill batch without assigning a free_block_tables_fn
# since we want to reuse the memory recorded on the block tables
self.prefill_bb.clear(free_block_tables_fn=None)

finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
for seq in finished_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)

self.done_list.extend(finish_seqs)
self.done_list.extend(finished_seqs)

return finish_seqs
return finished_seqs
Loading
Loading