From ce55f01a17212eda84b0ef43ded635bd12a9852b Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Sun, 30 Nov 2025 00:22:08 +0800 Subject: [PATCH 1/2] feat: add NPU process group initialization and management. --- CMakeLists.txt | 1 + xllm/CMakeLists.txt | 2 +- xllm/core/common/CMakeLists.txt | 1 + xllm/core/common/global_flags.cpp | 7 +- xllm/core/common/global_flags.h | 2 + .../distributed_runtime/worker_server.cpp | 2 - .../collective_communicator.cpp | 30 ++-- .../parallel_state/npu_process_group.cpp | 150 +++++++----------- .../parallel_state/npu_process_group.h | 24 ++- .../framework/parallel_state/process_group.h | 14 ++ 10 files changed, 113 insertions(+), 120 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f18bca2ff..052cc3819 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -344,6 +344,7 @@ if(USE_NPU) $ENV{PYTORCH_INSTALL_PATH}/include $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include $ENV{PYTORCH_NPU_INSTALL_PATH}/include + $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed $ENV{NPU_HOME_PATH}/include $ENV{ATB_HOME_PATH}/include $ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/ diff --git a/xllm/CMakeLists.txt b/xllm/CMakeLists.txt index b31f3f239..0c86f08c6 100644 --- a/xllm/CMakeLists.txt +++ b/xllm/CMakeLists.txt @@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p add_dependencies(xllm brpc-static) if(USE_NPU) - set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext) + set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext torch_npu torch_python) elseif(USE_MLU) set(COMMON_LIBS Python::Python) endif() diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index 2b1fd8dea..b765a22f9 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -29,6 +29,7 @@ cc_library( absl::random_random absl::strings torch + $<$:torch_python> $<$:torch_npu> $<$:mspti> $<$:ms_tools_ext> diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index db5a38b22..83fb5bb4b 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -430,4 +430,9 @@ DEFINE_bool( enable_dp_balance, false, "Whether to enable dp load balance, if true, sequences within a single " - "dp batch will be shuffled."); \ No newline at end of file + "dp batch will be shuffled."); + +DEFINE_string( + npu_kernel_backend, + "ATB", + "NPU kernel backend. Supported options: ATB, TORCH. Default is ATB."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index f3e40f7d1..003d72f84 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -214,3 +214,5 @@ DECLARE_bool(enable_prefetch_weight); DECLARE_int32(flashinfer_workspace_buffer_size); DECLARE_bool(enable_dp_balance); + +DECLARE_string(npu_kernel_backend); \ No newline at end of file diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 0177d779f..7f2ed9cb0 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -104,9 +104,7 @@ void WorkerServer::create_server( CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size); const ParallelArgs* parallel_args = comm.parallel_args(); -#if defined(USE_MLU) || defined(USE_CUDA) comm.create_process_groups(master_node_addr, device); -#endif std::unique_ptr worker = std::make_unique(*parallel_args, device, options, worker_type); diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index c0066be0a..b03d59add 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "mapping_npu.h" #if defined(USE_NPU) +#include "npu_process_group.h" #include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h" #include "xllm_kernels/core/include/atb_speed/utils/singleton.h" #include "xllm_kernels/models/base/param/mapping.h" @@ -30,23 +31,6 @@ limitations under the License. #include "parallel_args.h" #include "util/net.h" -namespace { -#if defined(USE_NPU) -std::unique_ptr create_process_group( - int rank, - int world_size, - int rank_size, - int port, - bool trans, - const std::string& host, - const std::string& group_name, - const torch::Device& device) { - LOG(FATAL) << "Unsupported device type"; - return nullptr; -} -#endif -} // namespace - namespace xllm { CollectiveCommunicator::CollectiveCommunicator(int global_rank, @@ -72,6 +56,13 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, // std::make_unique( // global_rank, world_size, device, comm); + // comunicator will be inited in torch. + if (FLAGS_npu_kernel_backend == "TORCH") { + parallel_args_ = std::make_unique( + global_rank, world_size, dp_size, nullptr, ep_size); + return; + } + // comunicator will be inited in atb. MappingNPU::Options mapping_options; mapping_options.dp_size(dp_size) @@ -116,6 +107,11 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, void CollectiveCommunicator::create_process_groups( const std::string& master_addr, const torch::Device& device) { +#if defined(USE_NPU) + if (FLAGS_npu_kernel_backend == "ATB") { + return; + } +#endif std::string host; int port; net::parse_host_port_from_addr(master_addr, host, port); diff --git a/xllm/core/framework/parallel_state/npu_process_group.cpp b/xllm/core/framework/parallel_state/npu_process_group.cpp index fceaa9d00..b401c4371 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.cpp +++ b/xllm/core/framework/parallel_state/npu_process_group.cpp @@ -14,6 +14,16 @@ limitations under the License. ==============================================================================*/ #include "npu_process_group.h" +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#else +#include +#include +#endif + +#include +#include +#include namespace { @@ -24,113 +34,65 @@ namespace { LOG(FATAL) << "Failed, HCCL error :" << HcclGetErrorString(r); \ } \ } while (0) +} // namespace -inline bool is_npu(const at::Tensor& tensor) { - if (!tensor.defined()) { - return false; - } - return tensor.device().is_privateuseone(); -} - -inline bool is_npu(const at::TensorOptions& options) { - return options.device().is_privateuseone(); -} +namespace xllm { -inline bool is_npu(const at::Device& device) { - return device.is_privateuseone(); -} +ProcessGroupHCCL::ProcessGroupHCCL(int global_rank, + int world_size, + int rank_size, + int port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(device) { + c10::intrusive_ptr hccl_pg_options = + c10d_npu::ProcessGroupHCCL::Options::create(); + // hccl_pg_options->group_name = group_name; + int rank = global_rank; + if (world_size != rank_size) { + auto [local_rank, group_ranks] = + get_group_rank(world_size, global_rank, rank_size, trans); + std::vector uint32_ranks; + for (auto rank : group_ranks) { + uint32_ranks.push_back(static_cast(rank)); + } + hccl_pg_options->global_ranks_in_group = uint32_ranks; + rank = local_rank; + } -at::Tensor flatten_for_scatter_gather(std::vector& tensors) { - auto& t = tensors[0]; - std::vector sizes{static_cast(tensors.size())}; - sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); - return at::empty(sizes, t.options()); + auto store = create_tcp_store(host, port, rank); + pg_ = std::make_unique( + store, rank, rank_size, hccl_pg_options); } -HcclDataType to_hccl_data_type(const torch::Tensor& input) { - const auto type = input.scalar_type(); - switch (type) { - case at::kFloat: - return HCCL_DATA_TYPE_FP32; - case at::kHalf: - return HCCL_DATA_TYPE_FP16; - case at::kDouble: - return HCCL_DATA_TYPE_FP64; - case at::kLong: - return HCCL_DATA_TYPE_INT64; - case at::kInt: - return HCCL_DATA_TYPE_INT32; - case at::kChar: - return HCCL_DATA_TYPE_INT8; - case at::kByte: - return HCCL_DATA_TYPE_UINT8; - case at::kBool: - return HCCL_DATA_TYPE_UINT8; - case at::kBFloat16: - return HCCL_DATA_TYPE_BFP16; - default: - LOG(FATAL) << "Unconvertible HCCL type: " << type; +// Destructor. +ProcessGroupHCCL::~ProcessGroupHCCL() { + if (pg_) { + pg_->shutdown(); + } else { + HCCLCHECK(HcclCommDestroy(comm_)); } } -void check_input(torch::Tensor input) { - CHECK(is_npu(input)) << "input should be npu tensor"; - CHECK(input.is_contiguous()) << "input should be contiguous"; - CHECK(!input.is_sparse()) << "input have to be npu dense tensor"; -} - -} // namespace - -namespace xllm { - ProcessGroupHCCL::ProcessGroupHCCL(int rank, int world_size, const torch::Device& device, HcclComm comm) : ProcessGroup(device), comm_(comm) {} -// Destructor. -ProcessGroupHCCL::~ProcessGroupHCCL() { HCCLCHECK(HcclCommDestroy(comm_)); } -void ProcessGroupHCCL::allreduce(torch::Tensor& input) { - DCHECK(input.device() == device()) - << "input should be on the same device as the process group"; - check_input(input); - // inplace all reduce - // const auto count = input.numel(); - // const auto data_type = to_hccl_data_type(input); - // auto stream = c10_npu::getCurrentNPUStream(); - // torch::DeviceGuard device_guard(device()); - // HCCLCHECK(HcclAllReduce( - // /*sendbuff=*/input.data_ptr(), - // /*recvbuff=*/input.data_ptr(), - // /*count=*/count, - // /*datatype=*/data_type, - // /*op=*/HCCL_REDUCE_SUM, - // /*comm=*/comm_, - // /*stream=*/stream)); -} -void ProcessGroupHCCL::allgather(const torch::Tensor& input, - std::vector& outputs) { - check_input(input); - // CHECK(outputs.size() == world_size()) - // << "outputs should have the same size as world_size"; - // DCHECK(input.device() == device()) - // << "input should be on the same device as the process group"; - // torch::DeviceGuard device_guard(device()); - // torch::Tensor flattened_output = flatten_for_scatter_gather(outputs); - // const auto count = input.numel(); - // const auto data_type = to_hccl_data_type(input); - // auto stream = c10_npu::getCurrentNPUStream(); - // HCCLCHECK(HcclAllGather( - // /*sendbuff=*/input.data_ptr(), - // /*recvbuff=*/flattened_output.data_ptr(), - // /*sendcount=*/count, - // /*datatype=*/data_type, - // /*comm=*/comm_, - // /*stream=*/stream)); - // // copy the flattened output tensors to the outputs. - // for (int i = 0; i < outputs.size(); ++i) { - // outputs[i].copy_(flattened_output[i], /*non_blocking=*/true); - // } +std::unique_ptr create_process_group( + int rank, + int world_size, + int rank_size, + int port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device) { + return std::make_unique( + rank, world_size, rank_size, port, trans, host, group_name, device); } + } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/parallel_state/npu_process_group.h b/xllm/core/framework/parallel_state/npu_process_group.h index 7ca7d23b7..1000d9f72 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.h +++ b/xllm/core/framework/parallel_state/npu_process_group.h @@ -28,16 +28,30 @@ class ProcessGroupHCCL : public ProcessGroup { const torch::Device& device, HcclComm comm); + ProcessGroupHCCL(int rank, + int world_size, + int rank_size, + int port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device); + // Destructor. ~ProcessGroupHCCL() override; - void allreduce(torch::Tensor& input) override; - - void allgather(const torch::Tensor& input, - std::vector& outputs) override; - private: HcclComm comm_ = nullptr; }; +std::unique_ptr create_process_group( + int rank, + int world_size, + int rank_size, + int port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device); + } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/parallel_state/process_group.h b/xllm/core/framework/parallel_state/process_group.h index ba1d67a9e..5d7ccaa0a 100644 --- a/xllm/core/framework/parallel_state/process_group.h +++ b/xllm/core/framework/parallel_state/process_group.h @@ -19,6 +19,11 @@ limitations under the License. #include #include + +#if defined(USE_NPU) +#include +#endif + namespace xllm { std::pair> get_group_rank(int world_size, int global_rank, @@ -60,7 +65,16 @@ class ProcessGroup { torch::Device device_; protected: +#if defined(USE_NPU) && \ + (TORCH_VERSION_MAJOR < 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 7)) + // Using ProcessGroupHCCL for NPU devices + // Note: torch_npu uses an older torch version where c10d::Backend lacks + // shutdown() method + std::unique_ptr pg_{nullptr}; +#else std::unique_ptr pg_{nullptr}; +#endif }; } // namespace xllm \ No newline at end of file From 3b6f3d10b4643e15ec529bee4fae382700ce1bc9 Mon Sep 17 00:00:00 2001 From: dengyingxu1 Date: Mon, 1 Dec 2025 17:32:29 +0800 Subject: [PATCH 2/2] refactor: cleanup headers and optimize branch checks. --- xllm/core/common/CMakeLists.txt | 1 - xllm/core/common/global_flags.cpp | 2 ++ xllm/core/common/global_flags.h | 4 +++- .../parallel_state/collective_communicator.cpp | 1 - .../framework/parallel_state/cuda_process_group.h | 3 ++- .../framework/parallel_state/npu_process_group.cpp | 11 ++++------- xllm/core/layers/common/tests/tests_utils.h | 3 ++- 7 files changed, 13 insertions(+), 12 deletions(-) diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index b765a22f9..2b1fd8dea 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -29,7 +29,6 @@ cc_library( absl::random_random absl::strings torch - $<$:torch_python> $<$:torch_npu> $<$:mspti> $<$:ms_tools_ext> diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 83fb5bb4b..ed30a2584 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -432,7 +432,9 @@ DEFINE_bool( "Whether to enable dp load balance, if true, sequences within a single " "dp batch will be shuffled."); +#if defined(USE_NPU) DEFINE_string( npu_kernel_backend, "ATB", "NPU kernel backend. Supported options: ATB, TORCH. Default is ATB."); +#endif diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 003d72f84..8f4026df9 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -215,4 +215,6 @@ DECLARE_int32(flashinfer_workspace_buffer_size); DECLARE_bool(enable_dp_balance); -DECLARE_string(npu_kernel_backend); \ No newline at end of file +#if defined(USE_NPU) +DECLARE_string(npu_kernel_backend); +#endif diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index b03d59add..a53b777d3 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -21,7 +21,6 @@ limitations under the License. #include "npu_process_group.h" #include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h" #include "xllm_kernels/core/include/atb_speed/utils/singleton.h" -#include "xllm_kernels/models/base/param/mapping.h" #elif defined(USE_MLU) #include "mlu_process_group.h" #elif defined(USE_CUDA) diff --git a/xllm/core/framework/parallel_state/cuda_process_group.h b/xllm/core/framework/parallel_state/cuda_process_group.h index 349cf0083..b6d7b1161 100644 --- a/xllm/core/framework/parallel_state/cuda_process_group.h +++ b/xllm/core/framework/parallel_state/cuda_process_group.h @@ -34,7 +34,8 @@ class ProcessGroupNccl : public ProcessGroup { : ProcessGroup(device) { c10::intrusive_ptr pg_options = c10d::ProcessGroupNCCL::Options::create(); -#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) pg_options->group_name = group_name; #endif int rank = global_rank; diff --git a/xllm/core/framework/parallel_state/npu_process_group.cpp b/xllm/core/framework/parallel_state/npu_process_group.cpp index b401c4371..e941e5dc5 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.cpp +++ b/xllm/core/framework/parallel_state/npu_process_group.cpp @@ -14,12 +14,6 @@ limitations under the License. ==============================================================================*/ #include "npu_process_group.h" -#ifdef TORCH_HIGHER_THAN_PTA6 -#include -#else -#include -#include -#endif #include #include @@ -49,7 +43,10 @@ ProcessGroupHCCL::ProcessGroupHCCL(int global_rank, : ProcessGroup(device) { c10::intrusive_ptr hccl_pg_options = c10d_npu::ProcessGroupHCCL::Options::create(); - // hccl_pg_options->group_name = group_name; +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) + hccl_pg_options->group_name = group_name; +#endif int rank = global_rank; if (world_size != rank_size) { auto [local_rank, group_ranks] = diff --git a/xllm/core/layers/common/tests/tests_utils.h b/xllm/core/layers/common/tests/tests_utils.h index 8fdf56f5b..923d48937 100644 --- a/xllm/core/layers/common/tests/tests_utils.h +++ b/xllm/core/layers/common/tests/tests_utils.h @@ -125,7 +125,8 @@ class MockBackend : public c10d::Backend { int64_t getSize() const { return world_size_; } -#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) void shutdown() override { // Mock implementation - do nothing }