Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference]Optimize the usage of intermediate tensors through flash attn #5304

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class InferenceConfig:
def __post_init__(self):
self._init_batch_size()
self._verify_config()
self._get_dtype()

def _init_batch_size(self):
"""
Expand Down Expand Up @@ -84,6 +85,7 @@ def _verify_config(self) -> None:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"

assert self.dtype in [
"fp16",
"fp32",
Expand All @@ -97,3 +99,11 @@ def _verify_config(self) -> None:
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."

def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
13 changes: 3 additions & 10 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,10 @@ def __init__(
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.dtype = inference_config.dtype

model = model.eval()

if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
model.to(self.dtype)

if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
Expand Down Expand Up @@ -217,6 +210,7 @@ def add_request(
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
Expand All @@ -241,7 +235,6 @@ def step(self) -> List[str]:
batch,
self.k_cahce,
self.v_cache,
padding_id=self.tokenizer.pad_token_id,
)

logits = logits[:, -1, :]
Expand Down
51 changes: 46 additions & 5 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
Expand Down Expand Up @@ -69,20 +70,60 @@ class RequestHandler:
Args:
inference_config: Configuration for initialize and manage kv cache.
model_config: Configuration for model
dtype (torch.dtype): The data type for weights and activations.
"""

def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self._init_cache(model_config)

self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size

# initialize cache
self._init_cache(model_config)

# initialize batch
device = torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads

fd_inter_tensor = FDIntermTensors()
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
device=device,
)

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)

def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)

Expand Down
7 changes: 1 addition & 6 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
# Parallel settings
self.tp_size = config.tp_size
# Model settings
if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA
Expand Down
72 changes: 64 additions & 8 deletions colossalai/inference/modeling/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
Expand Down Expand Up @@ -50,15 +51,13 @@ def llama_causal_lm_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
padding_id=padding_id,
)
logits = self.lm_head(hidden_states)
return logits
Expand All @@ -70,11 +69,10 @@ def llama_model_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask(padding_id)
attention_mask = batch.get_attn_mask()

if attention_mask is not None:
if HAS_TRITON:
Expand All @@ -84,6 +82,7 @@ def llama_model_forward(
else:
sequence_lengths = batch.get_sequence_lengths()

batch_size, _ = input_ids.shape
kv_seq_len = sequence_lengths.max().item()

if attention_mask is not None:
Expand All @@ -102,7 +101,22 @@ def llama_model_forward(

hidden_states = self.embed_tokens(input_ids)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)

if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
Expand All @@ -116,6 +130,9 @@ def llama_model_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = self.norm(hidden_states)
Expand All @@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: int = None,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states

Expand All @@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = residual + hidden_states
Expand Down Expand Up @@ -178,6 +201,9 @@ def llama_attn_forward(
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -206,15 +232,35 @@ def llama_attn_forward(

if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
Expand Down Expand Up @@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_

@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""

if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
Expand Down
Loading
Loading