Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ DEFINE_bool(
"The default prefetching ratio for gateup weight is 40%."
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");

// rec prefill-only mode
// --- rec prefill-only mode ---
DEFINE_bool(enable_rec_prefill_only,
false,
"Enable rec prefill-only mode (no decoder self-attention blocks "
Expand All @@ -438,6 +438,9 @@ DEFINE_bool(
"Whether to enable dp load balance, if true, sequences within a single "
"dp batch will be shuffled.");

// --- the seed for random number generator ---
DEFINE_int32(random_seed, -1, "Random seed for random number generator.");

// --- dit cache config ---

DEFINE_string(dit_cache_policy,
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ DECLARE_int32(flashinfer_workspace_buffer_size);

DECLARE_bool(enable_dp_balance);

DECLARE_int32(random_seed);

DECLARE_string(dit_cache_policy);

DECLARE_int64(dit_cache_warmup_steps);
Expand Down
10 changes: 7 additions & 3 deletions xllm/core/distributed_runtime/comm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,12 @@ bool CommChannel::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
return true;
}

bool CommChannel::init_model(const std::string& model_weights_path) {
proto::ModelPath request;
bool CommChannel::init_model(const std::string& model_weights_path,
int32_t random_seed) {
proto::InitModelRequest request;

request.set_model_weights_path(model_weights_path);
request.set_random_seed(random_seed);
proto::Status response;
brpc::Controller cntl;
stub_->InitModel(&cntl, &request, &response, nullptr);
Expand All @@ -226,10 +228,12 @@ bool CommChannel::init_model(const std::string& model_weights_path) {
}

bool CommChannel::init_model_async(const std::string& model_weights_path,
int32_t random_seed,
folly::Promise<bool>& promise) {
proto::ModelPath request;
proto::InitModelRequest request;

request.set_model_weights_path(model_weights_path);
request.set_random_seed(random_seed);
auto done = new InitModelClosure();
done->promise = std::move(promise);
stub_->InitModel(&done->cntl, &request, &done->response, done);
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/distributed_runtime/comm_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ class CommChannel {
const std::vector<std::string>& device_ips,
const std::vector<uint16_t>& ports);

virtual bool init_model(const std::string& model_weights_path);
virtual bool init_model(const std::string& model_weights_path,
int32_t random_seed);

virtual bool init_model_async(const std::string& model_weights_path,
int32_t random_seed,
folly::Promise<bool>& promise);

virtual bool estimate_kv_cache_capacity(int64_t& available_memory,
Expand Down
20 changes: 12 additions & 8 deletions xllm/core/distributed_runtime/remote_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ bool RemoteWorker::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
return channel_->unlink_cluster(cluster_ids, addrs, device_ips, ports);
}

bool RemoteWorker::init_model(const std::string& model_weights_path) {
return channel_->init_model(model_weights_path);
bool RemoteWorker::init_model(const std::string& model_weights_path,
int32_t random_seed) {
return channel_->init_model(model_weights_path, random_seed);
}

std::tuple<int64_t, int64_t> RemoteWorker::estimate_kv_cache_capacity() {
Expand Down Expand Up @@ -190,14 +191,17 @@ folly::SemiFuture<folly::Unit> RemoteWorker::process_group_test_async() {
}

folly::SemiFuture<bool> RemoteWorker::init_model_async(
const std::string& model_weights_path) {
const std::string& model_weights_path,
int32_t random_seed) {
folly::Promise<bool> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule(
[this, model_weights_path, promise = std::move(promise)]() mutable {
// call InitModel with callback
channel_->init_model_async(model_weights_path, promise);
});
threadpool_.schedule([this,
model_weights_path,
random_seed,
promise = std::move(promise)]() mutable {
// call InitModel with callback
channel_->init_model_async(model_weights_path, random_seed, promise);
});
return future;
}

Expand Down
6 changes: 4 additions & 2 deletions xllm/core/distributed_runtime/remote_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class RemoteWorker : public WorkerClient {

bool wait_for_server_ready(const std::string& server_address);

virtual bool init_model(const std::string& model_weights_path) override;
virtual bool init_model(const std::string& model_weights_path,
int32_t random_seed) override;

virtual std::tuple<int64_t, int64_t> estimate_kv_cache_capacity() override;

Expand Down Expand Up @@ -87,7 +88,8 @@ class RemoteWorker : public WorkerClient {
const ForwardInput& inputs) override;

virtual folly::SemiFuture<bool> init_model_async(
const std::string& model_weights_path) override;
const std::string& model_weights_path,
int32_t random_seed) override;

virtual folly::SemiFuture<std::tuple<int64_t, int64_t>>
estimate_kv_cache_capacity_async() override;
Expand Down
6 changes: 4 additions & 2 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,15 @@ void WorkerService::Hello(::google::protobuf::RpcController* controller,
}

void WorkerService::InitModel(::google::protobuf::RpcController* controller,
const proto::ModelPath* request,
const proto::InitModelRequest* request,
proto::Status* response,
::google::protobuf::Closure* done) {
threadpool_->schedule([this, controller, request, response, done]() mutable {
brpc::ClosureGuard done_guard(done);
auto model_weights_path = request->model_weights_path();
auto init_future = worker_->init_model_async(model_weights_path);
auto random_seed = request->random_seed();
auto init_future =
worker_->init_model_async(model_weights_path, random_seed);
bool status = std::move(init_future).get();
if (!status) {
response->set_ok(false);
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/distributed_runtime/worker_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class WorkerService : public proto::DistributeWorker {
::google::protobuf::Closure* done) override;

void InitModel(::google::protobuf::RpcController* controller,
const proto::ModelPath* request,
const proto::InitModelRequest* request,
proto::Status* response,
::google::protobuf::Closure* done) override;

Expand Down
21 changes: 20 additions & 1 deletion xllm/core/platform/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ limitations under the License.
==============================================================================*/

#include "device.h"
#if defined(USE_MLU)
#if defined(USE_NPU)
#include <torch_npu/csrc/aten/NPUGeneratorImpl.h>
#elif defined(USE_MLU)
#include <cn_api.h>
#include <torch_mlu/csrc/framework/core/device.h>
#include <torch_mlu/csrc/framework/core/device_utils.h>
#include <torch_mlu/csrc/framework/generator/generator_impl.h>
#elif defined(USE_CUDA)
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
Expand All @@ -39,6 +42,22 @@ void Device::set_device() const {
#endif
}

void Device::set_seed(uint64_t seed) const {
torch::manual_seed(seed);
#if defined(USE_NPU)
auto gen = at_npu::detail::getDefaultNPUGenerator(index());
gen.set_current_seed(seed);
#elif defined(USE_MLU)
auto gen = torch_mlu::getDefaultMLUGenerator(index());
{
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
#elif defined(USE_CUDA)
torch::cuda::manual_seed(seed);
#endif
}

const torch::Device& Device::unwrap() const { return device_; }

int32_t Device::index() const { return device_.index(); }
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/platform/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Device {

void set_device() const;

void set_seed(uint64_t seed = 42) const;

const torch::Device& unwrap() const;
int32_t index() const;

Expand Down
3 changes: 2 additions & 1 deletion xllm/core/runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ bool LLMEngine::init_model() {
LOG(INFO) << "Initializing model with " << args_;
LOG(INFO) << "Initializing model with quant args: " << quant_args_;
LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_;
LOG(INFO) << "Initializing model with random seed: " << FLAGS_random_seed;

// init model for each worker in parallel
// multiple workers, call async init
std::vector<folly::SemiFuture<bool>> futures;
futures.reserve(worker_clients_num_);
for (auto& worker : worker_clients_) {
futures.push_back(worker->init_model_async(model_path));
futures.push_back(worker->init_model_async(model_path, FLAGS_random_seed));
}
// wait for all futures to complete
auto results = folly::collectAll(futures).get();
Expand Down
8 changes: 5 additions & 3 deletions xllm/core/runtime/speculative_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ SpeculativeWorkerImpl::SpeculativeWorkerImpl(const ParallelArgs& parallel_args,
std::make_unique<LLMWorkerImpl>(parallel_args, device, runtime_options);
}

bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path) {
bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path,
int32_t random_seed) {
// initialize model
bool result = true;
if (impl_->get_status() == WorkerImpl::Status::UNINITIALIZED) {
result = impl_->WorkerImpl::init_model(model_weights_path);
result = impl_->WorkerImpl::init_model(model_weights_path, random_seed);
if (result) {
dtype_ = impl_->dtype();
embedding_size_ = impl_->hidden_size();
}
} else {
CHECK_EQ(draft_impl_->get_status(), WorkerImpl::Status::UNINITIALIZED);
result = draft_impl_->WorkerImpl::init_model(model_weights_path);
result =
draft_impl_->WorkerImpl::init_model(model_weights_path, random_seed);
}

if (draft_impl_->get_status() == WorkerImpl::Status::LOADED) {
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/runtime/speculative_worker_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class SpeculativeWorkerImpl : public WorkerImpl {
return true;
};

bool init_model(const std::string& model_weights_path) override;
bool init_model(const std::string& model_weights_path,
int32_t random_seed) override;

void get_device_info(std::string& device_ip, uint16_t& port) override {
impl_->get_device_info(device_ip, port);
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/runtime/vlm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,14 @@ bool VLMEngine::init_model() {
LOG(INFO) << "Initializing model with " << args_;
LOG(INFO) << "Initializing model with quant args: " << quant_args_;
LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_;
LOG(INFO) << "Initializing model with random seed: " << FLAGS_random_seed;

// init model for each worker in parallel
// multiple workers, call async init
std::vector<folly::SemiFuture<bool>> futures;
futures.reserve(worker_clients_num_);
for (auto& worker : worker_clients_) {
futures.push_back(worker->init_model_async(model_path));
futures.push_back(worker->init_model_async(model_path, FLAGS_random_seed));
}
// wait for all futures to complete
auto results = folly::collectAll(futures).get();
Expand Down
10 changes: 6 additions & 4 deletions xllm/core/runtime/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ Worker::Worker(const ParallelArgs& parallel_args,

Worker::~Worker() { delete impl_; }

bool Worker::init_model(const std::string& model_weights_path) {
return impl_->init_model(model_weights_path);
bool Worker::init_model(const std::string& model_weights_path,
int32_t random_seed) {
return impl_->init_model(model_weights_path, random_seed);
}

bool Worker::allocate_kv_cache(
Expand Down Expand Up @@ -127,8 +128,9 @@ folly::SemiFuture<folly::Unit> Worker::process_group_test_async() {

// initialize model, cache manager. async call
folly::SemiFuture<bool> Worker::init_model_async(
const std::string& model_weights_path) {
return impl_->init_model_async(model_weights_path);
const std::string& model_weights_path,
int32_t random_seed) {
return impl_->init_model_async(model_weights_path, random_seed);
}

folly::SemiFuture<bool> Worker::allocate_kv_cache_async(
Expand Down
5 changes: 3 additions & 2 deletions xllm/core/runtime/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Worker {
~Worker();

// initialize model, cache manager. blocking call
bool init_model(const std::string& model_weights_path);
bool init_model(const std::string& model_weights_path, int32_t random_seed);

std::tuple<int64_t, int64_t> estimate_kv_cache_capacity();

Expand Down Expand Up @@ -80,7 +80,8 @@ class Worker {

// initialize model, cache manager. async call
folly::SemiFuture<bool> init_model_async(
const std::string& model_weights_path);
const std::string& model_weights_path,
int32_t random_seed);

folly::SemiFuture<std::tuple<int64_t, int64_t>>
estimate_kv_cache_capacity_async();
Expand Down
10 changes: 6 additions & 4 deletions xllm/core/runtime/worker_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ limitations under the License.

namespace xllm {

bool WorkerClient::init_model(const std::string& model_weights_path) {
return worker_->init_model(model_weights_path);
bool WorkerClient::init_model(const std::string& model_weights_path,
int32_t random_seed) {
return worker_->init_model(model_weights_path, random_seed);
}

bool WorkerClient::allocate_kv_cache(
Expand Down Expand Up @@ -120,8 +121,9 @@ folly::SemiFuture<folly::Unit> WorkerClient::process_group_test_async() {

// initialize model, cache manager. async call
folly::SemiFuture<bool> WorkerClient::init_model_async(
const std::string& model_weights_path) {
return worker_->init_model_async(model_weights_path);
const std::string& model_weights_path,
int32_t random_seed) {
return worker_->init_model_async(model_weights_path, random_seed);
}

folly::SemiFuture<bool> WorkerClient::allocate_kv_cache_async(
Expand Down
6 changes: 4 additions & 2 deletions xllm/core/runtime/worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class WorkerClient {
virtual ~WorkerClient() = default;

// initialize model, cache manager. blocking call
virtual bool init_model(const std::string& model_weights_path);
virtual bool init_model(const std::string& model_weights_path,
int32_t random_seed);

virtual std::tuple<int64_t, int64_t> estimate_kv_cache_capacity();

Expand Down Expand Up @@ -82,7 +83,8 @@ class WorkerClient {

// initialize model, cache manager. async call
virtual folly::SemiFuture<bool> init_model_async(
const std::string& model_weights_path);
const std::string& model_weights_path,
int32_t random_seed);

virtual folly::SemiFuture<std::tuple<int64_t, int64_t>>
estimate_kv_cache_capacity_async();
Expand Down
21 changes: 14 additions & 7 deletions xllm/core/runtime/worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,19 +565,26 @@ folly::SemiFuture<folly::Unit> WorkerImpl::process_group_test_async() {

// initialize model, cache manager. async call
folly::SemiFuture<bool> WorkerImpl::init_model_async(
const std::string& model_weights_path) {
const std::string& model_weights_path,
int32_t random_seed) {
folly::Promise<bool> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule(
[this, model_weights_path, promise = std::move(promise)]() mutable {
auto status = this->init_model(model_weights_path);
promise.setValue(status);
});
threadpool_.schedule([this,
model_weights_path,
random_seed,
promise = std::move(promise)]() mutable {
auto status = this->init_model(model_weights_path, random_seed);
promise.setValue(status);
});

return future;
}

bool WorkerImpl::init_model(const std::string& model_weights_path) {
bool WorkerImpl::init_model(const std::string& model_weights_path,
int32_t random_seed) {
// set same random seed for all worker
device_.set_seed(random_seed);

auto model_loader = ModelLoader::create(model_weights_path);
auto tokenizer = model_loader->tokenizer();
CHECK(tokenizer != nullptr);
Expand Down
Loading