Skip to content

Commit

Permalink
[inference]Optimize the usage of the mid tensors space in flash attn (#…
Browse files Browse the repository at this point in the history
…5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
  • Loading branch information
yuehuayingxueluo committed Jan 26, 2024
1 parent af8359c commit 4f28cb4
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 57 deletions.
10 changes: 10 additions & 0 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class InferenceConfig:
def __post_init__(self):
self._init_batch_size()
self._verify_config()
self._get_dtype()

def _init_batch_size(self):
"""
Expand Down Expand Up @@ -84,6 +85,7 @@ def _verify_config(self) -> None:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"

assert self.dtype in [
"fp16",
"fp32",
Expand All @@ -97,3 +99,11 @@ def _verify_config(self) -> None:
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."

def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
13 changes: 3 additions & 10 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,10 @@ def __init__(
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.dtype = inference_config.dtype

model = model.eval()

if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
model.to(self.dtype)

if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
Expand Down Expand Up @@ -217,6 +210,7 @@ def add_request(
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
Expand All @@ -241,7 +235,6 @@ def step(self) -> List[str]:
batch,
self.k_cahce,
self.v_cache,
padding_id=self.tokenizer.pad_token_id,
)

logits = logits[:, -1, :]
Expand Down
51 changes: 46 additions & 5 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers.configuration_utils import PretrainedConfig

from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
Expand Down Expand Up @@ -69,20 +70,60 @@ class RequestHandler:
Args:
inference_config: Configuration for initialize and manage kv cache.
model_config: Configuration for model
dtype (torch.dtype): The data type for weights and activations.
"""

def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.inference_config = inference_config
self._init_cache(model_config)

self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.dtype = inference_config.dtype
self.max_batch_size = inference_config.max_batch_size

# initialize cache
self._init_cache(model_config)

# initialize batch
device = torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads

fd_inter_tensor = FDIntermTensors()
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
device=device,
)

# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
dtype=self.dtype,
fd_inter_tensor=fd_inter_tensor,
)

def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)

Expand Down
7 changes: 1 addition & 6 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
# Parallel settings
self.tp_size = config.tp_size
# Model settings
if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA
Expand Down
72 changes: 64 additions & 8 deletions colossalai/inference/modeling/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.inference.struct import BatchInfo
from colossalai.kernel.triton import (
Expand Down Expand Up @@ -50,15 +51,13 @@ def llama_causal_lm_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
padding_id=padding_id,
)
logits = self.lm_head(hidden_states)
return logits
Expand All @@ -70,11 +69,10 @@ def llama_model_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask(padding_id)
attention_mask = batch.get_attn_mask()

if attention_mask is not None:
if HAS_TRITON:
Expand All @@ -84,6 +82,7 @@ def llama_model_forward(
else:
sequence_lengths = batch.get_sequence_lengths()

batch_size, _ = input_ids.shape
kv_seq_len = sequence_lengths.max().item()

if attention_mask is not None:
Expand All @@ -102,7 +101,22 @@ def llama_model_forward(

hidden_states = self.embed_tokens(input_ids)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)

if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
Expand All @@ -116,6 +130,9 @@ def llama_model_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = self.norm(hidden_states)
Expand All @@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: int = None,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states

Expand All @@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = residual + hidden_states
Expand Down Expand Up @@ -178,6 +201,9 @@ def llama_attn_forward(
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -206,15 +232,35 @@ def llama_attn_forward(

if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
Expand Down Expand Up @@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_

@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""

if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
Expand Down
Loading

0 comments on commit 4f28cb4

Please sign in to comment.