Skip to content

Commit

Permalink
[Inference]Fused the gate and up proj in mlp,and optimized the autogr…
Browse files Browse the repository at this point in the history
…ad process. (#5365)

* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
  • Loading branch information
isky-cd committed Feb 6, 2024
1 parent 1dedb57 commit 35382a7
Show file tree
Hide file tree
Showing 10 changed files with 484 additions and 50 deletions.
29 changes: 15 additions & 14 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ def _shardformer(
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: _description_
nn.Module: The model optimized by Shardformer.
"""

shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
Expand Down Expand Up @@ -149,25 +150,25 @@ def generate(
Returns:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

output_seqs_list = []
output_tokens_list = []
output_seqs_list = []
output_tokens_list = []

while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()

output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))

for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)

output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)

return output_str
return output_str

def add_request(
self,
Expand Down
9 changes: 0 additions & 9 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.
Expand Down Expand Up @@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache


@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Expand Down Expand Up @@ -81,7 +79,6 @@ class PagedAttention:
"""

@staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
Expand All @@ -97,14 +94,12 @@ def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
return padded_tensor

@staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask

@staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Expand All @@ -122,7 +117,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)

@staticmethod
@torch.no_grad
def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
Expand Down Expand Up @@ -191,7 +185,6 @@ def nopad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
Expand Down Expand Up @@ -249,7 +242,6 @@ def pad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
Expand Down Expand Up @@ -306,7 +298,6 @@ def pad_decoding_forward(
return attn_output

@staticmethod
@torch.no_grad
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
Expand Down
32 changes: 14 additions & 18 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")


@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
Expand All @@ -58,7 +57,6 @@ def llama_causal_lm_forward(
return logits


@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
Expand Down Expand Up @@ -120,7 +118,6 @@ def llama_model_forward(
return hidden_states


@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
Expand All @@ -139,7 +136,7 @@ def llama_decoder_layer_forward(
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand All @@ -154,8 +151,8 @@ def llama_decoder_layer_forward(
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
Expand Down Expand Up @@ -240,7 +237,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
return attn_layer

# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -258,8 +254,8 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand Down Expand Up @@ -321,7 +317,7 @@ def forward(
sm_scale=sm_scale,
)

attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)

return attn_output
Expand All @@ -345,9 +341,10 @@ def __init__(
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
self.gate_proj = None
self.up_proj = None

@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
Expand All @@ -371,15 +368,14 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight)
Loading

0 comments on commit 35382a7

Please sign in to comment.