Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/dist_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,
const runtime::Options& options) {
std::unique_ptr<CommChannel> channel;

if (net::extract_ip(FLAGS_master_node_addr) ==
if (net::extract_ip(options.master_node_addr().value_or("")) ==
net::extract_ip(worker_addrs) &&
options.enable_shm()) {
// create shared memory manager for local rank
Expand All @@ -118,6 +118,7 @@ std::unique_ptr<CommChannel> create_channel(const std::string& worker_addrs,

return channel;
}

} // namespace

void DistManager::setup_multi_node_workers(
Expand Down
7 changes: 5 additions & 2 deletions xllm/core/distributed_runtime/shm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "shm_channel.h"

#include "common/global_flags.h"
#include "util/net.h"

namespace xllm {

Expand All @@ -26,16 +27,18 @@ ShmChannel::ShmChannel(int dp_group,
: enable_shm_(options.enable_shm()) {
bool is_creator;

std::string name_prefix =
"xllm_" + net::extract_port(options.master_node_addr().value_or(""));
if (is_driver) {
auto name = ForwardSharedMemoryManager::create_unique_name(
dp_group, FORWARD_RAW_INPUT_TYPE, rank);
name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, rank);
input_shm_manager_ = std::make_unique<ForwardSharedMemoryManager>(
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
LOG(INFO) << "Create input shared memory manager with name: " << name;
}

auto name = ForwardSharedMemoryManager::create_unique_name(
dp_group, FORWARD_RAW_OUTPUT_TYPE, rank);
name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, rank);
output_shm_manager_ = std::make_unique<ForwardSharedMemoryManager>(
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
LOG(INFO) << "Create output shared memory manager with name: " << name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ SpawnWorkerServer::SpawnWorkerServer(const std::string& master_node_addr,
int device_idx,
int num_decoding_tokens,
int block_size,
bool enable_shm) {
bool enable_shm,
bool is_local) {
// TODO: pass whole xllm::runtime::Options here from main process.
xllm::runtime::Options runner_options;
runner_options.block_size(block_size)
.num_decoding_tokens(num_decoding_tokens)
.enable_schedule_overlap(false)
.enable_offline_inference(true)
.master_node_addr(master_node_addr)
.enable_shm(enable_shm);
.enable_shm(enable_shm)
.is_local(is_local);
FLAGS_enable_schedule_overlap = false;
FLAGS_master_node_addr = master_node_addr;
FLAGS_block_size = block_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class SpawnWorkerServer final {
int device_idx,
int num_decoding_tokens,
int block_size,
bool enable_shm);
bool enable_shm,
bool is_local);

~SpawnWorkerServer() = default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ limitations under the License.
// @num_decoding_tokens
// @block_size
// @enable_shm
// @is_local
int main(int argc, char* argv[]) {
if (argc < 8) {
if (argc < 9) {
LOG(ERROR)
<< "Spwan worker process receive wrong args. Need 8 args, receive "
<< "Spwan worker process receive wrong args. Need 9 args, receive "
<< argc;
return 1;
}
Expand All @@ -52,15 +53,17 @@ int main(int argc, char* argv[]) {
int num_decoding_tokens = atoi(argv[6]);
int block_size = atoi(argv[7]);
int enable_shm = atoi(argv[8]);
int is_local = atoi(argv[9]);

LOG(INFO) << "Spwan worker: "
<< "master_node_addr = " << master_node_addr
<< ", local_rank = " << local_rank
<< ", is_local = " << is_local << ", local_rank = " << local_rank
<< ", world_size = " << world_size
<< ", device_idx = " << device_idx
<< ", num_decoding_tokens = " << num_decoding_tokens
<< ", block_size = " << block_size
<< ", enable_shm = " << (enable_shm > 0) << "\n";
<< ", enable_shm = " << (enable_shm > 0)
<< ", enable_shm = " << (is_local > 0) << "\n";

xllm::SpawnWorkerServer worker(master_node_addr,
local_rank,
Expand All @@ -69,7 +72,8 @@ int main(int argc, char* argv[]) {
device_idx,
num_decoding_tokens,
block_size,
enable_shm > 0);
enable_shm > 0,
is_local > 0);

worker.run();

Expand Down
50 changes: 29 additions & 21 deletions xllm/core/distributed_runtime/worker_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ void WorkerServer::create_spawn_server(int local_rank,
const char* block_size_ptr = block_size_str.c_str();
auto enable_shm_str = std::to_string(options.enable_shm());
const char* enable_shm_ptr = enable_shm_str.c_str();
auto is_local_str = std::to_string(options.is_local());
const char* is_local_ptr = is_local_str.c_str();
std::string spawn_worker_bin_path =
options.spawn_worker_path() + "/spawn_worker";
LOG(INFO) << "Spawn worker path: " << spawn_worker_bin_path;
Expand All @@ -153,6 +155,7 @@ void WorkerServer::create_spawn_server(int local_rank,
num_decoding_tokens_ptr,
block_size_ptr,
enable_shm_ptr,
is_local_ptr,
nullptr};
pid_t pid;
posix_spawn_file_actions_init(&file_actions_);
Expand Down Expand Up @@ -181,14 +184,16 @@ void WorkerServer::prepare_shm(
int dp_local_tp_size = parallel_args.world_size() / parallel_args.dp_size();
int dp_group = parallel_args.rank() / dp_local_tp_size;

std::string name_prefix =
"xllm_" + net::extract_port(options.master_node_addr().value());
string name = ForwardSharedMemoryManager::create_unique_name(
dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank());
name_prefix, dp_group, FORWARD_RAW_INPUT_TYPE, parallel_args.rank());
input_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE);
LOG(INFO) << "Create input shared memory manager with name: " << name;

name = ForwardSharedMemoryManager::create_unique_name(
dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank());
name_prefix, dp_group, FORWARD_RAW_OUTPUT_TYPE, parallel_args.rank());
output_shm_manager = std::make_unique<ForwardSharedMemoryManager>(
name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE);
LOG(INFO) << "Create output shared memory manager with name: " << name;
Expand All @@ -204,31 +209,34 @@ WorkerServer::WorkerServer(int local_worker_idx,
WorkerType worker_type,
bool use_spawn_worker) {
if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) {
// TODO: Refactor these code later.
if (use_spawn_worker) {
// start worker in a spawn process(for offline inference worker.)
create_spawn_server(
local_worker_idx, master_node_addr, done, parallel_args, d, options);
return;
}
} else {
std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
prepare_shm(
parallel_args, options, input_shm_manager, output_shm_manager);

std::unique_ptr<ForwardSharedMemoryManager> input_shm_manager = nullptr;
std::unique_ptr<ForwardSharedMemoryManager> output_shm_manager = nullptr;
prepare_shm(parallel_args, options, input_shm_manager, output_shm_manager);
// start worker in a thread.
worker_thread_ =
std::make_unique<std::thread>(&WorkerServer::create_server,
this,
std::cref(options),
std::ref(done),
std::cref(master_node_addr),
std::cref(d),
parallel_args.world_size(),
parallel_args.rank(),
parallel_args.dp_size(),
local_worker_idx,
parallel_args.ep_size(),
std::move(input_shm_manager),
std::move(output_shm_manager));
// start worker in a thread.
worker_thread_ =
std::make_unique<std::thread>(&WorkerServer::create_server,
this,
std::cref(options),
std::ref(done),
std::cref(master_node_addr),
std::cref(d),
parallel_args.world_size(),
parallel_args.rank(),
parallel_args.dp_size(),
local_worker_idx,
parallel_args.ep_size(),
std::move(input_shm_manager),
std::move(output_shm_manager));
}
} else {
// TODO: support other model type later.
LOG(ERROR) << "Unsupported model type: " << worker_type;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/runtime/dit_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ DiTEngine::DiTEngine(const runtime::Options& options) : options_(options) {
}
const int32_t world_size = static_cast<int32_t>(devices.size());

CHECK(!options_.enable_shm()) << "Dit can not support enable_shm currently.";

// create workers
for (size_t i = 0; i < devices.size(); ++i) {
const int32_t rank = static_cast<int32_t>(i);
Expand Down
12 changes: 7 additions & 5 deletions xllm/core/runtime/forward_shared_memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,12 @@ ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default;

/* The shared memory filename may have duplicates when using kill -9 xllm, but
this doesn't affect usage.*/
std::string ForwardSharedMemoryManager::create_unique_name(int dp_group,
int forward_type,
int rank) {
std::string filename = "xllm_" + net::extract_port(FLAGS_master_node_addr);
std::string ForwardSharedMemoryManager::create_unique_name(
const std::string& prefix,
int dp_group,
int forward_type,
int rank) {
std::string filename = prefix;
if (forward_type == FORWARD_PB_INPUT_TYPE ||
forward_type == FORWARD_RAW_INPUT_TYPE) {
filename += "_dpg_" + std::to_string(dp_group) + "_input";
Expand Down Expand Up @@ -997,4 +999,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) {
void ForwardSharedMemoryManager::clear() {
std::memset(base_address(), 0, size());
}
} // namespace xllm
} // namespace xllm
5 changes: 3 additions & 2 deletions xllm/core/runtime/forward_shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
bool& is_creator,
ForwardType type);
~ForwardSharedMemoryManager();
static std::string create_unique_name(int dp_group,
static std::string create_unique_name(const std::string& prefix,
int dp_group,
int forward_type,
int rank);

Expand Down Expand Up @@ -121,4 +122,4 @@ class ForwardSharedMemoryManager : public SharedMemoryManager {
void* metadata_addr_ = nullptr;
ControlMetadata* control_ptr_ = nullptr;
};
} // namespace xllm
} // namespace xllm
2 changes: 2 additions & 0 deletions xllm/core/runtime/vlm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ VLMEngine::VLMEngine(const runtime::Options& options) : options_(options) {
process_groups_ = parallel_state::create_npu_process_groups(devices);
}

CHECK(!options_.enable_shm()) << "VLM can not support enable_shm currently.";

WorkerType worker_type =
(options_.task_type() == "generate") ? WorkerType::VLM : WorkerType::EVLM;
const int32_t world_size = static_cast<int32_t>(devices.size());
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/runtime/vlm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ void VLMMaster::run() {
}

void VLMMaster::generate() {
DCHECK(options_.enable_schedule_overlap())
<< "Mode generate does not support schedule overlap yet.";
const bool already_running = running_.load(std::memory_order_relaxed);
if (already_running) {
LOG(WARNING) << "VLMMaster is already running.";
LOG(WARNING) << "Generate is already running.";
return;
}

Expand Down
3 changes: 2 additions & 1 deletion xllm/pybind/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ PYBIND11_MODULE(xllm_export, m) {
.def_readwrite("enable_offline_inference",
&Options::enable_offline_inference_)
.def_readwrite("spawn_worker_path", &Options::spawn_worker_path_)
.def_readwrite("enable_shm", &Options::enable_shm_);
.def_readwrite("enable_shm", &Options::enable_shm_)
.def_readwrite("is_local", &Options::is_local_);

// 2. export LLMMaster
py::class_<LLMMaster>(m, "LLMMaster")
Expand Down
2 changes: 2 additions & 0 deletions xllm/pybind/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
disable_ttft_profiling: bool = False,
enable_forward_interruption: bool = False,
enable_shm: bool = False,
is_local: bool = True,
**kwargs,
) -> None:

Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
options.enable_offline_inference = True
options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
options.enable_shm = enable_shm
options.is_local = is_local
self.master = LLMMaster(options)

def finish(self):
Expand Down
2 changes: 2 additions & 0 deletions xllm/pybind/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
enable_schedule_overlap: bool = False,
kv_cache_transfer_mode: str = 'PUSH',
enable_shm: bool = False,
is_local: bool = True,
**kwargs,
) -> None:

Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
options.enable_offline_inference = True
options.spawn_worker_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
options.enable_shm = enable_shm
options.is_local = is_local
self.master = VLMMaster(options)

def finish(self):
Expand Down