Skip to content

Commit

Permalink
[Inference] Fix request handler and add recycle logic (#5260)
Browse files Browse the repository at this point in the history
* fix request handler

* fix comment
  • Loading branch information
CjhHa1 committed Jan 15, 2024
1 parent c597678 commit d8db500
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
18 changes: 16 additions & 2 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def ready_for_prefill(self):
def is_empty(self):
return not self.decoding and not self.prefill

def total_seq_num(self):
return len(self.decoding) + len(self.prefill)


class RequestHandler:
"""
Expand Down Expand Up @@ -105,6 +108,11 @@ def schedule(self):
)
self.abort_sequence(seq.request_id)
break

# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
break

# Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
Expand All @@ -113,6 +121,7 @@ def schedule(self):
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
for seq in remove_list:
lst.remove(seq)

if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
Expand All @@ -121,7 +130,12 @@ def schedule(self):

if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set:
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
seq.recycle()
self.running_batch.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.

return self.running_batch

Expand Down Expand Up @@ -227,4 +241,4 @@ def update(self):

self.done_list.extend(finish_seqs)

return finish_seqs
return finish_seqs
16 changes: 11 additions & 5 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len
# The last allocated block may be either partially or fully occupied.
# `alloc_local_block_idx` is the index of block to be allocated on provided block table.
alloc_local_block_idx = context_len // self.block_size
self.allocate_single_block(block_table, alloc_local_block_idx, 1)
return self.allocate_single_block(block_table, alloc_local_block_idx)

def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int:
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block.
Expand All @@ -221,11 +221,14 @@ def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int,
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block_table.dim() == 1
space_asked = 1
block_global_id = block_table[block_local_idx].item()
if block_global_id < 0:
# Allocate a new block if the current position is not assigned a block yet
assert self._available_blocks > 0, "No available blocks to allocate."
if self._available_blocks <= 0:
# No available blocks to allocate, we free current sequence and return it to
self.free_block_table(block_table)
return True
free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0]
block: CacheBlock = self._cache_blocks[free_block_id]
block.add_ref()
Expand All @@ -235,6 +238,7 @@ def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int,
block_table[block_local_idx] = block_global_id
block: CacheBlock = self._cache_blocks[block_global_id]
return self._allocate_on_block(block, space_asked)
# only when space asked if fully satisfied, the return value will be zero.

def free_block_table(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**."""
Expand Down Expand Up @@ -269,7 +273,9 @@ def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block.available_space > 0, "No available space on block to allocate."
assert (
block.available_space > 0
), "Tried to allocate some space but found no available space left in chosen block."
space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate)
return space_asked - space_to_allocate
Expand Down
10 changes: 10 additions & 0 deletions colossalai/inference/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def mark_aborted(self) -> None:
"""
self.status = RequestStatus.ABORTED

def recycle(self) -> None:
"""
Recycle a running sequnce to waiitting list
"""
assert (
not self.status.is_finished and not self.status == RequestStatus.ABORTED
), "The running sequence \
is already done but it still in running list"
self.status = RequestStatus.WAITING

def __repr__(self) -> str:
return (
f"(request_id={self.request_id}, "
Expand Down

0 comments on commit d8db500

Please sign in to comment.