Skip to content

Commit

Permalink
[inference] refactored config
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Feb 8, 2024
1 parent 1f8c7e7 commit 43a2ee7
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 22 deletions.
53 changes: 32 additions & 21 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,49 +35,60 @@ class InferenceConfig:
"""The inference configuration.
Args:
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): Maximum batch size, defaults to 8.
max_output_len (int): Maximum output length, defaults to 256.
max_input_len (int): Maximum input length, defaults to 256.
block_size (int): The number of blocks in a logical block, defaults to 16.
dtype (Union[str, torch.dtype]): The data type for weights and activations.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
do_sample (bool): Whether to use sampling for generation, defaults to False.
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, defaults to 1.2. We will do a step of prefill
when the actual value exceeds this ratio.
pad_input: Whether to pad all inputs to the max length.
quant_mode (Optional[str]): Quantization mode.
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
prompt_template (Optional[str]): The prompt template for formatting the input text. Some built-in templates include 'llama' and 'vicuna'. Otherwise, the template should contain '{input_text}' for formatting the input text.
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
block_size (int): The number of blocks in a logical block, defaults to 16.
tp_size (int): Tensor parallel size, defaults to 1.
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
"""

micro_batch_size: int = 1
micro_batch_buffer_size: int = None
# NOTE: arrange configs according to their importance and frequency of usage

# runtime limit
max_batch_size: int = 8
max_output_len: int = 256
max_input_len: int = 256
block_size: int = 16

# general configs
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default

tp_size: int = 1
pp_size: int = 1
# TODO: beam search is not support for now
# generation configs
prompt_template: Optional[str] = None
do_sample: bool = False
beam_width: int = 1
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
prefill_ratio: Optional[float] = 1.2
beam_width: int = 1 # TODO: beam search is not support for now
prefill_ratio: Optional[
float
] = 1.2 # the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
pad_input: bool = False
quant_mode: Optional[str] = None
revision: Optional[str] = None
early_stopping: Optional[bool] = False

top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
prompt_template: Optional[str] = None

# paged attention configs
block_size: int = 16

# model parallelism configs
tp_size: int = 1
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None

def __post_init__(self):
self._verify_config()
Expand Down
1 change: 0 additions & 1 deletion colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def _shardformer(
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
Expand Down
86 changes: 86 additions & 0 deletions colossalai/inference/modeling/policy/padding_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from functools import partial

import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm

from colossalai.inference.modeling.models.padding_llama import (
PadLlamaAttention,
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription

# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

try:
from colossalai.kernel.triton import rms_layernorm

HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
return None


class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
policy = super().module_policy()

policy[LlamaDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn",
target_module=PadLlamaAttention,
),
]
)

self.shard_config._infer()

infer_forward = llama_causal_lm_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
)

infer_forward = llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)

infer_forward = llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)

infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()

if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)

return policy

def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model

0 comments on commit 43a2ee7

Please sign in to comment.