Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference]Optimize the usage of intermediate tensors through flash attn #5304

Merged
6 changes: 4 additions & 2 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.num_heads = self.model_config.num_attention_heads
self.head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads

model = model.eval()

Expand Down Expand Up @@ -79,7 +81,7 @@ def __init__(
if verbose:
self.logger = get_dist_logger(__name__)

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.request_handler = RequestHandler(self.inference_config, self.model_config, self.dtype)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.counter = count()

Expand Down Expand Up @@ -217,6 +219,7 @@ def add_request(
None,
block_table,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
Expand All @@ -241,7 +244,6 @@ def step(self) -> List[str]:
batch,
self.k_cahce,
self.v_cache,
padding_id=self.tokenizer.pad_token_id,
)

logits = logits[:, -1, :]
Expand Down
45 changes: 38 additions & 7 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,53 @@ class RequestHandler:
Args:
inference_config: Configuration for initialize and manage kv cache.
model_config: Configuration for model
dtype (torch.dtype): The data type for weights and activations.
"""

def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
def __init__(
self, inference_config: InferenceConfig, model_config: PretrainedConfig, dtype: torch.dtype = None
) -> None:
self.inference_config = inference_config
self._init_cache(model_config)

self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
self.dtype = dtype
self.max_batch_size = inference_config.max_batch_size

# initialize cache
self._init_cache(model_config)

# initialize batch
device = torch.cuda.current_device()
kv_max_split_num = (
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
) // inference_config.block_size
head_dim = model_config.hidden_size // model_config.num_attention_heads
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=False,
device=device,
dtype=dtype,
)
self.prefill_batch = BatchInfo(
max_batch_size=self.max_batch_size,
kv_max_split_num=kv_max_split_num,
num_heads=model_config.num_attention_heads,
head_dim=head_dim,
is_prompts=True,
device=device,
dtype=dtype,
)
self.running_batch.init_fd_tensors()
self.prefill_batch.init_fd_tensors()
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved

def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
self.cache_manager = KVCacheManager(self.inference_config, model_config, dtype=self.dtype)
isky-cd marked this conversation as resolved.
Show resolved Hide resolved

def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
Expand Down
11 changes: 4 additions & 7 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,16 @@ class KVCacheManager:
And it's possible to have a batch of sequences with different lengths of block tables.
"""

def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
def __init__(
self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False, dtype: torch.dtype = None
) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()

# Parallel settings
self.tp_size = config.tp_size
# Model settings
if config.dtype == "fp32" or config.dtype == torch.float32:
self.dtype = torch.float32
elif config.dtype == "fp16" or config.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16
self.dtype = dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
# For now we focus on MHA only, TODO add handling for MQA and GQA
Expand Down
72 changes: 64 additions & 8 deletions colossalai/inference/modeling/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flash_decoding_attention,
rotary_embedding,
)
from colossalai.kernel.triton.flash_decoding_utils import FDIntermTensors
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
from colossalai.logging import get_dist_logger

from flash_attn.bert_padding import index_first_axis, pad_input # noqa
Expand Down Expand Up @@ -50,15 +51,13 @@ def llama_causal_lm_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
k_caches=k_caches,
v_caches=v_caches,
padding_id=padding_id,
)
logits = self.lm_head(hidden_states)
return logits
Expand All @@ -70,11 +69,10 @@ def llama_model_forward(
batch: BatchInfo = None,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
padding_id: int = None,
):
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask(padding_id)
attention_mask = batch.get_attn_mask()

if attention_mask is not None:
if HAS_TRITON:
Expand All @@ -84,6 +82,7 @@ def llama_model_forward(
else:
sequence_lengths = batch.get_sequence_lengths()

batch_size, _ = input_ids.shape
kv_seq_len = sequence_lengths.max().item()

if attention_mask is not None:
Expand All @@ -102,7 +101,22 @@ def llama_model_forward(

hidden_states = self.embed_tokens(input_ids)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype)
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
# cos_sin = (cos, sin)

cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)

if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)

for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
Expand All @@ -116,6 +130,9 @@ def llama_model_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = self.norm(hidden_states)
Expand All @@ -131,10 +148,13 @@ def llama_decoder_layer_forward(
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: int = None,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states

Expand All @@ -151,6 +171,9 @@ def llama_decoder_layer_forward(
attention_mask=attention_mask,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
)

hidden_states = residual + hidden_states
Expand Down Expand Up @@ -178,6 +201,9 @@ def llama_attn_forward(
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -206,15 +232,35 @@ def llama_attn_forward(

if is_prompts:
attn_output = context_attention_unpadded(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
output=output_tensor,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
output=output_tensor,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
Expand Down Expand Up @@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_

@torch.no_grad()
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
"""
Get cos and sin for the cache, and return nopad format.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
dtype: The data type of this inference process.
"""

if is_prompts:
index_arrays = [torch.arange(length) for length in lengths]
else:
Expand Down
Loading
Loading