Skip to content

Commit

Permalink
Merge branch 'master' of github.com:bytedance/lightseq
Browse files Browse the repository at this point in the history
  • Loading branch information
hexisyztem committed Dec 7, 2022
2 parents 8930383 + 1f6d693 commit 2b5592f
Show file tree
Hide file tree
Showing 11 changed files with 632 additions and 34 deletions.
105 changes: 105 additions & 0 deletions lightseq/inference/kernels/moeKernels.cc.cu
Expand Up @@ -533,5 +533,110 @@ template void ker_bias_redirect_residual_launcher<__half>(
int block_dim, cudaStream_t stream, const __half* input, const __half* bias,
const float* score, const int* expert_routed, __half* output);

/**
@brief: ker_hard_gate_reorder_pre
reorder input, merge sequences with same gates according to p_d_gate_indexs
@thread
blockIdx.x = batch_index
blockIdx.y = seq_id
@param
input: [seq_len * batch_size, feature_dim]
output: [seq_len * gate_size, feature_dim]
p_d_gate_indexs: [gate_size]
*/
template <typename T>
__global__ void ker_hard_gate_reorder_pre(const T* input, T* output,
int seq_len, int hidden_size,
int* p_d_gate_indexs) {
int batch_index = blockIdx.x, seq_id = blockIdx.y;
int src_index = p_d_gate_indexs[batch_index];
int pos = batch_index * seq_len + seq_id;
int src_pos = src_index * seq_len + seq_id;

for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
output[pos * hidden_size + i] = __ldg(&input[src_pos * hidden_size + i]);
}
}

template <typename T>
void ker_hard_gate_reorder_pre_launcher(const T* input, cudaStream_t stream,
int gate_size, int* p_d_gate_indexs,
T* output, int max_token_num,
int seq_len, int max_thread_per_block,
int hidden_size, int batch_token_num) {
ker_hard_gate_reorder_pre<T>
<<<dim3(gate_size, seq_len), max_thread_per_block, 0, stream>>>(
input, output, seq_len, hidden_size, p_d_gate_indexs);
}

template void ker_hard_gate_reorder_pre_launcher<float>(
const float* input, cudaStream_t stream, int gate_size,
int* p_d_gate_indexs, float* output, int max_token_num, int seq_len,
int max_thread_per_block, int hidden_size, int batch_token_num);
template void ker_hard_gate_reorder_pre_launcher<__half>(
const __half* input, cudaStream_t stream, int gate_size,
int* p_d_gate_indexs, __half* output, int max_token_num, int seq_len,
int max_thread_per_block, int hidden_size, int batch_token_num);

/**
@brief: ker_hard_gate_reorder_post
1.reorder output, reorder output according to p_d_gate_reorder_indexs
2.add bias to output
@thread
blockIdx.x = batch_index
blockIdx.y = seq_id
@param
input: [seq_len * batch_size, feature_dim]
output: [seq_len * batch_size, feature_dim]
p_d_gate_reorder_indexs: [batch_size]
bias: [expoert_num * feature_dim]
*/
template <typename T>
__global__ void ker_hard_gate_reorder_post(const T* input, T* output,
int* p_d_gate_reorder_indexs,
int hidden_size, int seq_len,
const T* bias,
int* _p_d_hard_gates) {
int batch_idx = blockIdx.x;
int seq_id = blockIdx.y;

int src_pos = batch_idx * seq_len + seq_id;
int tgt_batch_idx = p_d_gate_reorder_indexs[batch_idx];
int tgt_pos = tgt_batch_idx * seq_len + seq_id;

int gate_id = _p_d_hard_gates[tgt_batch_idx];

T bias_val;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
bias_val = __ldg(&bias[gate_id * hidden_size + i]);
output[tgt_pos * hidden_size + i] +=
(__ldg(&input[src_pos * hidden_size + i]) + bias_val);
}
}

template <typename T>
void ker_hard_gate_reorder_post_launcher(
cudaStream_t stream, const T* input, T* output, int seq_len,
int max_thread_per_block, int hidden_size, int* p_d_gate_reorder_indexs,
int batch_size, const T* bias, int* _p_d_hard_gates) {
ker_hard_gate_reorder_post<T>
<<<dim3(batch_size, seq_len), max_thread_per_block, 0, stream>>>(
input, output, p_d_gate_reorder_indexs, hidden_size, seq_len, bias,
_p_d_hard_gates);
}

template void ker_hard_gate_reorder_post_launcher<float>(
cudaStream_t stream, const float* input, float* output, int seq_len,
int max_thread_per_block, int hidden_size, int* p_d_gate_reorder_indexs,
int batch_size, const float* bias, int* _p_d_hard_gates);
template void ker_hard_gate_reorder_post_launcher<__half>(
cudaStream_t stream, const __half* input, __half* output, int seq_len,
int max_thread_per_block, int hidden_size, int* p_d_gate_reorder_indexs,
int batch_size, const __half* bias, int* _p_d_hard_gates);

} // namespace cuda
} // namespace lightseq
13 changes: 13 additions & 0 deletions lightseq/inference/kernels/moeKernels.h
Expand Up @@ -46,5 +46,18 @@ void ker_bias_redirect_residual_launcher(int hidden_size, int max_token_num,
const float* score,
const int* expert_routed, T* output);

template <typename T>
void ker_hard_gate_reorder_pre_launcher(const T* input, cudaStream_t stream,
int gate_size, int* p_d_gate_indexs,
T* output, int max_token_num,
int seq_len, int max_thread_per_block,
int hidden_size, int batch_token_num);

template <typename T>
void ker_hard_gate_reorder_post_launcher(
cudaStream_t stream, const T* input, T* output, int seq_len,
int max_thread_per_block, int hidden_size, int* p_d_gate_reorder_indexs,
int batch_size, const T* bias, int* _p_d_hard_gates);

} // namespace cuda
} // namespace lightseq
165 changes: 163 additions & 2 deletions lightseq/inference/model/moe_decoder.cc.cu
Expand Up @@ -668,8 +668,13 @@ void MoeDecoder<OpType_>::encdec_attention() {
template <OperationType OpType_>
void MoeDecoder<OpType_>::ffn_add_norm() {
if (_tw._is_moe_layer_decoder[_layer_id]) {
moe_fw();
++_gate_weight_offset;
if (_tw._gate_type == 1) {
moe_fw_hard_gate();
} else {
// soft gate
moe_fw();
++_gate_weight_offset;
}
} else {
ffn();
}
Expand Down Expand Up @@ -719,6 +724,162 @@ void MoeDecoder<OpType_>::ffn() {
return;
}

template <OperationType OpType_>
void MoeDecoder<OpType_>::set_hard_gates_ptr(int* hard_gates,
std::set<int>* gate_sets,
int* p_d_hard_gates) {
_h_hard_gates = hard_gates;
_gate_sets = gate_sets;
_p_d_hard_gates = p_d_hard_gates;
}

template <OperationType OpType_>
void MoeDecoder<OpType_>::moe_fw_hard_gate() {
// the same with ffn except ffn_weight

if (_batch_size == 1) {
/* ---step 0. layer_norm, add output_bias to "query"--- */
int expert_id = _h_hard_gates[0];

int ffn1_weight_offset = _tw._inner_size * _tw._hidden_size * expert_id;
int ffn1_bias_offset = _tw._inner_size * expert_id;

int ffn2_weight_offset = _tw._inner_size * _tw._hidden_size * expert_id;
int ffn2_bias_offset = _tw._hidden_size * expert_id;

ker_norm_layer_resual_launcher<_DataType>(
_step_token_num, _tw._hidden_size, _stream, _p_d_cur_step_query,
_p_d_query_buf1, _p_d_dec_wei[_weight_offset + 12],
_p_d_dec_wei[_weight_offset + 13],
_p_d_dec_wei[_weight_offset + 17] + ffn2_bias_offset,
_max_thread_per_block, _tw._is_post_ln);

/* ---step 1. first ffn layer--- */
CHECK_GPU_ERROR(cublasGemmEx(
_hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._inner_size, _step_token_num,
_tw._hidden_size, &_type_one,
_p_d_dec_wei[_weight_offset + 14] + ffn1_weight_offset, _AType,
_tw._inner_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_type_zero,
_p_d_query_buf2, _CType, _tw._inner_size, _computeType,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (_tw._use_gelu) {
ker_bias_gelu_launcher<_DataType>(
_step_token_num, _max_thread_per_block, _stream, _p_d_query_buf2,
_p_d_dec_wei[_weight_offset + 15] + ffn1_bias_offset,
_tw._inner_size);
} else {
ker_bias_relu_launcher<_DataType>(
_step_token_num, _max_thread_per_block, _stream, _p_d_query_buf2,
_p_d_dec_wei[_weight_offset + 15] + ffn1_bias_offset,
_tw._inner_size);
}

/* ---step 2. second ffn layer--- */
CHECK_GPU_ERROR(cublasGemmEx(
_hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _step_token_num,
_tw._inner_size, &_type_one,
_p_d_dec_wei[_weight_offset + 16] + ffn2_weight_offset, _AType,
_tw._hidden_size, _p_d_query_buf2, _BType, _tw._inner_size, &_type_one,
_p_d_cur_step_query, _CType, _tw._hidden_size, _computeType,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
/**
if batch_size>1: perform ffn() for each gate respectively, then reorder
logits according to inputs
@param_shape:
_p_d_cur_step_query: [beam_size*batch_size , hidden_dim]
_p_d_query_buf1: [beam_size*batch_size , hidden_dim]
_p_d_moe_input_buf: [beam_size*batch_size , hidden_dim]
_p_d_moe_input_buf_tmp: [beam_size*cur_gate_size , hidden_dim]
_p_d_moe_inner_buf: [beam_size*batch_size , inner_dim]
*/

/* ---step 0. layer_norm --- */
ker_norm_layer_prepost_launcher<_DataType>(
_step_token_num, _tw._hidden_size, _stream, _p_d_cur_step_query,
_p_d_query_buf1, _p_d_dec_wei[_weight_offset + 12],
_p_d_dec_wei[_weight_offset + 13], _max_thread_per_block,
_tw._is_post_ln);

// used for reorder ouptut of each gate
int cursor_p = 0;
_DataType* _p_d_moe_input_buf_tmp;

int* _p_d_cur_gate_indexs = _p_d_hard_gates + 2 * _max_batch_size;
int sizes_index = _max_batch_size;

for (auto it = _gate_sets->begin(); it != _gate_sets->end(); it++) {
int cur_gate_size = _h_hard_gates[sizes_index];

// _p_d_moe_input_buf_tmp: [beam_size*cur_gate_size , hidden_dim]
// pointer of _p_d_moe_input_buf_tmp each gate will accumlate for each
// gate sequence
_p_d_moe_input_buf_tmp =
_p_d_moe_input_buf + cursor_p * _tw._beam_size * _tw._hidden_size;

int expert_id = *it;

int ffn1_weight_offset = _tw._inner_size * _tw._hidden_size * expert_id;
int ffn1_bias_offset = _tw._inner_size * expert_id;

int ffn2_weight_offset = _tw._inner_size * _tw._hidden_size * expert_id;
int ffn2_bias_offset = _tw._hidden_size * expert_id;

/* ---step 1. reorder batch-inputs according to gate--- */
ker_hard_gate_reorder_pre_launcher(
_p_d_query_buf1, _stream, cur_gate_size, _p_d_cur_gate_indexs,
_p_d_moe_input_buf_tmp, _max_step_token_num, _tw._beam_size,
_max_thread_per_block, _tw._hidden_size, _step_token_num);

/* ---step 2. first ffn layer--- */
CHECK_GPU_ERROR(cublasGemmEx(
_hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._inner_size,
cur_gate_size * _tw._beam_size, _tw._hidden_size, &_type_one,
_p_d_dec_wei[_weight_offset + 14] + ffn1_weight_offset, _AType,
_tw._inner_size, _p_d_moe_input_buf_tmp, _BType, _tw._hidden_size,
&_type_zero, _p_d_moe_inner_buf, _CType, _tw._inner_size,
_computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (_tw._use_gelu) {
ker_bias_gelu_launcher<_DataType>(
_tw._beam_size * cur_gate_size, _max_thread_per_block, _stream,
_p_d_moe_inner_buf,
_p_d_dec_wei[_weight_offset + 15] + ffn1_bias_offset,
_tw._inner_size);
} else {
ker_bias_relu_launcher<_DataType>(
_tw._beam_size * cur_gate_size, _max_thread_per_block, _stream,
_p_d_moe_inner_buf,
_p_d_dec_wei[_weight_offset + 15] + ffn1_bias_offset,
_tw._inner_size);
}

/* ---step 3. second ffn layer--- */
CHECK_GPU_ERROR(cublasGemmEx(
_hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size,
_tw._beam_size * cur_gate_size, _tw._inner_size, &_type_one,
_p_d_dec_wei[_weight_offset + 16] + ffn2_weight_offset, _AType,
_tw._hidden_size, _p_d_moe_inner_buf, _BType, _tw._inner_size,
&_type_zero, _p_d_moe_input_buf_tmp, _CType, _tw._hidden_size,
_computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP));

cursor_p += cur_gate_size;
_p_d_cur_gate_indexs = _p_d_cur_gate_indexs + cur_gate_size;
sizes_index++;
}

/* ---step 4. reorder output of different gate--- */
// 1. add ffn2 bias
// 2. reorder ffn result(_p_d_moe_input_buf) according to input
ker_hard_gate_reorder_post_launcher(
_stream, _p_d_moe_input_buf, _p_d_cur_step_query, _tw._beam_size,
_max_thread_per_block, _tw._hidden_size,
_p_d_hard_gates + 2 * _max_batch_size, _batch_size,
_p_d_dec_wei[_weight_offset + 17], _p_d_hard_gates);
}
}

template <OperationType OpType_>
void MoeDecoder<OpType_>::moe_fw() {
ker_norm_layer_prepost_launcher<_DataType>(
Expand Down
8 changes: 8 additions & 0 deletions lightseq/inference/model/moe_decoder.h
Expand Up @@ -43,6 +43,7 @@ class MoeDecoder {
void encdec_attention();
void ffn_add_norm();
void ffn();
void moe_fw_hard_gate();
void moe_fw();
bool sample();
bool beam_search();
Expand Down Expand Up @@ -75,6 +76,11 @@ class MoeDecoder {
int* _p_d_alive_seq;
int* _p_d_alive_seq_buf;
int* _p_d_expert_id_routed;

int* _p_d_hard_gates;
int* _h_hard_gates;
std::set<int>* _gate_sets;

_DataType* _p_d_cur_step_query;
// cur step's projected query-key-value in self atten, one pointer for one
// decoder layer device memory in [batch_size, beam_size, 3, hidden_size]
Expand Down Expand Up @@ -141,6 +147,8 @@ class MoeDecoder {
MoeWeight<OpType_>& tw, cudaStream_t stream, cublasHandle_t hd,
bool output_topk = false, const int* p_d_lang_id = nullptr);
long compute_buffer_bytesize();
void set_hard_gates_ptr(int* hard_gates, std::set<int>* gate_sets,
int* p_d_hard_gates);
void init_buffer(void* pbuf);
std::string check();
void run_one_infer(int batch_size, int batch_seq_len);
Expand Down

0 comments on commit 2b5592f

Please sign in to comment.