diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index c08982564..3fb84460c 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -343,7 +343,7 @@ DEFINE_string(store_metadata_connstring, "", "The address of the kv cache store metadata service."); -// --- for computation communication parallel --- +// --- computation communication parallel config --- DEFINE_bool( enable_multi_stream_parallel, @@ -355,7 +355,7 @@ DEFINE_int32(default_micro_batch_num, 2, "Default use two micro batches for multi-stream parallel."); -// --- for dit --- +// --- dit config --- DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch."); // --- continuous kv cache config --- @@ -377,4 +377,4 @@ DEFINE_int64(cache_size_per_token, DEFINE_int64(buffer_size_per_seq, 0, - "Buffer size per sequence in bytes, default 0."); + "Buffer size per sequence in bytes, default 0."); \ No newline at end of file diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index a05182b38..04b58110b 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -189,7 +189,6 @@ DECLARE_int32(max_global_ttft_ms); DECLARE_int32(max_global_tpot_ms); -// dit DECLARE_int32(max_requests_per_batch); DECLARE_bool(enable_continuous_kvcache); @@ -198,4 +197,4 @@ DECLARE_int64(granularity_size); DECLARE_int64(cache_size_per_token); -DECLARE_int64(buffer_size_per_seq); +DECLARE_int64(buffer_size_per_seq); \ No newline at end of file diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 264f411e7..8626324e5 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -128,7 +128,6 @@ class BatchInputBuilder { uint32_t q_seq_len, BuilderState* state_ptr = nullptr, std::unordered_set* write_block_ids_ptr = nullptr); - void setup_continuous_kv_cache_info(Sequence* sequence, uint32_t n_kv_cache_tokens, uint32_t seq_len, diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 63cc79493..aaaae36d6 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -93,6 +93,7 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + return params; } diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 9639b732e..4aa1941b6 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -1,12 +1,24 @@ include(cc_library) if(USE_NPU) - include_directories( - ${CMAKE_SOURCE_DIR}/third_party/spdlog/include - ) add_subdirectory(npu) endif() if(USE_MLU) add_subdirectory(mlu) endif() + + +cc_library( + NAME + kernels + HDRS + param.h + ops_api.h + SRCS + ops_api.cpp + DEPS + torch + $<$:npu_kernels> + $<$:mlu_kernels> +) \ No newline at end of file diff --git a/xllm/core/kernels/mlu/CMakeLists.txt b/xllm/core/kernels/mlu/CMakeLists.txt index 7e4fdf525..0517eb258 100644 --- a/xllm/core/kernels/mlu/CMakeLists.txt +++ b/xllm/core/kernels/mlu/CMakeLists.txt @@ -2,7 +2,6 @@ include(cc_library) file(GLOB_RECURSE MLU_HEADER_FILES "${CMAKE_CURRENT_LIST_DIR}/*.h" - "${CMAKE_CURRENT_LIST_DIR}/*.hpp" ) file(GLOB_RECURSE MLU_SOURCE_FILES @@ -11,7 +10,7 @@ file(GLOB_RECURSE MLU_SOURCE_FILES cc_library( NAME - xllm_mlu_ops + mlu_kernels HDRS ${MLU_HEADER_FILES} SRCS diff --git a/xllm/core/kernels/mlu/active.cpp b/xllm/core/kernels/mlu/active.cpp new file mode 100644 index 000000000..66e864804 --- /dev/null +++ b/xllm/core/kernels/mlu/active.cpp @@ -0,0 +1,38 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlu_ops_api.h" +#include "torch_mlu_ops.h" + +namespace xllm::kernel::mlu { + +void active(const torch::Tensor& input, + torch::Tensor& output, + const std::optional& bias, + const std::optional& cusum_token_count, + const std::string& act_mode, + bool is_gated, + int start_expert_id, + int expert_size) { + tmo::torch_api::active(input, + output, + bias, + cusum_token_count, + act_mode, + is_gated, + start_expert_id, + expert_size); +} +} // namespace xllm::kernel::mlu \ No newline at end of file diff --git a/xllm/core/kernels/mlu/attention.cpp b/xllm/core/kernels/mlu/attention.cpp new file mode 100644 index 000000000..5cee13171 --- /dev/null +++ b/xllm/core/kernels/mlu/attention.cpp @@ -0,0 +1,119 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlu_ops_api.h" +#include "torch_mlu_ops.h" + +namespace xllm::kernel::mlu { + +void reshape_paged_cache(torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& k_cache, + torch::Tensor& v_cache, + const torch::Tensor& slot_mapping, + bool direction) { + tmo::torch_api::reshape_paged_cache( + key, value, k_cache, v_cache, slot_mapping, direction); +} + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& output, + std::optional& output_lse, + const std::optional& query_start_loc, + const std::optional& seq_start_loc, + const std::optional& alibi_slope, + const std::optional& attn_bias, + const std::optional& q_quant_scale, + const std::optional& k_quant_scale, + const std::optional& v_quant_scale, + const std::optional& out_quant_scale, + const std::optional& block_table, + int max_query_len, + int max_seq_len, + float scale, + bool is_causal, + int window_size_left, + int window_size_right, + const std::string& compute_dtype, + bool return_lse) { + tmo::torch_api::flash_attention(query, + key, + value, + output, + output_lse, + query_start_loc, + seq_start_loc, + alibi_slope, + attn_bias, + q_quant_scale, + k_quant_scale, + v_quant_scale, + out_quant_scale, + block_table, + max_query_len, + max_seq_len, + scale, + is_causal, + window_size_left, + window_size_right, + compute_dtype, + return_lse); +} + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + torch::Tensor& output, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + const torch::Tensor& v_cache, + std::optional& output_lse, + const std::optional& q_quant_scale, + const std::optional& k_cache_quant_scale, + const std::optional& v_cache_quant_scale, + const std::optional& out_quant_scale, + const std::optional& alibi_slope, + const std::optional& mask, + const std::string& compute_dtype, + int max_seq_len, + int window_size_left, + int window_size_right, + float scale, + bool return_lse, + int kv_cache_quant_bit_size) { + tmo::torch_api::single_query_cached_kv_attn(query, + k_cache, + output, + block_table, + seq_lens, + v_cache, + output_lse, + q_quant_scale, + k_cache_quant_scale, + v_cache_quant_scale, + out_quant_scale, + alibi_slope, + mask, + compute_dtype, + max_seq_len, + window_size_left, + window_size_right, + scale, + return_lse, + kv_cache_quant_bit_size); +} + +} // namespace xllm::kernel::mlu \ No newline at end of file diff --git a/xllm/core/kernels/mlu/fused_layernorm.cpp b/xllm/core/kernels/mlu/fused_layernorm.cpp new file mode 100644 index 000000000..7ca2b8dff --- /dev/null +++ b/xllm/core/kernels/mlu/fused_layernorm.cpp @@ -0,0 +1,53 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlu_ops_api.h" +#include "torch_mlu_ops.h" + +namespace xllm::kernel::mlu { + +void fused_layernorm(const torch::Tensor& input, + torch::Tensor& output, + const std::optional& residual, + const torch::Tensor& weight, + const std::optional& beta, + const std::optional& bias, + const std::optional& quant_scale, + const std::optional& residual_out, + const std::optional& smooth_quant_scale, + const std::optional& normed_out, + const std::string& mode, + double eps, + bool store_output_before_norm, + bool store_output_after_norm, + bool dynamic_quant) { + tmo::torch_api::fused_layernorm(input, + output, + residual, + weight, + beta, + bias, + quant_scale, + residual_out, + smooth_quant_scale, + normed_out, + mode, + eps, + store_output_before_norm, + store_output_after_norm, + dynamic_quant); +} + +} // namespace xllm::kernel::mlu \ No newline at end of file diff --git a/xllm/core/kernels/mlu/fused_moe.cpp b/xllm/core/kernels/mlu/fused_moe.cpp index f71e5cb53..7532980af 100644 --- a/xllm/core/kernels/mlu/fused_moe.cpp +++ b/xllm/core/kernels/mlu/fused_moe.cpp @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlu_ops_api.h" #include "torch_mlu_ops.h" -#include "torch_ops_api.h" namespace { torch::Tensor create_group_gemm_output(const torch::Tensor& a, @@ -27,7 +27,7 @@ torch::Tensor create_group_gemm_output(const torch::Tensor& a, } } // namespace -namespace xllm::mlu { +namespace xllm::kernel::mlu { torch::Tensor fused_moe( const torch::Tensor& hidden_states, const torch::Tensor& gating_output, @@ -175,4 +175,4 @@ torch::Tensor fused_moe( return output.reshape(ori_input_shape); } -} // namespace xllm::mlu +} // namespace xllm::kernel::mlu diff --git a/xllm/core/kernels/mlu/matmul.cpp b/xllm/core/kernels/mlu/matmul.cpp index 1c20f885e..559192400 100644 --- a/xllm/core/kernels/mlu/matmul.cpp +++ b/xllm/core/kernels/mlu/matmul.cpp @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlu_ops_api.h" #include "torch_mlu_ops.h" -#include "torch_ops_api.h" -namespace xllm::mlu { +namespace xllm::kernel::mlu { -at::Tensor matmul(const at::Tensor& a, - const at::Tensor& b, - const std::optional& bias, - const std::optional& c, - double alpha, - double beta) { +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias, + const std::optional& c, + double alpha, + double beta) { return tmo::torch_api::matmul(a, b, bias, @@ -43,4 +43,4 @@ at::Tensor matmul(const at::Tensor& a, true); } -} // namespace xllm::mlu +} // namespace xllm::kernel::mlu diff --git a/xllm/core/kernels/mlu/mlu_ops_api.h b/xllm/core/kernels/mlu/mlu_ops_api.h new file mode 100644 index 000000000..1db91b210 --- /dev/null +++ b/xllm/core/kernels/mlu/mlu_ops_api.h @@ -0,0 +1,157 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +#include "ATen/Tensor.h" +#include "torch_mlu_ops.h" + +namespace xllm::kernel::mlu { + +static const std::string kActModeSilu = "silu"; +static const std::string kActModeGelu = "gelu"; +static const std::string kActModeQuickGelu = "quick_gelu"; +static const std::string kActModeSwish = "swish"; + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& sin, + const torch::Tensor& cos, + const std::optional& position_ids, + const torch::Tensor& cu_query_lens, + bool interleaved, + bool discrete, + bool dynamic_ntk, + int max_query_len); + +void active(const torch::Tensor& input, + torch::Tensor& output, + const std::optional& bias, + const std::optional& cusum_token_count, + const std::string& act_mode, + bool is_gated, + int start_expert_id, + int expert_size); + +void reshape_paged_cache(torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& k_cache, + torch::Tensor& v_cache, + const torch::Tensor& slot_mapping, + bool direction); + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + torch::Tensor& output, + std::optional& output_lse, + const std::optional& query_start_loc, + const std::optional& seq_start_loc, + const std::optional& alibi_slope, + const std::optional& attn_bias, + const std::optional& q_quant_scale, + const std::optional& k_quant_scale, + const std::optional& v_quant_scale, + const std::optional& out_quant_scale, + const std::optional& block_tables, + int max_query_len, + int max_seq_len, + float scale, + bool is_causal, + int window_size_left, + int window_size_right, + const std::string& compute_dtype, + bool return_lse); + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + torch::Tensor& output, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + const torch::Tensor& v_cache, + std::optional& output_lse, + const std::optional& q_quant_scale, + const std::optional& k_cache_quant_scale, + const std::optional& v_cache_quant_scale, + const std::optional& out_quant_scale, + const std::optional& alibi_slope, + const std::optional& mask, + const std::string& compute_dtype, + int max_seq_len, + int window_size_left, + int window_size_right, + float scale, + bool return_lse, + int kv_cache_quant_bit_size); + +void fused_layernorm(const torch::Tensor& input, + torch::Tensor& output, + const std::optional& residual, + const torch::Tensor& weight, + const std::optional& beta, + const std::optional& bias, + const std::optional& quant_scale, + const std::optional& residual_out, + const std::optional& smooth_quant_scale, + const std::optional& normed_out, + const std::string& mode, + double eps, + bool store_output_before_norm, + bool store_output_after_norm, + bool dynamic_quant); + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias, + const std::optional& c, + double alpha, + double beta); + +torch::Tensor fused_moe( + const torch::Tensor& hidden_states, + const torch::Tensor& gating_output, + const torch::Tensor& w1, + const torch::Tensor& w2, + const std::optional& bias1, + const std::optional& bias2, + const std::optional& residual, + const std::optional& input_smooth, + const std::optional& act_smooth, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional& e_score_correction_bias, + int topk, + bool renormalize, + bool gated, + const std::string& act_mode, + const std::string& scoring_func, + int start_expert_id, + int block_n, + bool avg_moe, + const std::optional& class_reduce_weight, + const std::optional& class_expert_id, + const std::optional>& w1_quant_flag, + const std::optional>& w2_quant_flag, + int world_size, + int shared_expert_num, + const std::string& parallel_mode); + +} // namespace xllm::kernel::mlu diff --git a/xllm/core/kernels/mlu/rope.cpp b/xllm/core/kernels/mlu/rope.cpp new file mode 100644 index 000000000..850c73ff3 --- /dev/null +++ b/xllm/core/kernels/mlu/rope.cpp @@ -0,0 +1,53 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlu_ops_api.h" +#include "torch_mlu_ops.h" + +namespace xllm::kernel::mlu { + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& sin, + const torch::Tensor& cos, + const std::optional& position_ids, + const torch::Tensor& cu_query_lens, + bool interleaved, + bool discrete, + bool dynamic_ntk, + int max_query_len) { + const int64_t rotary_dim = sin.size(-1); + const int64_t T = q.size(0); + q = q.view({T, -1}); + k = k.view({T, -1}); + auto qk = torch::cat({q, k}, /*dim=*/-1); + qk = qk.view({T, -1, rotary_dim}); + tmo::torch_api::apply_rotary(qk, + qk /* output */, + sin, + cos, + position_ids, + cu_query_lens, + interleaved, + discrete, + false /* dynamic_ntk */, + max_query_len); + qk = qk.view({-1, q.size(-1) + k.size(-1)}); + auto qk_vec = qk.split({q.size(-1), k.size(-1)}, /*dim=*/-1); + q = qk_vec[0]; + k = qk_vec[1]; +} + +} // namespace xllm::kernel::mlu \ No newline at end of file diff --git a/xllm/core/kernels/mlu/torch_ops_api.h b/xllm/core/kernels/mlu/torch_ops_api.h deleted file mode 100644 index 458167e66..000000000 --- a/xllm/core/kernels/mlu/torch_ops_api.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include - -#include - -#include "ATen/Tensor.h" -#include "torch_mlu_ops.h" -namespace xllm::mlu { - -static const std::string kActModeSilu = "silu"; -static const std::string kActModeGelu = "gelu"; -static const std::string kActModeQuickGelu = "quick_gelu"; -static const std::string kActModeSwish = "swish"; - -at::Tensor matmul(const at::Tensor& a, - const at::Tensor& b, - const std::optional& bias, - const std::optional& c, - double alpha, - double beta); - -torch::Tensor fused_moe( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const torch::Tensor& w1, - const torch::Tensor& w2, - const std::optional& bias1, - const std::optional& bias2, - const std::optional& residual, - const std::optional& input_smooth, - const std::optional& act_smooth, - const std::optional& w1_scale, - const std::optional& w2_scale, - const std::optional& e_score_correction_bias, - int topk, - bool renormalize, - bool gated, - const std::string& act_mode, - const std::string& scoring_func = "softmax", - int start_expert_id = 0, - int block_n = 0, - bool avg_moe = false, - const std::optional& class_reduce_weight = std::nullopt, - const std::optional& class_expert_id = std::nullopt, - const std::optional>& w1_quant_flag = std::nullopt, - const std::optional>& w2_quant_flag = std::nullopt, - int world_size = 0, - int shared_expert_num = 0, - const std::string& parallel_mode = "ep"); - -} // namespace xllm::mlu diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index c8558d6bd..5553d8a09 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -1,96 +1,17 @@ include(cc_library) -include(cc_test) +add_subdirectory(impl) add_subdirectory(xllm_ops) cc_library( NAME npu_kernels HDRS - npu_split_impl.h - npu_linear_impl.h - npu_rms_norm_impl.h - npu_rope_impl.h - SRCS - npu_split_impl.cpp - npu_linear_impl.cpp - npu_rms_norm_impl.cpp - npu_rope_impl.cpp + linear.h + split.h + rms_norm.h + rope.h DEPS - :npu_layers - :model_context - :state_dict - glog::glog - torch - torch_npu -) - -cc_test( - NAME - npu_rms_norm_test - SRCS - npu_rms_norm_test.cpp - DEPS - :npu_kernels - GTest::gtest - GTest::gtest_main - xllm_kernels - c_sec - atb -) - -cc_test( - NAME - npu_linear_test - SRCS - npu_linear_test.cpp - DEPS - :npu_kernels - GTest::gtest - GTest::gtest_main - xllm_kernels - c_sec - atb -) - -cc_test( - NAME - npu_split_test - SRCS - npu_split_test.cpp - DEPS - :npu_kernels - GTest::gtest - GTest::gtest_main - xllm_kernels - c_sec - atb -) - -cc_test( - NAME - npu_rope_impl_test - SRCS - npu_rope_impl_test.cpp - DEPS - :npu_kernels - GTest::gtest - GTest::gtest_main - xllm_kernels - c_sec - atb -) - -cc_test( - NAME - npu_sample_model_test - SRCS - npu_sample_model_test.cpp - DEPS - :npu_kernels - GTest::gtest - GTest::gtest_main - xllm_kernels - c_sec - atb + :npu_kernels_impl + # spdlog::spdlog ) \ No newline at end of file diff --git a/xllm/core/kernels/npu/impl/CMakeLists.txt b/xllm/core/kernels/npu/impl/CMakeLists.txt new file mode 100644 index 000000000..d8ec37ff7 --- /dev/null +++ b/xllm/core/kernels/npu/impl/CMakeLists.txt @@ -0,0 +1,104 @@ +include(cc_library) +include(cc_test) + +include_directories( + ${CMAKE_SOURCE_DIR}/third_party/spdlog/include +) + + +cc_library( + NAME + npu_kernels_impl + HDRS + npu_split_impl.h + npu_linear_impl.h + npu_rms_norm_impl.h + npu_rope_impl.h + SRCS + npu_split_impl.cpp + npu_linear_impl.cpp + npu_rms_norm_impl.cpp + npu_rope_impl.cpp + DEPS + :npu_layers + :model_context + :state_dict + glog::glog + torch + torch_npu +) + +cc_test( + NAME + npu_rms_norm_test + SRCS + npu_rms_norm_test.cpp + DEPS + :npu_kernels_impl + GTest::gtest + GTest::gtest_main + xllm_kernels + c_sec + atb + spdlog::spdlog +) + +cc_test( + NAME + npu_linear_test + SRCS + npu_linear_test.cpp + DEPS + :npu_kernels_impl + GTest::gtest + GTest::gtest_main + xllm_kernels + c_sec + atb + spdlog::spdlog +) + +cc_test( + NAME + npu_split_test + SRCS + npu_split_test.cpp + DEPS + :npu_kernels_impl + GTest::gtest + GTest::gtest_main + xllm_kernels + c_sec + atb + spdlog::spdlog +) + +cc_test( + NAME + npu_rope_impl_test + SRCS + npu_rope_impl_test.cpp + DEPS + :npu_kernels_impl + GTest::gtest + GTest::gtest_main + xllm_kernels + c_sec + atb + spdlog::spdlog +) + +cc_test( + NAME + npu_sample_model_test + SRCS + npu_sample_model_test.cpp + DEPS + :npu_kernels_impl + GTest::gtest + GTest::gtest_main + xllm_kernels + c_sec + atb + spdlog::spdlog +) \ No newline at end of file diff --git a/xllm/core/kernels/npu/npu_linear_impl.cpp b/xllm/core/kernels/npu/impl/npu_linear_impl.cpp similarity index 100% rename from xllm/core/kernels/npu/npu_linear_impl.cpp rename to xllm/core/kernels/npu/impl/npu_linear_impl.cpp diff --git a/xllm/core/kernels/npu/npu_linear_impl.h b/xllm/core/kernels/npu/impl/npu_linear_impl.h similarity index 100% rename from xllm/core/kernels/npu/npu_linear_impl.h rename to xllm/core/kernels/npu/impl/npu_linear_impl.h diff --git a/xllm/core/kernels/npu/npu_linear_test.cpp b/xllm/core/kernels/npu/impl/npu_linear_test.cpp similarity index 99% rename from xllm/core/kernels/npu/npu_linear_test.cpp rename to xllm/core/kernels/npu/impl/npu_linear_test.cpp index b18feab20..ec4607a03 100644 --- a/xllm/core/kernels/npu/npu_linear_test.cpp +++ b/xllm/core/kernels/npu/impl/npu_linear_test.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "core/kernels/linear.h" +#include "kernels/npu/linear.h" namespace xllm::kernel { diff --git a/xllm/core/kernels/npu/npu_rms_norm_impl.cpp b/xllm/core/kernels/npu/impl/npu_rms_norm_impl.cpp similarity index 100% rename from xllm/core/kernels/npu/npu_rms_norm_impl.cpp rename to xllm/core/kernels/npu/impl/npu_rms_norm_impl.cpp diff --git a/xllm/core/kernels/npu/npu_rms_norm_impl.h b/xllm/core/kernels/npu/impl/npu_rms_norm_impl.h similarity index 100% rename from xllm/core/kernels/npu/npu_rms_norm_impl.h rename to xllm/core/kernels/npu/impl/npu_rms_norm_impl.h diff --git a/xllm/core/kernels/npu/npu_rms_norm_test.cpp b/xllm/core/kernels/npu/impl/npu_rms_norm_test.cpp similarity index 99% rename from xllm/core/kernels/npu/npu_rms_norm_test.cpp rename to xllm/core/kernels/npu/impl/npu_rms_norm_test.cpp index 50fdff8e8..df4c0ce3a 100644 --- a/xllm/core/kernels/npu/npu_rms_norm_test.cpp +++ b/xllm/core/kernels/npu/impl/npu_rms_norm_test.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "core/kernels/rms_norm.h" +#include "kernels/npu/rms_norm.h" namespace xllm::kernel { diff --git a/xllm/core/kernels/npu/npu_rope_impl.cpp b/xllm/core/kernels/npu/impl/npu_rope_impl.cpp similarity index 100% rename from xllm/core/kernels/npu/npu_rope_impl.cpp rename to xllm/core/kernels/npu/impl/npu_rope_impl.cpp diff --git a/xllm/core/kernels/npu/npu_rope_impl.h b/xllm/core/kernels/npu/impl/npu_rope_impl.h similarity index 100% rename from xllm/core/kernels/npu/npu_rope_impl.h rename to xllm/core/kernels/npu/impl/npu_rope_impl.h diff --git a/xllm/core/kernels/npu/npu_rope_impl_test.cpp b/xllm/core/kernels/npu/impl/npu_rope_impl_test.cpp similarity index 99% rename from xllm/core/kernels/npu/npu_rope_impl_test.cpp rename to xllm/core/kernels/npu/impl/npu_rope_impl_test.cpp index 4caa64e93..26a78bef9 100644 --- a/xllm/core/kernels/npu/npu_rope_impl_test.cpp +++ b/xllm/core/kernels/npu/impl/npu_rope_impl_test.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "core/kernels/rope.h" +#include "kernels/npu/rope.h" namespace xllm::kernel { diff --git a/xllm/core/kernels/npu/npu_sample_model_test.cpp b/xllm/core/kernels/npu/impl/npu_sample_model_test.cpp similarity index 99% rename from xllm/core/kernels/npu/npu_sample_model_test.cpp rename to xllm/core/kernels/npu/impl/npu_sample_model_test.cpp index 5db6ce41f..c8cee2d4e 100644 --- a/xllm/core/kernels/npu/npu_sample_model_test.cpp +++ b/xllm/core/kernels/npu/impl/npu_sample_model_test.cpp @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "core/kernels/linear.h" -#include "core/kernels/rms_norm.h" -#include "core/kernels/rope.h" -#include "core/kernels/split.h" +#include "kernels/npu/linear.h" +#include "kernels/npu/rms_norm.h" +#include "kernels/npu/rope.h" +#include "kernels/npu/split.h" namespace xllm::kernel { diff --git a/xllm/core/kernels/npu/npu_split_impl.cpp b/xllm/core/kernels/npu/impl/npu_split_impl.cpp similarity index 100% rename from xllm/core/kernels/npu/npu_split_impl.cpp rename to xllm/core/kernels/npu/impl/npu_split_impl.cpp diff --git a/xllm/core/kernels/npu/npu_split_impl.h b/xllm/core/kernels/npu/impl/npu_split_impl.h similarity index 100% rename from xllm/core/kernels/npu/npu_split_impl.h rename to xllm/core/kernels/npu/impl/npu_split_impl.h diff --git a/xllm/core/kernels/npu/npu_split_test.cpp b/xllm/core/kernels/npu/impl/npu_split_test.cpp similarity index 99% rename from xllm/core/kernels/npu/npu_split_test.cpp rename to xllm/core/kernels/npu/impl/npu_split_test.cpp index 33057202d..2e28d26cc 100644 --- a/xllm/core/kernels/npu/npu_split_test.cpp +++ b/xllm/core/kernels/npu/impl/npu_split_test.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "core/kernels/split.h" +#include "kernels/npu/split.h" namespace xllm::kernel { diff --git a/xllm/core/kernels/linear.h b/xllm/core/kernels/npu/linear.h similarity index 92% rename from xllm/core/kernels/linear.h rename to xllm/core/kernels/npu/linear.h index 72d0c4e80..0834c014b 100644 --- a/xllm/core/kernels/linear.h +++ b/xllm/core/kernels/npu/linear.h @@ -14,13 +14,10 @@ limitations under the License. ==============================================================================*/ #pragma once -#if defined(USE_NPU) -#include "npu/npu_linear_impl.h" -#endif +#include "impl/npu_linear_impl.h" namespace xllm::kernel { -#if defined(USE_NPU) class Linear : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; @@ -29,6 +26,5 @@ class Linear : public torch::nn::ModuleHolder { Linear(const ModelContext& context) : ModuleHolder(std::make_shared(context)) {} }; -#endif } // namespace xllm::kernel diff --git a/xllm/core/kernels/rms_norm.h b/xllm/core/kernels/npu/rms_norm.h similarity index 92% rename from xllm/core/kernels/rms_norm.h rename to xllm/core/kernels/npu/rms_norm.h index ea6fb42e9..ed7f8d047 100644 --- a/xllm/core/kernels/rms_norm.h +++ b/xllm/core/kernels/npu/rms_norm.h @@ -14,14 +14,11 @@ limitations under the License. ==============================================================================*/ #pragma once -#if defined(USE_NPU) -#include "npu/npu_rms_norm_impl.h" -#endif +#include "impl/npu_rms_norm_impl.h" namespace xllm { namespace kernel { -#if defined(USE_NPU) class RmsNorm : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; @@ -30,7 +27,6 @@ class RmsNorm : public torch::nn::ModuleHolder { RmsNorm(const ModelContext& context) : ModuleHolder(std::make_shared(context)) {} }; -#endif } // namespace kernel } // namespace xllm diff --git a/xllm/core/kernels/rope.h b/xllm/core/kernels/npu/rope.h similarity index 92% rename from xllm/core/kernels/rope.h rename to xllm/core/kernels/npu/rope.h index a79d089e9..7a075b0d3 100644 --- a/xllm/core/kernels/rope.h +++ b/xllm/core/kernels/npu/rope.h @@ -14,12 +14,10 @@ limitations under the License. ==============================================================================*/ #pragma once -#if defined(USE_NPU) -#include "npu/npu_rope_impl.h" -#endif +#include "impl/npu_rope_impl.h" namespace xllm::kernel { -#if defined(USE_NPU) + class Rope : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; @@ -28,5 +26,5 @@ class Rope : public torch::nn::ModuleHolder { Rope(const ModelContext& context) : ModuleHolder(std::make_shared(context)) {} }; -#endif + } // namespace xllm::kernel diff --git a/xllm/core/kernels/split.h b/xllm/core/kernels/npu/split.h similarity index 93% rename from xllm/core/kernels/split.h rename to xllm/core/kernels/npu/split.h index 806730bd9..cda39703e 100644 --- a/xllm/core/kernels/split.h +++ b/xllm/core/kernels/npu/split.h @@ -14,12 +14,9 @@ limitations under the License. ==============================================================================*/ #pragma once -#if defined(USE_NPU) -#include "npu/npu_split_impl.h" -#endif +#include "impl/npu_split_impl.h" namespace xllm::kernel { -#if defined(USE_NPU) class Split : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; @@ -34,5 +31,5 @@ class Split : public torch::nn::ModuleHolder { splitNum, splitSizes)) {} }; -#endif + } // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp new file mode 100644 index 000000000..af34e66f1 --- /dev/null +++ b/xllm/core/kernels/ops_api.cpp @@ -0,0 +1,187 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ops_api.h" + +namespace xllm { +namespace kernel { + +void apply_rotary(RotaryParams& params) { +#if defined(USE_MLU) + mlu::apply_rotary(params.q, + params.k, + params.sin, + params.cos, + params.position_ids, + params.cu_query_lens, + params.interleaved, + params.discrete, + params.dynamic_ntk, + params.max_query_len); +#else + throw std::runtime_error("apply_rotary not implemented"); +#endif +} + +void active(ActivationParams& params) { +#if defined(USE_MLU) + mlu::active(params.input, + params.output, + params.bias, + params.cusum_token_count, + params.act_mode, + params.is_gated, + params.start_expert_id, + params.expert_size); +#else + throw std::runtime_error("active not implemented"); +#endif +} + +void reshape_paged_cache(ReshapePagedCacheParams& params) { +#if defined(USE_MLU) + mlu::reshape_paged_cache(params.key, + params.value, + params.k_cache, + params.v_cache, + params.slot_mapping, + params.direction); +#else + throw std::runtime_error("reshape_paged_cache not implemented"); +#endif +} + +void batch_prefill(AttentionParams& params) { +#if defined(USE_MLU) + mlu::batch_prefill(params.query, + params.key, + params.value, + params.output, + params.output_lse, + params.query_start_loc, + params.seq_start_loc, + params.alibi_slope, + params.attn_bias, + params.q_quant_scale, + params.k_quant_scale, + params.v_quant_scale, + params.out_quant_scale, + params.block_table, + params.max_query_len, + params.max_seq_len, + params.scale, + params.is_causal, + params.window_size_left, + params.window_size_right, + params.compute_dtype, + params.return_lse); +#else + throw std::runtime_error("batch_prefill not implemented"); +#endif +} + +void batch_decode(AttentionParams& params) { +#if defined(USE_MLU) + mlu::batch_decode(params.query, + params.k_cache, + params.output, + params.block_table.value(), + params.kv_seq_lens, + params.v_cache, + params.output_lse, + params.q_quant_scale, + params.k_cache_quant_scale, + params.v_cache_quant_scale, + params.out_quant_scale, + params.alibi_slope, + params.mask, + params.compute_dtype, + params.max_seq_len, + params.window_size_left, + params.window_size_right, + params.scale, + params.return_lse, + params.kv_cache_quant_bit_size); +#else + throw std::runtime_error("batch_decode not implemented"); +#endif +} + +void fused_layernorm(FusedLayerNormParams& params) { +#if defined(USE_MLU) + mlu::fused_layernorm(params.input, + params.output, + params.residual, + params.weight, + params.beta, + params.bias, + params.quant_scale, + params.residual_out, + params.smooth_quant_scale, + params.normed_out, + params.mode, + params.eps, + params.store_output_before_norm, + params.store_output_after_norm, + params.dynamic_quant); +#else + throw std::runtime_error("fused_layernorm not implemented"); +#endif +} + +torch::Tensor matmul(MatmulParams& params) { +#if defined(USE_MLU) + return mlu::matmul( + params.a, params.b, params.bias, params.c, params.alpha, params.beta); +#else + throw std::runtime_error("matmul not implemented"); +#endif +} + +torch::Tensor fused_moe(FusedMoEParams& params) { +#if defined(USE_MLU) + return mlu::fused_moe(params.hidden_states, + params.gating_output, + params.w1, + params.w2, + params.bias1, + params.bias2, + params.residual, + params.input_smooth, + params.act_smooth, + params.w1_scale, + params.w2_scale, + params.e_score_correction_bias, + params.topk, + params.renormalize, + params.gated, + params.act_mode, + params.scoring_func, + params.start_expert_id, + params.block_n, + params.avg_moe, + params.class_reduce_weight, + params.class_expert_id, + params.w1_quant_flag, + params.w2_quant_flag, + params.world_size, + params.shared_expert_num, + params.parallel_mode); +#else + throw std::runtime_error("fused_moe not implemented"); +#endif +} +} // namespace kernel +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h new file mode 100644 index 000000000..df96850a7 --- /dev/null +++ b/xllm/core/kernels/ops_api.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "param.h" + +#if defined(USE_MLU) +#include "mlu/mlu_ops_api.h" +#endif + +namespace xllm { +namespace kernel { + +void apply_rotary(RotaryParams& params); + +void active(ActivationParams& params); + +void reshape_paged_cache(ReshapePagedCacheParams& params); + +void batch_prefill(AttentionParams& params); + +void batch_decode(AttentionParams& params); + +void fused_layernorm(FusedLayerNormParams& params); + +torch::Tensor matmul(MatmulParams& params); + +torch::Tensor fused_moe(FusedMoEParams& params); + +} // namespace kernel +} // namespace xllm diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h new file mode 100644 index 000000000..2896ce0b8 --- /dev/null +++ b/xllm/core/kernels/param.h @@ -0,0 +1,174 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +namespace xllm { +namespace kernel { + +// Note: add default values for optional parameters in the struct definition + +// Rotary embedding parameters +struct RotaryParams { + torch::Tensor q; + torch::Tensor k; + torch::Tensor sin; + torch::Tensor cos; + torch::Tensor cos_sin; + std::optional position_ids; + torch::Tensor cu_query_lens; + bool interleaved; + bool discrete; + bool dynamic_ntk = false; + int max_query_len; +}; + +// Activation parameters +struct ActivationParams { + torch::Tensor input; + torch::Tensor output; + std::optional bias; + std::optional cusum_token_count; + std::string act_mode; + bool is_gated; + int start_expert_id = 0; + int expert_size = 0; +}; + +// Reshape paged cache parameters +struct ReshapePagedCacheParams { + torch::Tensor key; + torch::Tensor value; + torch::Tensor k_cache; + torch::Tensor v_cache; + torch::Tensor slot_mapping; + bool direction = false; +}; + +// Attention parameters +struct AttentionParams { + // common parameters + torch::Tensor query; + torch::Tensor output; + std::optional output_lse; + std::optional alibi_slope; + std::optional q_quant_scale; + std::optional out_quant_scale; + std::optional block_table; + std::string compute_dtype; + int max_seq_len; + int window_size_left; + int window_size_right = -1; + float scale; + bool return_lse = false; + // for flashinfer + torch::Tensor paged_kv_indptr; + torch::Tensor paged_kv_indices; + torch::Tensor paged_kv_last_page_len; + torch::Tensor float_workspace_buffer; + torch::Tensor int_workspace_buffer; + torch::Tensor page_locked_int_workspace_buffer; + torch::Tensor kv_cu_seq_lens; + torch::Tensor q_cu_seq_lens; + bool enable_cuda_graph = false; + + // prefill parameters + torch::Tensor key; // [num_tokens, num_kv_heads, head_dim_qk] + torch::Tensor value; // [num_tokens, num_kv_heads, head_dim_vo] + std::optional query_start_loc; + std::optional seq_start_loc; + std::optional attn_bias; + std::optional k_quant_scale; + std::optional v_quant_scale; + int max_query_len; + bool is_causal = true; + + // decode parameters + torch::Tensor k_cache; + torch::Tensor v_cache; + torch::Tensor kv_seq_lens; + std::optional k_cache_quant_scale; + std::optional v_cache_quant_scale; + std::optional mask; + int kv_cache_quant_bit_size = -1; +}; + +// Fused layer norm parameters +struct FusedLayerNormParams { + torch::Tensor input; + torch::Tensor output; + std::optional residual; + torch::Tensor weight; + std::optional beta; + std::optional bias; + std::optional quant_scale; + std::optional residual_out; + std::optional smooth_quant_scale; + std::optional normed_out; + std::string mode; + double eps; + bool store_output_before_norm = false; + bool store_output_after_norm = false; + bool dynamic_quant = false; +}; + +// Matmul parameters +struct MatmulParams { + torch::Tensor a; + torch::Tensor b; + std::optional bias; + std::optional c; + double alpha = 1.0; + double beta = 0.0; +}; + +// Fused MoE parameters +struct FusedMoEParams { + torch::Tensor hidden_states; + torch::Tensor gating_output; + torch::Tensor w1; + torch::Tensor w2; + std::optional bias1; + std::optional bias2; + std::optional residual; + std::optional input_smooth; + std::optional act_smooth; + std::optional w1_scale; + std::optional w2_scale; + std::optional e_score_correction_bias; + int topk; + bool renormalize; + bool gated; + std::string act_mode; + std::string scoring_func = "softmax"; + int start_expert_id = 0; + int block_n = 0; + bool avg_moe = false; + std::optional class_reduce_weight; + std::optional class_expert_id; + std::optional> w1_quant_flag; + std::optional> w2_quant_flag; + int world_size = 0; + int shared_expert_num = 0; + std::string parallel_mode = "ep"; +}; +} // namespace kernel +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/mlu/CMakeLists.txt b/xllm/core/layers/mlu/CMakeLists.txt index a08e6b68e..cca8f5fb4 100755 --- a/xllm/core/layers/mlu/CMakeLists.txt +++ b/xllm/core/layers/mlu/CMakeLists.txt @@ -35,7 +35,7 @@ cc_library( :parallel_state :state_dict :model - :xllm_mlu_ops + :kernels glog::glog gflags::gflags torch diff --git a/xllm/core/layers/mlu/attention.cpp b/xllm/core/layers/mlu/attention.cpp index 1ae056ce3..fa7a77257 100644 --- a/xllm/core/layers/mlu/attention.cpp +++ b/xllm/core/layers/mlu/attention.cpp @@ -15,7 +15,7 @@ limitations under the License. #include "attention.h" -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" DECLARE_bool(enable_chunked_prefill); namespace xllm { @@ -42,7 +42,7 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, attn_metadata.is_prefill = is_prefill && !attn_metadata.is_chunked_prefill; if (!attn_metadata.is_prefill) { attn_metadata.block_table = params.block_tables; - attn_metadata.seq_lens = torch::diff(params.kv_seq_lens); + attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens } return attn_metadata; @@ -76,83 +76,54 @@ std::tuple> AttentionImpl::forward( torch::Tensor k_cache = kv_cache.get_k_cache(); torch::Tensor v_cache = kv_cache.get_v_cache(); - tmo::torch_api::reshape_paged_cache(key, - value, - k_cache, - v_cache, - attn_metadata.slot_mapping, - false /* direction */); + xllm::kernel::ReshapePagedCacheParams reshape_paged_cache_params; + reshape_paged_cache_params.key = key; + reshape_paged_cache_params.value = value; + reshape_paged_cache_params.k_cache = k_cache; + reshape_paged_cache_params.v_cache = v_cache; + reshape_paged_cache_params.slot_mapping = attn_metadata.slot_mapping; + xllm::kernel::reshape_paged_cache(reshape_paged_cache_params); + + xllm::kernel::AttentionParams attention_params; + attention_params.query = query; + attention_params.output = output; + attention_params.output_lse = output_lse; + attention_params.max_seq_len = attn_metadata.max_seq_len; + attention_params.window_size_left = sliding_window_; + attention_params.scale = scale_; + attention_params.compute_dtype = attn_metadata.compute_dtype; if (attn_metadata.is_prefill) { - tmo::torch_api::flash_attention(query, - key, - value, - output, - output_lse, - attn_metadata.query_start_loc, - attn_metadata.seq_start_loc, - std::nullopt /* alibi_slope */, - std::nullopt /* attn_bias */, - std::nullopt /* q_quant_scale */, - std::nullopt /* k_quant_scale */, - std::nullopt /* v_quant_scale */, - std::nullopt /* out_quant_scale */, - std::nullopt /* block_tables */, - attn_metadata.max_query_len, - attn_metadata.max_seq_len, - scale_, - true /* is_causal */, - sliding_window_, - -1, - attn_metadata.compute_dtype, - false /* return_lse */); + attention_params.key = key; + attention_params.value = value; + attention_params.query_start_loc = attn_metadata.query_start_loc; + attention_params.seq_start_loc = attn_metadata.seq_start_loc; + attention_params.max_query_len = attn_metadata.max_query_len; + + xllm::kernel::batch_prefill(attention_params); } else if (attn_metadata.is_chunked_prefill) { - tmo::torch_api::flash_attention(query, - k_cache, - v_cache, - output, - output_lse, - attn_metadata.query_start_loc, - attn_metadata.seq_start_loc, - std::nullopt /* alibi_slope */, - std::nullopt /* attn_bias */, - std::nullopt /* q_quant_scale */, - std::nullopt /* k_quant_scale */, - std::nullopt /* v_quant_scale */, - std::nullopt /* out_quant_scale */, - attn_metadata.block_table, - attn_metadata.max_query_len, - attn_metadata.max_seq_len, - scale_, - true /* is_causal */, - sliding_window_, - -1, - attn_metadata.compute_dtype, - false /* return_lse */); + attention_params.key = k_cache; + attention_params.value = v_cache; + attention_params.query_start_loc = attn_metadata.query_start_loc; + attention_params.seq_start_loc = attn_metadata.seq_start_loc; + attention_params.max_query_len = attn_metadata.max_query_len; + attention_params.block_table = attn_metadata.block_table; + + xllm::kernel::batch_prefill(attention_params); } else { query = query.view({-1, 1, num_heads_, head_size_}); output = output.view({-1, 1, num_heads_, head_size_}); - tmo::torch_api::single_query_cached_kv_attn( - query, - k_cache, - output, - attn_metadata.block_table, - attn_metadata.seq_lens, - v_cache, - output_lse, - std::nullopt /* q_quant_scale */, - std::nullopt /* k_cache_quant_scale */, - std::nullopt /* v_cache_quant_scale */, - std::nullopt /* out_quant_scale */, - std::nullopt /* alibi_slope */, - std::nullopt /* mask */, - attn_metadata.compute_dtype, - attn_metadata.max_seq_len, - sliding_window_, - -1 /* always -1 for window size right */, - scale_, - false /* return_lse */, - -1 /* kv_cache_quant_bit_size */); + + attention_params.query = query; + attention_params.output = output; + attention_params.k_cache = k_cache; + attention_params.v_cache = v_cache; + + // for mlu + attention_params.block_table = attn_metadata.block_table; + attention_params.kv_seq_lens = attn_metadata.kv_seq_lens; + + xllm::kernel::batch_decode(attention_params); } output = output.view({-1, num_heads_ * head_size_}); diff --git a/xllm/core/layers/mlu/attention.h b/xllm/core/layers/mlu/attention.h index 5e73571ab..7e2100017 100644 --- a/xllm/core/layers/mlu/attention.h +++ b/xllm/core/layers/mlu/attention.h @@ -36,7 +36,7 @@ struct AttentionMetadata { torch::Tensor query_start_loc; torch::Tensor seq_start_loc; - torch::Tensor seq_lens; + torch::Tensor kv_seq_lens; torch::Tensor block_table; torch::Tensor slot_mapping; int max_query_len; diff --git a/xllm/core/layers/mlu/dense_mlp.cpp b/xllm/core/layers/mlu/dense_mlp.cpp index 06797721d..b7a90dd28 100644 --- a/xllm/core/layers/mlu/dense_mlp.cpp +++ b/xllm/core/layers/mlu/dense_mlp.cpp @@ -17,7 +17,7 @@ limitations under the License. #include -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" namespace xllm { namespace layer { @@ -61,14 +61,13 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) { {batch_size, intermediate_size_ / parallel_args_.tp_group_->world_size()}, gate_up.options()); - tmo::torch_api::active(gate_up, - output, - std::nullopt /* bias */, - std::nullopt /* cusum_token_count */, - xllm::mlu::kActModeSilu, - is_gated_, - 0 /* start_expert_id */, - 0 /* expert_size */); + xllm::kernel::ActivationParams activation_params; + activation_params.input = gate_up; + activation_params.output = output; + activation_params.act_mode = xllm::kernel::mlu::kActModeSilu; + activation_params.is_gated = is_gated_; + + xllm::kernel::active(activation_params); return down_proj_->forward(output); } diff --git a/xllm/core/layers/mlu/fuse_norm.cpp b/xllm/core/layers/mlu/fuse_norm.cpp index 766f36b3a..9b5dec015 100644 --- a/xllm/core/layers/mlu/fuse_norm.cpp +++ b/xllm/core/layers/mlu/fuse_norm.cpp @@ -17,7 +17,7 @@ limitations under the License. #include -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" namespace xllm { namespace layer { @@ -39,22 +39,14 @@ torch::Tensor FusedRMSNormImpl::forward(torch::Tensor& input) { input = input.reshape({-1, norm_dim_}); auto output = torch::empty_like(input); - tmo::torch_api::fused_layernorm(input, - output, - std::nullopt /*residual*/, - weight_, - std::nullopt /*beta*/, - std::nullopt /*bias*/, - std::nullopt /*quant_scale*/, - std::nullopt /*residual_out*/, - std::nullopt /*smooth_quant_scale*/, - std::nullopt /*normed_out*/, - kRmsNormMode, - eps_, - false /*store_output_before_norm*/, - false /*store_output_after_norm*/, - false /*dynamic_quant*/ - ); + xllm::kernel::FusedLayerNormParams fused_layernorm_params; + fused_layernorm_params.input = input; + fused_layernorm_params.output = output; + fused_layernorm_params.weight = weight_; + fused_layernorm_params.mode = kRmsNormMode; + fused_layernorm_params.eps = eps_; + + xllm::kernel::fused_layernorm(fused_layernorm_params); output = output.view(org_shape); return output; diff --git a/xllm/core/layers/mlu/fused_moe.cpp b/xllm/core/layers/mlu/fused_moe.cpp index 05059207d..4c20be07f 100644 --- a/xllm/core/layers/mlu/fused_moe.cpp +++ b/xllm/core/layers/mlu/fused_moe.cpp @@ -18,7 +18,8 @@ limitations under the License. #include #include "framework/parallel_state/parallel_state.h" -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" + namespace xllm { namespace layer { @@ -102,24 +103,23 @@ torch::Tensor FusedMoEImpl::forward_expert( if (e_score_correction_bias_.defined()) { e_score_correction_bias = e_score_correction_bias_; } - auto final_hidden_states = xllm::mlu::fused_moe(hidden_states, - router_logits, - w13_, - w2_, - std::nullopt, - std::nullopt, - shared_output, - std::nullopt, - std::nullopt, - std::nullopt, - std::nullopt, - e_score_correction_bias, - topk_, - renormalize_, - is_gated_, - hidden_act_, - scoring_func_, - start_expert_id_); + pack_params(); + + xllm::kernel::FusedMoEParams fused_moe_params; + fused_moe_params.hidden_states = hidden_states; + fused_moe_params.gating_output = router_logits; + fused_moe_params.w1 = w13_; + fused_moe_params.w2 = w2_; + fused_moe_params.e_score_correction_bias = e_score_correction_bias; + fused_moe_params.topk = topk_; + fused_moe_params.renormalize = renormalize_; + fused_moe_params.gated = is_gated_; + fused_moe_params.act_mode = hidden_act_; + fused_moe_params.scoring_func = scoring_func_; + fused_moe_params.start_expert_id = start_expert_id_; + + auto final_hidden_states = xllm::kernel::fused_moe(fused_moe_params); + if (tp_pg_->world_size() > 1) { final_hidden_states = parallel_state::reduce(final_hidden_states, tp_pg_); } diff --git a/xllm/core/layers/mlu/linear_impl.cpp b/xllm/core/layers/mlu/linear_impl.cpp index 7607742ef..64607715c 100644 --- a/xllm/core/layers/mlu/linear_impl.cpp +++ b/xllm/core/layers/mlu/linear_impl.cpp @@ -20,7 +20,7 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/parallel_state/parallel_state.h" -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" namespace xllm { namespace layer { @@ -61,7 +61,12 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { input = input.to(device_); auto bias = (bias_.defined() && rank_ == 0) ? std::optional(bias_) : std::nullopt; - auto output = xllm::mlu::matmul(input, weight_, bias, std::nullopt, 1.0, 0.0); + xllm::kernel::MatmulParams matmul_params; + matmul_params.a = input; + matmul_params.b = weight_; + matmul_params.bias = bias; + + auto output = xllm::kernel::matmul(matmul_params); if (world_size_ > 1 && gather_output_) { output = xllm::parallel_state::gather(output, parallel_args_.tp_group_); } @@ -144,8 +149,12 @@ torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) { auto bias = (qkv_bias_.defined() && rank_ == 0) ? std::optional(qkv_bias_) : std::nullopt; - auto output = - xllm::mlu::matmul(input, qkv_weight_, bias, std::nullopt, 1.0, 0.0); + xllm::kernel::MatmulParams matmul_params; + matmul_params.a = input; + matmul_params.b = qkv_weight_; + matmul_params.bias = bias; + + auto output = xllm::kernel::matmul(matmul_params); if (world_size_ > 1 && gather_output_) { output = xllm::parallel_state::gather(output, parallel_args_.tp_group_); } @@ -234,8 +243,12 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { auto bias = (bias_.defined() && rank_ == 0) ? std::optional(bias_) : std::nullopt; - auto output = xllm::mlu::matmul(input, weight_, bias, std::nullopt, 1.0, 0.0); + xllm::kernel::MatmulParams matmul_params; + matmul_params.a = input; + matmul_params.b = weight_; + matmul_params.bias = bias; + auto output = xllm::kernel::matmul(matmul_params); if (if_reduce_results_ && world_size_ > 1) { output = xllm::parallel_state::reduce(output, parallel_args_.tp_group_); } @@ -277,7 +290,12 @@ ReplicatedLinearImpl::ReplicatedLinearImpl( torch::Tensor ReplicatedLinearImpl::forward(torch::Tensor input) { namespace F = torch::nn::functional; auto bias = bias_.defined() ? std::optional(bias_) : std::nullopt; - auto output = xllm::mlu::matmul(input, weight_, bias, std::nullopt, 1.0, 0.0); + xllm::kernel::MatmulParams matmul_params; + matmul_params.a = input; + matmul_params.b = weight_; + matmul_params.bias = bias; + + auto output = xllm::kernel::matmul(matmul_params); return output; } diff --git a/xllm/core/layers/mlu/rotary_embedding.cpp b/xllm/core/layers/mlu/rotary_embedding.cpp index 6af442c21..656a61d4a 100644 --- a/xllm/core/layers/mlu/rotary_embedding.cpp +++ b/xllm/core/layers/mlu/rotary_embedding.cpp @@ -15,7 +15,7 @@ limitations under the License. #include "rotary_embedding.h" -#include "kernels/mlu/torch_ops_api.h" +#include "kernels/ops_api.h" namespace xllm { namespace layer { @@ -78,25 +78,20 @@ void RotaryEmbeddingImpl::forward(torch::Tensor& q, discrete = true; position_ids = positions; } - const int64_t T = q.size(0); - q = q.view({T, -1}); - k = k.view({T, -1}); - auto qk = torch::cat({q, k}, /*dim=*/-1); - qk = qk.view({T, -1, rotary_dim_}); - tmo::torch_api::apply_rotary(qk, - qk /* output */, - sin_, - cos_, - position_ids, - cu_query_lens, - interleaved_, - discrete, - false /* dynamic_ntk */, - max_query_len); - qk = qk.view({-1, q.size(-1) + k.size(-1)}); - auto qk_vec = qk.split({q.size(-1), k.size(-1)}, /*dim=*/-1); - q = qk_vec[0]; - k = qk_vec[1]; + + xllm::kernel::RotaryParams rotary_params; + rotary_params.q = q; + rotary_params.k = k; + rotary_params.sin = sin_; + rotary_params.cos = cos_; + rotary_params.cos_sin = cos_sin_cache_; + rotary_params.position_ids = position_ids; + rotary_params.cu_query_lens = cu_query_lens; + rotary_params.interleaved = interleaved_; + rotary_params.discrete = discrete; + rotary_params.max_query_len = max_query_len; + + xllm::kernel::apply_rotary(rotary_params); } } // namespace layer diff --git a/xllm/core/platform/device.cpp b/xllm/core/platform/device.cpp index 975063fa0..6c3763c60 100644 --- a/xllm/core/platform/device.cpp +++ b/xllm/core/platform/device.cpp @@ -27,10 +27,10 @@ Device::Device(torch::Device device) : device_(device) {} Device::operator torch::Device() const { return unwrap(); } void Device::set_device() const { + int ret = 0; #if defined(USE_NPU) - auto ret = c10_npu::SetDevice(index()); + ret = c10_npu::SetDevice(index()); #elif defined(USE_MLU) - int ret = 0; torch_mlu::setDevice(index()); #endif