From 0ac5cf7125158fa272b3da7d9cd5de60af22ff89 Mon Sep 17 00:00:00 2001 From: "pengtao.156" Date: Wed, 22 Oct 2025 21:28:24 +0800 Subject: [PATCH] feat: support single-node multi-NPUs offline inference. Signed-off-by: pengtao.156 --- examples/generate.py | 3 +- examples/generate_vlm.py | 3 +- setup.py | 3 +- xllm/core/common/options.h | 5 ++ xllm/core/distributed_runtime/CMakeLists.txt | 2 + .../core/distributed_runtime/dist_manager.cpp | 6 +- .../spawn_worker_server/CMakeLists.txt | 24 ++++++ .../spawn_worker_server.cpp | 84 +++++++++++++++++++ .../spawn_worker_server/spawn_worker_server.h | 41 +++++++++ .../spawn_worker_server_process.cpp | 73 ++++++++++++++++ .../distributed_runtime/worker_server.cpp | 68 ++++++++++++++- xllm/core/distributed_runtime/worker_server.h | 22 ++++- .../prefix_cache/prefix_cache_benchmark.cpp | 5 +- xllm/core/runtime/llm_engine.cpp | 3 +- xllm/core/runtime/master.cpp | 12 ++- xllm/core/runtime/options.h | 5 ++ xllm/core/scheduler/continuous_scheduler.cpp | 3 +- xllm/pybind/bind.cpp | 5 +- xllm/pybind/llm.py | 11 ++- xllm/pybind/util.py | 8 ++ xllm/pybind/vlm.py | 7 +- 21 files changed, 368 insertions(+), 25 deletions(-) create mode 100644 xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt create mode 100644 xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp create mode 100644 xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h create mode 100644 xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp create mode 100644 xllm/pybind/util.py diff --git a/examples/generate.py b/examples/generate.py index 740ed8ad..387cb0f5 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,4 +1,5 @@ -# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0' +# python examples/generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0' +# python generate.py --model='/path/models/Qwen2-7B-Instruct' --devices='npu:0,npu:1' from xllm import ArgumentParser, LLM, RequestParams diff --git a/examples/generate_vlm.py b/examples/generate_vlm.py index d8171f33..b67b0f96 100644 --- a/examples/generate_vlm.py +++ b/examples/generate_vlm.py @@ -1,4 +1,5 @@ -# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0' --master_node_addr=127.0.0.1:8888 +# python examples/generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0' +# python generate_vlm.py --model='/path/models/Qwen2.5-VL-7B' --devices='npu:0,npu:1' import os import signal diff --git a/setup.py b/setup.py index 906ed04e..62c56345 100644 --- a/setup.py +++ b/setup.py @@ -610,7 +610,8 @@ def apply_patch(): }, zip_safe=False, py_modules=["xllm/launch_xllm", "xllm/__init__", - "xllm/pybind/llm", "xllm/pybind/vlm", "xllm/pybind/args"], + "xllm/pybind/llm", "xllm/pybind/vlm", + "xllm/pybind/util", "xllm/pybind/args"], entry_points={ 'console_scripts': [ 'xllm = xllm.launch_xllm:launch_xllm' diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 0d7f9014..b0533104 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -170,6 +170,11 @@ class Options { PROPERTY(int, max_requests_per_batch) = 0; PROPERTY(bool, enable_continuous_kvcache) = false; + + // for offline inference: start with offline inference, default is false + PROPERTY(bool, enable_offline_inference) = false; + // for offline inference: the path to spawn worker binary + PROPERTY(std::string, spawn_worker_path) = ""; }; } // namespace xllm diff --git a/xllm/core/distributed_runtime/CMakeLists.txt b/xllm/core/distributed_runtime/CMakeLists.txt index f60828b0..4a68308d 100644 --- a/xllm/core/distributed_runtime/CMakeLists.txt +++ b/xllm/core/distributed_runtime/CMakeLists.txt @@ -4,6 +4,8 @@ if(USE_NPU) include_directories( ${CMAKE_SOURCE_DIR}/third_party/spdlog/include ) + + add_subdirectory(spawn_worker_server) endif() cc_library( diff --git a/xllm/core/distributed_runtime/dist_manager.cpp b/xllm/core/distributed_runtime/dist_manager.cpp index b8669686..046048bd 100644 --- a/xllm/core/distributed_runtime/dist_manager.cpp +++ b/xllm/core/distributed_runtime/dist_manager.cpp @@ -141,6 +141,9 @@ void DistManager::setup_multi_node_workers( // Node2: 0+4, 1+4, 2+4, 3+4 const int32_t rank = static_cast(i) + base_rank; + // we use spawn process worker to launch a xllm instance + // when start a offline inference task with multi-gpu/npu/mpu/... + bool use_spawn_worker = options.enable_offline_inference() && i > 0; ParallelArgs parallel_args(rank, world_size, dp_size, nullptr, ep_size); servers_.emplace_back(std::make_unique(i, master_node_addr, @@ -149,7 +152,8 @@ void DistManager::setup_multi_node_workers( parallel_args, devices[i], worker_server_options, - worker_type)); + worker_type, + use_spawn_worker)); } // Master node need to wait all workers done diff --git a/xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt b/xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt new file mode 100644 index 00000000..7fbae3e5 --- /dev/null +++ b/xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt @@ -0,0 +1,24 @@ +include(cc_binary) + +cc_binary( + NAME + spawn_worker + HDRS + spawn_worker_server.h + SRCS + spawn_worker_server.cpp + spawn_worker_server_process.cpp + DEPS + :models + :model + :distributed_runtime + absl::strings + xllm_kernels + ascendcl + nnopbase + atb + c_sec + spdlog::spdlog +) + +add_dependencies(export_module spawn_worker) diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp new file mode 100644 index 00000000..13dab2e4 --- /dev/null +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.cpp @@ -0,0 +1,84 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "spawn_worker_server.h" + +#include +#if defined(USE_NPU) +#include +#endif +#include +#include + +#include + +#include "core/distributed_runtime/worker_server.h" +#include "core/platform/device.h" +#include "core/runtime/options.h" + +namespace xllm { + +bool xllm::SpawnWorkerServer::g_running_ = true; + +SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr, + int local_rank, + int global_rank, + int world_size, + int device_idx, + int num_decoding_tokens, + int block_size) { + // TODO: pass whole xllm::runtime::Options here from main process. + xllm::runtime::Options runner_options; + runner_options.block_size(block_size) + .num_decoding_tokens(num_decoding_tokens) + .enable_schedule_overlap(false) + .enable_offline_inference(true) + .master_node_addr(master_node_addr); + FLAGS_enable_schedule_overlap = false; + FLAGS_master_node_addr = master_node_addr; + FLAGS_block_size = block_size; + + std::atomic done(false); +#if defined(USE_NPU) + xllm::Device device("npu:" + std::to_string(device_idx)); + device.set_device(); + device.init_device_context(); + FLAGS_enable_atb_comm_multiprocess = true; +#endif + + ParallelArgs parallel_args(global_rank, world_size, 1, nullptr, 1); + WorkerServer worker_server(local_rank, + master_node_addr, + done, + parallel_args, + device, + runner_options, + WorkerType::LLM, + false); +} + +void SpawnWorkerServer::handle_signal(int signum) { g_running_ = false; } + +void SpawnWorkerServer::run() { + signal(SIGINT, SpawnWorkerServer::handle_signal); + signal(SIGTERM, SpawnWorkerServer::handle_signal); + + // main thread waiting here + while (SpawnWorkerServer::g_running_) { + sleep(5); + } +} + +} // namespace xllm diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h new file mode 100644 index 00000000..e78b403b --- /dev/null +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server.h @@ -0,0 +1,41 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +namespace xllm { + +class SpawnWorkerServer final { + public: + explicit SpawnWorkerServer(const std::string& master_node_addr, + int local_rank, + int global_rank, + int world_size, + int device_idx, + int num_decoding_tokens, + int block_size); + + ~SpawnWorkerServer() = default; + + void run(); + + static void handle_signal(int signum); + + static bool g_running_; +}; + +} // namespace xllm diff --git a/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp new file mode 100644 index 00000000..1e441776 --- /dev/null +++ b/xllm/core/distributed_runtime/spawn_worker_server/spawn_worker_server_process.cpp @@ -0,0 +1,73 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "spawn_worker_server.h" + +// Worker argv from engine process: +// @master_node_addr +// @local_rank +// @global_rank +// @world_size +// @device_idx +// @num_decoding_tokens +// @block_size +int main(int argc, char* argv[]) { + if (argc < 7) { + LOG(ERROR) + << "Spwan worker process receive wrong args. Need 7 args, receive " + << argc; + return 1; + } + + // set PR_SET_PDEATHSIG flag that child should exit + // when parent process exit + if (prctl(PR_SET_PDEATHSIG, SIGHUP) == -1) { + perror("prctl"); + return EXIT_FAILURE; + } + + std::string master_node_addr = std::string(argv[1]); + int local_rank = atoi(argv[2]); + int global_rank = atoi(argv[3]); + int world_size = atoi(argv[4]); + int device_idx = atoi(argv[5]); + int num_decoding_tokens = atoi(argv[6]); + int block_size = atoi(argv[7]); + + LOG(INFO) << "Spwan worker: " + << "master_node_addr = " << master_node_addr + << ", local_rank = " << local_rank + << ", world_size = " << world_size + << ", device_idx = " << device_idx + << ", num_decoding_tokens = " << num_decoding_tokens + << ", block_size = " << block_size << "\n"; + + xllm::SpawnWorkerServer worker(master_node_addr, + local_rank, + global_rank, + world_size, + device_idx, + num_decoding_tokens, + block_size); + + worker.run(); + + return 0; +} diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index f1c8348d..8fd14cf2 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -46,6 +47,8 @@ limitations under the License. #include "xllm_kernels/models/base/param/mapping.h" #endif +extern char** environ; + namespace xllm { void WorkerServer::create_server(const runtime::Options& options, @@ -59,6 +62,7 @@ void WorkerServer::create_server(const runtime::Options& options, int32_t ep_size) { Device device(d); device.set_device(); + LOG(INFO) << "Create worker server with device: " << device.index(); auto worker_global_rank = global_rank; // TODO: FIXME Later @@ -109,15 +113,70 @@ void WorkerServer::create_server(const runtime::Options& options, worker_server->run(); } +void WorkerServer::create_spawn_server(int local_rank, + const std::string& master_node_addr, + std::atomic& done, + const ParallelArgs& parallel_args, + const torch::Device& d, + const runtime::Options& options) { + auto local_rank_str0 = std::to_string(local_rank); + const char* local_rank_str = local_rank_str0.c_str(); + auto global_rank_str0 = std::to_string(parallel_args.rank()); + const char* global_rank_str = global_rank_str0.c_str(); + auto world_size_str0 = std::to_string(parallel_args.world_size()); + const char* world_size_str = world_size_str0.c_str(); + auto device_idx_str0 = std::to_string(d.index()); + const char* device_idx_str = device_idx_str0.c_str(); + auto num_decoding_tokens_str0 = std::to_string(options.num_decoding_tokens()); + const char* num_decoding_tokens_str = num_decoding_tokens_str0.c_str(); + auto block_size_str0 = std::to_string(options.block_size()); + const char* block_size_str = block_size_str0.c_str(); + std::string spawn_worker_bin_path = + options.spawn_worker_path() + "/spawn_worker"; + LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path; + const char* argv[] = {spawn_worker_bin_path.c_str(), + master_node_addr.c_str(), + local_rank_str, + global_rank_str, + world_size_str, + device_idx_str, + num_decoding_tokens_str, + block_size_str, + nullptr}; + pid_t pid; + posix_spawn_file_actions_init(&file_actions_); + posix_spawnattr_init(&spawn_attr_); + int status = posix_spawnp(&pid, + argv[0], + &file_actions_, + &spawn_attr_, + const_cast(argv), + environ); + if (status != 0) { + LOG(ERROR) << "posix_spawnp failed: " << strerror(status); + return; + } + use_spwan_worker_ = true; + done.store(true); +} + WorkerServer::WorkerServer(int local_worker_idx, const std::string& master_node_addr, std::atomic& done, const ParallelArgs& parallel_args, const torch::Device& d, const runtime::Options& options, - WorkerType worker_type) { + WorkerType worker_type, + bool use_spawn_worker) { if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) { - // TODO: Use Process or thread. + if (use_spawn_worker) { + // start worker in a spawn process(for offline inference worker.) + create_spawn_server( + local_worker_idx, master_node_addr, done, parallel_args, d, options); + return; + } + + // start worker in a thread. worker_thread_ = std::make_unique(&WorkerServer::create_server, this, std::cref(options), @@ -184,6 +243,11 @@ WorkerServer::~WorkerServer() { if (worker_thread_->joinable()) { worker_thread_->join(); } + + if (use_spwan_worker_) { + posix_spawn_file_actions_destroy(&file_actions_); + posix_spawnattr_destroy(&spawn_attr_); + } } } // namespace xllm diff --git a/xllm/core/distributed_runtime/worker_server.h b/xllm/core/distributed_runtime/worker_server.h index 396dd546..a3b8045c 100644 --- a/xllm/core/distributed_runtime/worker_server.h +++ b/xllm/core/distributed_runtime/worker_server.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -41,13 +42,11 @@ class WorkerServer { const ParallelArgs& parallel_args, const torch::Device& d, const runtime::Options& options, - WorkerType worker_type); + WorkerType worker_type, + bool use_spawn_worker = false); virtual ~WorkerServer(); - private: - DISALLOW_COPY_AND_ASSIGN(WorkerServer); - void create_server(const runtime::Options& options, std::atomic& done, const std::string& master_node_addr, @@ -58,12 +57,27 @@ class WorkerServer { int local_rank, int32_t ep_size); + private: + DISALLOW_COPY_AND_ASSIGN(WorkerServer); + + void create_spawn_server(int local_rank, + const std::string& master_node_addr, + std::atomic& done, + const ParallelArgs& parallel_args, + const torch::Device& d, + const runtime::Options& options); + bool sync_master_node(const std::string& master_node_addr, proto::AddressInfo& addr_info, proto::CommUniqueIdList& uids); private: std::unique_ptr worker_thread_; + + // for offline inference spawn process worker. + bool use_spwan_worker_ = false; + posix_spawn_file_actions_t file_actions_; + posix_spawnattr_t spawn_attr_; }; } // namespace xllm diff --git a/xllm/core/framework/prefix_cache/prefix_cache_benchmark.cpp b/xllm/core/framework/prefix_cache/prefix_cache_benchmark.cpp index cd367e5b..b22fb51d 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache_benchmark.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache_benchmark.cpp @@ -35,6 +35,7 @@ static void BM_HashSearch(benchmark::State& state) { // token_id_count; assert((token_id_count / block_size) < total_blocks); + uint32_t n_blocks = token_id_count / block_size; state.PauseTiming(); BlockManager::Options options; @@ -51,12 +52,10 @@ static void BM_HashSearch(benchmark::State& state) { std::generate( token_ids.begin(), token_ids.end(), [&]() { return dist(gen); }); - uint32_t n_blocks = token_id_count / block_size; - std::vector token_blocks = block_manager.allocate(n_blocks); Slice slice_token_blocks(token_blocks); Slice slice_token_ids(token_ids); - std::vector match_token_ids(token_ids); + Slice match_token_ids(token_ids); prefix_cache.insert(slice_token_ids, slice_token_blocks); state.ResumeTiming(); diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 98555de3..fae37725 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -74,7 +74,8 @@ LLMEngine::LLMEngine(const runtime::Options& options, CHECK_EQ(device.type(), device_type) << "All devices should be the same type"; #if defined(USE_NPU) - FLAGS_enable_atb_comm_multiprocess = (options.nnodes() > 1); + FLAGS_enable_atb_comm_multiprocess = + options.enable_offline_inference() || (options.nnodes() > 1); #endif } diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 61d24add..8b1a4590 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -109,7 +109,9 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .enable_disagg_pd(options_.enable_disagg_pd()) .enable_service_routing(options_.enable_service_routing()) .enable_cache_upload(options_.enable_cache_upload()) - .enable_schedule_overlap(options_.enable_schedule_overlap()); + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_offline_inference(options_.enable_offline_inference()) + .spawn_worker_path(options_.spawn_worker_path()); auto engine = std::make_unique(eng_options); engine_ = std::move(engine); @@ -148,7 +150,9 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .enable_disagg_pd(options_.enable_disagg_pd()) .enable_service_routing(options_.enable_service_routing()) .enable_schedule_overlap(options_.enable_schedule_overlap()) - .enable_cache_upload(options_.enable_cache_upload()); + .enable_cache_upload(options_.enable_cache_upload()) + .enable_offline_inference(options_.enable_offline_inference()) + .spawn_worker_path(options_.spawn_worker_path()); if (options_.device_ip().has_value()) { spec_options.device_ip(options_.device_ip().value()); @@ -192,7 +196,9 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .store_protocol(options_.store_protocol()) .store_master_server_entry(options_.store_master_server_entry()) .store_metadata_connstring(options_.store_metadata_connstring()) - .enable_continuous_kvcache(options_.enable_continuous_kvcache()); + .enable_continuous_kvcache(options_.enable_continuous_kvcache()) + .enable_offline_inference(options_.enable_offline_inference()) + .spawn_worker_path(options_.spawn_worker_path()); if (options_.device_ip().has_value()) { eng_options.device_ip(options_.device_ip().value()); diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index 6e96dbb5..9d2c03ec 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -157,6 +157,11 @@ struct Options { // enable continuous kvcache PROPERTY(bool, enable_continuous_kvcache) = false; + + // start with offline inference, default is false + PROPERTY(bool, enable_offline_inference) = false; + // the path to spawn worker binary + PROPERTY(std::string, spawn_worker_path) = ""; }; } // namespace runtime diff --git a/xllm/core/scheduler/continuous_scheduler.cpp b/xllm/core/scheduler/continuous_scheduler.cpp index d8b5ad00..b1587bb2 100644 --- a/xllm/core/scheduler/continuous_scheduler.cpp +++ b/xllm/core/scheduler/continuous_scheduler.cpp @@ -957,7 +957,8 @@ void ContinuousScheduler::step_with_schedule_overlap( void ContinuousScheduler::generate() { bool batch_empty = false; - while (num_pending_requests() > 0 || !batch_empty) { + while (num_pending_requests() > 0 || !batch_empty || + request_queue_.size() > 0) { // build a batch of requests/sequences auto batch = prepare_batch(); batch_empty = true; diff --git a/xllm/pybind/bind.cpp b/xllm/pybind/bind.cpp index e0cf014c..3191a3d7 100644 --- a/xllm/pybind/bind.cpp +++ b/xllm/pybind/bind.cpp @@ -77,7 +77,10 @@ PYBIND11_MODULE(xllm_export, m) { .def_readwrite("disable_ttft_profiling", &Options::disable_ttft_profiling_) .def_readwrite("enable_forward_interruption", - &Options::enable_forward_interruption_); + &Options::enable_forward_interruption_) + .def_readwrite("enable_offline_inference", + &Options::enable_offline_inference_) + .def_readwrite("spawn_worker_path", &Options::spawn_worker_path_); // 2. export LLMMaster py::class_(m, "LLMMaster") diff --git a/xllm/pybind/llm.py b/xllm/pybind/llm.py index 3da8cfaf..fd481eb0 100644 --- a/xllm/pybind/llm.py +++ b/xllm/pybind/llm.py @@ -1,5 +1,6 @@ import os import signal +from . import util from typing import List, Optional, Union from xllm_export import (LLMMaster, Options, RequestOutput, @@ -17,9 +18,9 @@ def __init__( max_cache_size: int = 0, max_memory_utilization: float = 0.9, disable_prefix_cache: bool = False, - max_tokens_per_batch: int = 20000, + max_tokens_per_batch: int = 20480, max_seqs_per_batch: int = 256, - max_tokens_per_chunk_for_prefill: int = 512, + max_tokens_per_chunk_for_prefill: int = -1, num_speculative_tokens: int = 0, num_handling_threads: int = 4, communication_backend: str = 'lccl', @@ -27,7 +28,6 @@ def __init__( expert_parallel_degree: int = 0, enable_mla: bool = False, disable_chunked_prefill: bool = False, - master_node_addr: str = '127.0.0.1:9988', instance_role: str = 'DEFAULT', device_ip: str = '', transfer_listen_port: int = 26000, @@ -75,7 +75,8 @@ def __init__( options.enable_chunked_prefill = False else: options.enable_chunked_prefill = True - options.master_node_addr = master_node_addr + free_port = util.get_free_port() + options.master_node_addr = "127.0.0.1:" + str(free_port) options.device_ip = device_ip options.transfer_listen_port = transfer_listen_port options.nnodes = nnodes @@ -90,6 +91,8 @@ def __init__( options.kv_cache_transfer_mode = kv_cache_transfer_mode options.disable_ttft_profiling = disable_ttft_profiling options.enable_forward_interruption = enable_forward_interruption + options.enable_offline_inference = True + options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) self.master = LLMMaster(options) def finish(self): diff --git a/xllm/pybind/util.py b/xllm/pybind/util.py new file mode 100644 index 00000000..36c29f2d --- /dev/null +++ b/xllm/pybind/util.py @@ -0,0 +1,8 @@ +import socket + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('0.0.0.0', 0)) + _, port = s.getsockname() + return port + diff --git a/xllm/pybind/vlm.py b/xllm/pybind/vlm.py index 012b7b13..2f88f899 100644 --- a/xllm/pybind/vlm.py +++ b/xllm/pybind/vlm.py @@ -1,6 +1,7 @@ import os import signal import time +from . import util from typing import List, Optional, Union from xllm_export import (VLMMaster, Options, RequestOutput, @@ -28,7 +29,6 @@ def __init__( expert_parallel_degree: int = 0, enable_mla: bool = False, disable_chunked_prefill: bool = False, - master_node_addr: str = '127.0.0.1:9988', instance_role: str = 'DEFAULT', device_ip: str = '', transfer_listen_port: int = 26000, @@ -73,7 +73,8 @@ def __init__( options.enable_chunked_prefill = False else: options.enable_chunked_prefill = True - options.master_node_addr = master_node_addr + free_port = util.get_free_port() + options.master_node_addr = "127.0.0.1:" + str(free_port) options.device_ip = device_ip options.transfer_listen_port = transfer_listen_port options.nnodes = nnodes @@ -85,6 +86,8 @@ def __init__( options.enable_disagg_pd = enable_disagg_pd options.enable_schedule_overlap = False options.kv_cache_transfer_mode = kv_cache_transfer_mode + options.enable_offline_inference = True + options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) self.master = VLMMaster(options) def finish(self):