Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Add BatchInferState, Sequence and InferConfig #5149

73 changes: 73 additions & 0 deletions colossalai/inference/engine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Union

import torch
import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig


class InferConfig:
"""The infer configuration.

CjhHa1 marked this conversation as resolved.
Show resolved Hide resolved
Args:

"""

def __init__(
self,
model: Union[str, nn.Module],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
max_batch_size: int,
max_output_len: int,
max_input_len: int,
block_size: int,
gpu_utilization_rate: float,
dtype: Union[str, torch.dtype],
tp_size: int = 1,
pp_size: int = 1,
max_seq_len: Optional[int] = None,
quant_mode: Optional[str] = None,
revision: Optional[str] = None,
):
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.max_batch_size = max_batch_size
self.max_output_len = max_output_len
self.max_input_len = max_input_len
self.block_size = block_size
self.gpu_utilization_rate = gpu_utilization_rate
self.tp_size = tp_size
self.pp_size = pp_size
self.dtype = dtype
self.max_seq_len = max_seq_len
self.quant_mode = quant_mode
self.revision = revision

self.hf_model_config = self._get_hf_model_config()

def _get_hf_model_config(self) -> PretrainedConfig:
return AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
)

def get_pp_layer_num(self):
CjhHa1 marked this conversation as resolved.
Show resolved Hide resolved
return self.hf_config.num_hidden_layers // self.pp_size

def _verify_args(self):
if self.gpu_utilization_rate > 1.0:
raise ValueError(
f"GPU utilization should be less than 1.0, but is set to {self.gpu_memory_utilization}."
)
if self.tokenizer_mode not in ["auto", "slow"]:
raise ValueError("Tokenizer mode must be " "either 'auto' or 'slow'," f"but got {self.tokenizer_mode}")

if self.hf_config.num_hidden_layers % self.pp_size != 0:
raise ValueError(
"When using pipeline parallel,"
"total number of hidden layers must be divisible by pipeline parallel size."
f"Now total number of hidden layers is {self.hf_config.num_hidden_layers},"
f"pipeline parallel size is {self.pp_size}"
)
141 changes: 141 additions & 0 deletions colossalai/inference/engine/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import enum
from dataclasses import dataclass
from typing import Dict, List, Set


class RequsetStatus(enum.Enum):
"""The status of Sentences"""

WAITING = enum.auto()
RUNNING = enum.auto()
ABORTED = enum.auto()
OVERLENGTH = enum.auto()
COMPLETED = enum.auto()
LENGTH_CAPPED = enum.auto()

@staticmethod
def is_finished(status: "SentenceStatus") -> bool:
return status in [
OVERLENGTH,
COMPLETED,
LENGTH_CAPPED,
]

@staticmethod
def is_running(status: "SentenceStatus") -> bool:
return status == RUNNING

@staticmethod
def is_waiting(status: "SentenceStatus") -> bool:
return status == WAITING


class Sequence:
"""Store the information of a input Sequence.

Args:
request_id: The ID of the sequence.
prompt: The prompt of the sequence.
token_id: The ID of the sequence.
block_size: The block size of the sequence.
sample_params: The sample_params of the sequence.
block_table_index: The index of this sequence in block_table.
"""

def __init__(
self,
request_id: int,
prompt: str,
token_id: int,
blokc_size: int,
sample_params: SampleParams,
block_table_index: int,
):
self.request_id = request_id
self.input_token_id = token_id
self.prompt = prompt
self.blokc_size = blokc_size
self.output_token_id = []
self.output = ""
self.status = SentenceStatus.WAITING
self.sample_params = sample_params
self.batch_infer_state = batch_infer_state
self.block_table_index = block_table_index

def get_sentence_len(self) -> None:
return len(self.input_token_id) + len(self.output_token_id)

def get_input_len(self) -> None:
return len(self.input_token_id)

def get_output_len(self) -> None:
return len(self.output_token_id)

def check_finish(self) -> bool:
return SentenceStatus.check_finish(self.status)

def __repr__(self) -> str:
return (
f"Request ID(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"logical block number={len(self._logical_blocks)}"
)


@dataclass
class BatchInferState:
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
"""
Information to be passed and used for a batch of sequences.
"""

sequences_set: Set[Sequence]
block_table: Dict[int, int]

@classmethod
def init_batch(cls, seqs: List[Sequence]) -> BatchInferState:
sequences_set = set()
block_table = {}
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."

sequences_set.add(seq)
block_table[seq.request_id] = seq.block_table_index

return cls(sequences_set=sequences_set, block_table=block_table)

def clear_batch(self) -> None:
for seq in self.sequences_set:
if not seq.check_finish():
seq.status = RequsetStatus.ABORTED
self.sequences_set.clear()
self.block_table.clear()

def fliter_batch(self) -> None:
for seq in self.sequences_set:
if seq.check_finish():
self.sequences_set.reomve(seq)
del self.block_table[seq.request_id]

def add_seqs(self, seqs: List[Sequence]) -> None:
for seq in seqs:
if seq in sequences_set:
print("The sequence is already in sequences_set.")
assert (
seq.request_id in block_table
), "The sequence has been added to sequences_set, but it has not been added to block_table."
continue
assert (
seq.request_id not in block_table
), "The sequence has not been added to sequences_set, but it is already in block_table."
sequences_set.add(seq)
block_table[seq.request_id] = seq.block_table_index