From 0e34c3f68274ea1050e720911758df51087835ba Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Mon, 3 Nov 2025 18:19:35 +0800 Subject: [PATCH] bugfix: fix the hang issue of offline inference when enable_shm. Signed-off-by: Tao Peng --- .../core/distributed_runtime/dist_manager.cpp | 3 +- xllm/core/distributed_runtime/shm_channel.cpp | 7 ++- .../spawn_worker_server.cpp | 6 ++- .../spawn_worker_server/spawn_worker_server.h | 3 +- .../spawn_worker_server_process.cpp | 14 ++++-- .../distributed_runtime/worker_server.cpp | 50 +++++++++++-------- xllm/core/runtime/dit_engine.cpp | 2 + .../runtime/forward_shared_memory_manager.cpp | 12 +++-- .../runtime/forward_shared_memory_manager.h | 5 +- xllm/core/runtime/vlm_engine.cpp | 2 + xllm/core/runtime/vlm_master.cpp | 4 +- xllm/pybind/bind.cpp | 3 +- xllm/pybind/llm.py | 2 + xllm/pybind/vlm.py | 2 + 14 files changed, 74 insertions(+), 41 deletions(-) diff --git a/xllm/core/distributed_runtime/dist_manager.cpp b/xllm/core/distributed_runtime/dist_manager.cpp index 630e91af..7fb78e31 100644 --- a/xllm/core/distributed_runtime/dist_manager.cpp +++ b/xllm/core/distributed_runtime/dist_manager.cpp @@ -100,7 +100,7 @@ std::unique_ptr create_channel(const std::string& worker_addrs, const runtime::Options& options) { std::unique_ptr channel; - if (net::extract_ip(FLAGS_master_node_addr) == + if (net::extract_ip(options.master_node_addr().value_or("")) == net::extract_ip(worker_addrs) && options.enable_shm()) { // create shared memory manager for local rank @@ -118,6 +118,7 @@ std::unique_ptr create_channel(const std::string& worker_addrs, return channel; } + } // namespace void DistManager::setup_multi_node_workers( diff --git a/xllm/core/distributed_runtime/shm_channel.cpp b/xllm/core/distributed_runtime/shm_channel.cpp index ddd1d664..0c492e0e 100644 --- a/xllm/core/distributed_runtime/shm_channel.cpp +++ b/xllm/core/distributed_runtime/shm_channel.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "shm_channel.h" #include "common/global_flags.h" +#include "util/net.h" namespace xllm { @@ -26,16 +27,18 @@ ShmChannel::ShmChannel(int dp_group, : enable_shm_(options.enable_shm()) { bool is_creator; + std::string name_prefix = + "xllm_" + net::extract_port(options.master_node_addr().value_or("")); if (is_driver) { auto name = ForwardSharedMemoryManager::create_unique_name( - dp_group, FORWARD_RAW_INPUT_TYPE, rank); + name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, rank); input_shm_manager_ = std::make_unique( name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE); LOG(INFO) << "Create input shared memory manager with name: " << name; } auto name = ForwardSharedMemoryManager::create_unique_name( - dp_group, FORWARD_RAW_OUTPUT_TYPE, rank); + name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, rank); output_shm_manager_ = std::make_unique( name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE); LOG(INFO) << "Create output shared memory manager with name: " << name; diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp index 2e313791..c6285ae4 100644 --- a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp @@ -39,7 +39,8 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr, int device_idx, int num_decoding_tokens, int block_size, - bool enable_shm) { + bool enable_shm, + bool is_local) { // TODO: pass whole xllm::runtime::Options here from main process. xllm::runtime::Options runner_options; runner_options.block_size(block_size) @@ -47,7 +48,8 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr, .enable_schedule_overlap(false) .enable_offline_inference(true) .master_node_addr(master_node_addr) - .enable_shm(enable_shm); + .enable_shm(enable_shm) + .is_local(is_local); FLAGS_enable_schedule_overlap = false; FLAGS_master_node_addr = master_node_addr; FLAGS_block_size = block_size; diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h index c3e9509b..8a9e4e6b 100644 --- a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h @@ -28,7 +28,8 @@ class SpawnWorkerServer final { int device_idx, int num_decoding_tokens, int block_size, - bool enable_shm); + bool enable_shm, + bool is_local); ~SpawnWorkerServer() = default; diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp index fa0291a7..6728f649 100644 --- a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp @@ -29,10 +29,11 @@ limitations under the License. // @num_decoding_tokens // @block_size // @enable_shm +// @is_local int main(int argc, char* argv[]) { - if (argc < 8) { + if (argc < 9) { LOG(ERROR) - << "Spwan worker process receive wrong args. Need 8 args, receive " + << "Spwan worker process receive wrong args. Need 9 args, receive " << argc; return 1; } @@ -52,15 +53,17 @@ int main(int argc, char* argv[]) { int num_decoding_tokens = atoi(argv[6]); int block_size = atoi(argv[7]); int enable_shm = atoi(argv[8]); + int is_local = atoi(argv[9]); LOG(INFO) << "Spwan worker: " << "master_node_addr = " << master_node_addr - << ", local_rank = " << local_rank + << ", is_local = " << is_local << ", local_rank = " << local_rank << ", world_size = " << world_size << ", device_idx = " << device_idx << ", num_decoding_tokens = " << num_decoding_tokens << ", block_size = " << block_size - << ", enable_shm = " << (enable_shm > 0) << "\n"; + << ", enable_shm = " << (enable_shm > 0) + << ", enable_shm = " << (is_local > 0) << "\n"; xllm::SpawnWorkerServer worker(master_node_addr, local_rank, @@ -69,7 +72,8 @@ int main(int argc, char* argv[]) { device_idx, num_decoding_tokens, block_size, - enable_shm > 0); + enable_shm > 0, + is_local > 0); worker.run(); diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 9b631587..48635964 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -141,6 +141,8 @@ void WorkerServer::create_spawn_server(int local_rank, const char* block_size_ptr = block_size_str.c_str(); auto enable_shm_str = std::to_string(options.enable_shm()); const char* enable_shm_ptr = enable_shm_str.c_str(); + auto is_local_str = std::to_string(options.is_local()); + const char* is_local_ptr = is_local_str.c_str(); std::string spawn_worker_bin_path = options.spawn_worker_path() + "/spawn_worker"; LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path; @@ -153,6 +155,7 @@ void WorkerServer::create_spawn_server(int local_rank, num_decoding_tokens_ptr, block_size_ptr, enable_shm_ptr, + is_local_ptr, nullptr}; pid_t pid; posix_spawn_file_actions_init(&file_actions_); @@ -181,14 +184,16 @@ void WorkerServer::prepare_shm( int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size(); int dp_group = parallel_args.rank() / dp_local_tp_size; + std::string name_prefix = + "xllm_" + net::extract_port(options.master_node_addr().value()); string name = ForwardSharedMemoryManager::create_unique_name( - dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank()); + name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank()); input_shm_manager = std::make_unique( name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE); LOG(INFO) << "Create input shared memory manager with name: " << name; name = ForwardSharedMemoryManager::create_unique_name( - dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank()); + name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank()); output_shm_manager = std::make_unique( name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE); LOG(INFO) << "Create output shared memory manager with name: " << name; @@ -204,31 +209,34 @@ WorkerServer::WorkerServer(int local_worker_idx, WorkerType worker_type, bool use_spawn_worker) { if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) { + // TODO: Refactor these code later. if (use_spawn_worker) { // start worker in a spawn process(for offline inference worker.) create_spawn_server( local_worker_idx, master_node_addr, done, parallel_args, d, options); return; - } + } else { + std::unique_ptr input_shm_manager = nullptr; + std::unique_ptr output_shm_manager = nullptr; + prepare_shm( + parallel_args, options, input_shm_manager, output_shm_manager); - std::unique_ptr input_shm_manager = nullptr; - std::unique_ptr output_shm_manager = nullptr; - prepare_shm(parallel_args, options, input_shm_manager, output_shm_manager); - // start worker in a thread. - worker_thread_ = - std::make_unique(&WorkerServer::create_server, - this, - std::cref(options), - std::ref(done), - std::cref(master_node_addr), - std::cref(d), - parallel_args.world_size(), - parallel_args.rank(), - parallel_args.dp_size(), - local_worker_idx, - parallel_args.ep_size(), - std::move(input_shm_manager), - std::move(output_shm_manager)); + // start worker in a thread. + worker_thread_ = + std::make_unique(&WorkerServer::create_server, + this, + std::cref(options), + std::ref(done), + std::cref(master_node_addr), + std::cref(d), + parallel_args.world_size(), + parallel_args.rank(), + parallel_args.dp_size(), + local_worker_idx, + parallel_args.ep_size(), + std::move(input_shm_manager), + std::move(output_shm_manager)); + } } else { // TODO: support other model type later. LOG(ERROR) << "Unsupported model type: " << worker_type; diff --git a/xllm/core/runtime/dit_engine.cpp b/xllm/core/runtime/dit_engine.cpp index 808eb42e..3d33e2e5 100644 --- a/xllm/core/runtime/dit_engine.cpp +++ b/xllm/core/runtime/dit_engine.cpp @@ -42,6 +42,8 @@ DiTEngine::DiTEngine(const runtime::Options& options) : options_(options) { } const int32_t world_size = static_cast(devices.size()); + CHECK(!options_.enable_shm()) << "Dit can not support enable_shm currently."; + // create workers for (size_t i = 0; i < devices.size(); ++i) { const int32_t rank = static_cast(i); diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index cd3ee268..27b7161d 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -655,10 +655,12 @@ ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default; /* The shared memory filename may have duplicates when using kill -9 xllm, but this doesn't affect usage.*/ -std::string ForwardSharedMemoryManager::create_unique_name(int dp_group, - int forward_type, - int rank) { - std::string filename = "xllm_" + net::extract_port(FLAGS_master_node_addr); +std::string ForwardSharedMemoryManager::create_unique_name( + const std::string& prefix, + int dp_group, + int forward_type, + int rank) { + std::string filename = prefix; if (forward_type == FORWARD_PB_INPUT_TYPE || forward_type == FORWARD_RAW_INPUT_TYPE) { filename += "_dpg_" + std::to_string(dp_group) + "_input"; @@ -997,4 +999,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) { void ForwardSharedMemoryManager::clear() { std::memset(base_address(), 0, size()); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/forward_shared_memory_manager.h b/xllm/core/runtime/forward_shared_memory_manager.h index f94b5b96..152fea6d 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.h +++ b/xllm/core/runtime/forward_shared_memory_manager.h @@ -47,7 +47,8 @@ class ForwardSharedMemoryManager : public SharedMemoryManager { bool& is_creator, ForwardType type); ~ForwardSharedMemoryManager(); - static std::string create_unique_name(int dp_group, + static std::string create_unique_name(const std::string& prefix, + int dp_group, int forward_type, int rank); @@ -121,4 +122,4 @@ class ForwardSharedMemoryManager : public SharedMemoryManager { void* metadata_addr_ = nullptr; ControlMetadata* control_ptr_ = nullptr; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/vlm_engine.cpp b/xllm/core/runtime/vlm_engine.cpp index 3c7431bc..3a75a180 100644 --- a/xllm/core/runtime/vlm_engine.cpp +++ b/xllm/core/runtime/vlm_engine.cpp @@ -50,6 +50,8 @@ VLMEngine::VLMEngine(const runtime::Options& options) : options_(options) { process_groups_ = parallel_state::create_npu_process_groups(devices); } + CHECK(!options_.enable_shm()) << "VLM can not support enable_shm currently."; + WorkerType worker_type = (options_.task_type() == "generate") ? WorkerType::VLM : WorkerType::EVLM; const int32_t world_size = static_cast(devices.size()); diff --git a/xllm/core/runtime/vlm_master.cpp b/xllm/core/runtime/vlm_master.cpp index 62c330b8..5146f181 100644 --- a/xllm/core/runtime/vlm_master.cpp +++ b/xllm/core/runtime/vlm_master.cpp @@ -289,9 +289,11 @@ void VLMMaster::run() { } void VLMMaster::generate() { + DCHECK(options_.enable_schedule_overlap()) + << "Mode generate does not support schedule overlap yet."; const bool already_running = running_.load(std::memory_order_relaxed); if (already_running) { - LOG(WARNING) << "VLMMaster is already running."; + LOG(WARNING) << "Generate is already running."; return; } diff --git a/xllm/pybind/bind.cpp b/xllm/pybind/bind.cpp index 94815fdd..031074d2 100644 --- a/xllm/pybind/bind.cpp +++ b/xllm/pybind/bind.cpp @@ -82,7 +82,8 @@ PYBIND11_MODULE(xllm_export, m) { .def_readwrite("enable_offline_inference", &Options::enable_offline_inference_) .def_readwrite("spawn_worker_path", &Options::spawn_worker_path_) - .def_readwrite("enable_shm", &Options::enable_shm_); + .def_readwrite("enable_shm", &Options::enable_shm_) + .def_readwrite("is_local", &Options::is_local_); // 2. export LLMMaster py::class_(m, "LLMMaster") diff --git a/xllm/pybind/llm.py b/xllm/pybind/llm.py index 6eb46d16..d20bc83e 100644 --- a/xllm/pybind/llm.py +++ b/xllm/pybind/llm.py @@ -44,6 +44,7 @@ def __init__( disable_ttft_profiling: bool = False, enable_forward_interruption: bool = False, enable_shm: bool = False, + is_local: bool = True, **kwargs, ) -> None: @@ -95,6 +96,7 @@ def __init__( options.enable_offline_inference = True options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) options.enable_shm = enable_shm + options.is_local = is_local self.master = LLMMaster(options) def finish(self): diff --git a/xllm/pybind/vlm.py b/xllm/pybind/vlm.py index df8a4a6b..aa5f0628 100644 --- a/xllm/pybind/vlm.py +++ b/xllm/pybind/vlm.py @@ -42,6 +42,7 @@ def __init__( enable_schedule_overlap: bool = False, kv_cache_transfer_mode: str = 'PUSH', enable_shm: bool = False, + is_local: bool = True, **kwargs, ) -> None: @@ -90,6 +91,7 @@ def __init__( options.enable_offline_inference = True options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) options.enable_shm = enable_shm + options.is_local = is_local self.master = VLMMaster(options) def finish(self):