Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xllm/core/framework/parallel_state/parallel_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ struct ParallelArgs {
// ep size
PROPERTY(int32_t, ep_size) = 1;

#if defined(USE_NPU)
// atb hccl mapping json data
PROPERTY(nlohmann::json, mapping_data);

#if defined(USE_NPU)
// atb hccl mapping
PROPERTY(atb_speed::base::Mapping, mapping);

Expand Down
24 changes: 20 additions & 4 deletions xllm/core/framework/state_dict/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,21 @@ void load_fused_weight(const StateDict& state_dict,
int32_t world_size,
std::vector<torch::Tensor>& accumulated_tensors,
torch::Tensor& weight,
bool& weight_is_loaded) {
bool& weight_is_loaded,
int32_t num_kv_head_replicas) {
// return if the weight is already loaded
if (weight_is_loaded) {
return;
}

weight_is_loaded = load_tensor_list(
state_dict, prefixes, name, dim, rank, world_size, accumulated_tensors);
weight_is_loaded = load_tensor_list(state_dict,
prefixes,
name,
dim,
rank,
world_size,
accumulated_tensors,
num_kv_head_replicas);

if (weight_is_loaded) {
const auto merged_weight = torch::cat(accumulated_tensors, /*dim=*/dim);
Expand All @@ -106,7 +113,8 @@ bool load_tensor_list(const StateDict& state_dict,
int64_t dim,
int32_t rank,
int32_t world_size,
std::vector<torch::Tensor>& tensors) {
std::vector<torch::Tensor>& tensors,
int32_t num_kv_head_replicas) {
// resize the accumulated weight list if needed
if (tensors.size() < prefixes.size()) {
tensors.resize(prefixes.size());
Expand All @@ -118,6 +126,14 @@ bool load_tensor_list(const StateDict& state_dict,
continue;
}

// When the number of key/value heads is smaller than the number of query
// heads (e.g., multi-query/grouped-query attention), the key/value head may
// be replicated while the query heads are partitioned.
if (i == 1 && num_kv_head_replicas > 1) {
rank = rank / num_kv_head_replicas;
world_size = world_size / num_kv_head_replicas;
}

const std::string tensor_name = prefixes[i] + name;
torch::Tensor tensor;
if (dim < 0) {
Expand Down
18 changes: 16 additions & 2 deletions xllm/core/framework/state_dict/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,17 @@ void load_fused_weight(const StateDict& state_dict,
int32_t world_size,
std::vector<torch::Tensor>& accumulated_tensors,
torch::Tensor& weight,
bool& weight_is_loaded);
bool& weight_is_loaded,
int32_t num_kv_head_replicas = 1);

bool load_tensor_list(const StateDict& state_dict,
const std::vector<std::string>& prefixes,
const std::string& name,
int64_t dim,
int32_t rank,
int32_t world_size,
std::vector<torch::Tensor>& accumulated_tensors);
std::vector<torch::Tensor>& accumulated_tensors,
int32_t num_kv_head_replicas = 1);

void load_moe_weight(const StateDict& state_dict,
const std::string& sub_prefix,
Expand Down Expand Up @@ -114,6 +116,18 @@ void load_moe_fused_weight(const StateDict& state_dict,
name##_, \
name##_is_loaded_);

#define LOAD_QKV_WEIGHT(name, dim, num_kv_head_replicas) \
weight::load_fused_weight(state_dict, \
prefixes, \
#name, \
dim, \
rank, \
world_size, \
name##_list_, \
name##_, \
name##_is_loaded_, \
num_kv_head_replicas);

#define LOAD_SHARDED_WEIGHT(name, dim) \
weight::load_sharded_weight( \
state_dict, #name, dim, rank, world_size, name##_, name##_is_loaded_);
Expand Down
12 changes: 6 additions & 6 deletions xllm/core/layers/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@ cc_library(
HDRS
flashinfer_workspace.h
deepseek_v2_attention.h
qwen3_attention.h
qwen2_attention.h
attention.h
fuse_norm.h
rotary_embedding.h
fused_moe.h
dense_mlp.h
qwen3_decoder_layer.h
qwen2_decoder_layer.h
qwen3_moe_decoder_layer.h
linear_impl.h
linear.h
word_embedding_impl.h
layer_utils.h
indexer.h
SRCS
flashinfer_workspace.cpp
deepseek_v2_attention.cpp
qwen3_attention.cpp
qwen2_attention.cpp
attention.cpp
fuse_norm.cpp
rotary_embedding.cpp
fused_moe.cpp
dense_mlp.cpp
qwen3_decoder_layer.cpp
qwen2_decoder_layer.cpp
qwen3_moe_decoder_layer.cpp
linear_impl.cpp
linear.cpp
word_embedding_impl.cpp
layer_utils.cpp
indexer.cpp
DEPS
Expand Down
7 changes: 1 addition & 6 deletions xllm/core/layers/common/fuse_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ torch::Tensor FusedRMSNormImpl::forward_output(torch::Tensor& input,
}

void FusedRMSNormImpl::load_state_dict(const StateDict& state_dict) {
const auto weight = state_dict.get_tensor("weight");
if (weight.defined()) {
CHECK_EQ(weight_.sizes(), weight.sizes())
<< "weight size mismatch for " << name();
weight_.copy_(weight);
}
LOAD_WEIGHT(weight);
}

} // namespace layer
Expand Down
8 changes: 0 additions & 8 deletions xllm/core/layers/common/indexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,5 @@ void IndexerImpl::load_state_dict(const StateDict& state_dict) {
state_dict.get_dict_with_prefix("weights_proj."));
}

// whether the weight is loaded
void IndexerImpl::verify_loaded_weights(const std::string& prefix) const {
// Verify that all linear layers have loaded their weights
wq_b_->verify_loaded_weights(prefix + "wq_b.");
wk_->verify_loaded_weights(prefix + "wk.");
weights_proj_->verify_loaded_weights(prefix + "weights_proj.");
}

} // namespace layer
} // namespace xllm
3 changes: 0 additions & 3 deletions xllm/core/layers/common/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ class IndexerImpl : public torch::nn::Module {
// load the weight from the checkpoint
void load_state_dict(const StateDict& state_dict);

// whether the weight is loaded
void verify_loaded_weights(const std::string& prefix = "") const;

private:
int64_t dim_;
int64_t n_heads_;
Expand Down
14 changes: 1 addition & 13 deletions xllm/core/layers/common/layer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ limitations under the License.
namespace xllm {
namespace layer {

bool is_dummy_run(const ModelInputParams& input_params,
const ParallelArgs& parallel_args) {
int64_t dp_rank = 0;
if (parallel_args.dp_size() > 1) {
dp_rank = parallel_args.dp_local_process_group_->rank();
}
if (input_params.dp_global_token_nums.size() <= 1) {
return input_params.q_max_seq_len == 0;
}
return input_params.dp_global_token_nums[dp_rank] == 0;
}

void update_dummy_run_input(int64_t dp_rank,
torch::Tensor& positions,
ModelInputParams& input_params) {
Expand All @@ -48,4 +36,4 @@ void update_dummy_run_input(int64_t dp_rank,
}

} // namespace layer
} // namespace xllm
} // namespace xllm
5 changes: 1 addition & 4 deletions xllm/core/layers/common/layer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ limitations under the License.
namespace xllm {
namespace layer {

bool is_dummy_run(const ModelInputParams& input_params,
const ParallelArgs& parallel_args);

void update_dummy_run_input(int64_t dp_rank,
torch::Tensor& positions,
ModelInputParams& input_params);

} // namespace layer
} // namespace xllm
} // namespace xllm
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "linear_impl.h"
#include "linear.h"

#include <glog/logging.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -82,9 +82,8 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(

torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {
input = input.to(device_);
auto bias = (bias_.defined() && rank_ == 0)
? std::optional<torch::Tensor>(bias_)
: std::nullopt;
auto bias =
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;

torch::Tensor output;

Expand Down Expand Up @@ -148,8 +147,8 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {

// load the weight from the checkpoint
void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
const auto rank = rank_;
const auto world_size = world_size_;
const int64_t rank = rank_;
const int64_t world_size = world_size_;

// load and merge the weights on dim 0
// If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
Expand All @@ -172,8 +171,8 @@ void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
void ColumnParallelLinearImpl::load_state_dict(
const StateDict& state_dict,
const std::vector<std::string>& prefixes) {
const auto rank = rank_;
const auto world_size = world_size_;
const int64_t rank = rank_;
const int64_t world_size = world_size_;

// load and merge the weights on dim 0
// If quant_args_ indicates SmoothQuant, load qweight
Expand All @@ -192,7 +191,6 @@ void ColumnParallelLinearImpl::load_state_dict(
break;
}
}

LOAD_FUSED_WEIGHT(qweight, 0);
LOAD_FUSED_WEIGHT(per_channel_scale, 0);
} else {
Expand Down Expand Up @@ -223,36 +221,32 @@ QKVParallelLinearImpl::QKVParallelLinearImpl(
parallel_args_(parallel_args),
options_(options),
device_(options.device()) {
const int32_t QKV_CNT = 3;
rank_ = parallel_args_.tp_group_->rank();
world_size_ = parallel_args_.tp_group_->world_size();
const int64_t out_features_per_partition =
(num_heads + 2 * num_kv_heads) * head_size;
// Note: torch.nn.functional.linear performs XA^T + b and as a result
// we allocate the transpose.
qkv_weight_ = register_parameter(
weight_ = register_parameter(
"weight",
torch::empty({out_features_per_partition, hidden_size}, options),
/*requires_grad=*/false);
qkv_weight_list_.resize(QKV_CNT);

if (bias) {
qkv_bias_ =
bias_ =
register_parameter("bias",
torch::empty({out_features_per_partition}, options),
/*requires_grad=*/false);
qkv_bias_list_.resize(QKV_CNT);
}
}

torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) {
input = input.to(device_);
auto bias = (qkv_bias_.defined() && rank_ == 0)
? std::optional<torch::Tensor>(qkv_bias_)
: std::nullopt;
auto bias =
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;
xllm::kernel::MatmulParams matmul_params;
matmul_params.a = input;
matmul_params.b = qkv_weight_;
matmul_params.b = weight_;
matmul_params.bias = bias;

auto output = xllm::kernel::matmul(matmul_params);
Expand All @@ -262,46 +256,13 @@ torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) {
return output;
}

bool QKVParallelLinearImpl::load_qkv_weight(const StateDict& state_dict,
int32_t index) {
if (qkv_weight_list_[index].defined() || state_dict.size() == 0) {
return false;
}
DEFINE_WEIGHT(weight);
int64_t out_feature = num_heads_ * head_size_;
int64_t rank = rank_;
int64_t world_size = world_size_;
if (index > 0) {
rank = rank_ / num_kv_head_replicas_;
world_size = world_size_ / num_kv_head_replicas_;
out_feature = num_kv_heads_ * head_size_;
}
weight_ = torch::empty({out_feature, hidden_size_}, options_);
LOAD_SHARDED_WEIGHT(weight, 0);
if (weight_is_loaded_) {
qkv_weight_list_[index] = weight_.clone();
}
return weight_is_loaded_;
}

void QKVParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
std::vector<std::string> prefixes = {"q_proj.", "k_proj.", "v_proj."};
if (!qkv_weight_is_loaded_) {
bool all_loaded = true;
for (size_t i = 0; i < prefixes.size(); ++i) {
all_loaded =
all_loaded &&
load_qkv_weight(state_dict.get_dict_with_prefix(prefixes[i]), i);
}
if (all_loaded) {
const auto merged_weight = torch::cat(qkv_weight_list_, /*dim=*/0);
CHECK_EQ(qkv_weight_.sizes(), merged_weight.sizes())
<< "weight size mismatch";
qkv_weight_.copy_(merged_weight);
// release the memory for weight_list
qkv_weight_list_.clear();
qkv_weight_is_loaded_ = true;
}
const int64_t rank = rank_;
const int64_t world_size = world_size_;
LOAD_QKV_WEIGHT(weight, 0, num_kv_head_replicas_);
if (bias_.defined()) {
LOAD_QKV_WEIGHT(bias, 0, num_kv_head_replicas_);
}
}

Expand Down Expand Up @@ -424,8 +385,8 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) {

// load the weight from the checkpoint
void RowParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
const auto rank = rank_;
const auto world_size = world_size_;
const int64_t rank = rank_;
const int64_t world_size = world_size_;

// If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
// normal weight.
Expand Down Expand Up @@ -462,7 +423,6 @@ ReplicatedLinearImpl::ReplicatedLinearImpl(
}

torch::Tensor ReplicatedLinearImpl::forward(torch::Tensor input) {
namespace F = torch::nn::functional;
auto bias =
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;
xllm::kernel::MatmulParams matmul_params;
Expand Down
Loading