From 3bda858ad4c9bb6fd9c934e20d50f6763996bf5e Mon Sep 17 00:00:00 2001 From: guoxueting Date: Mon, 10 Nov 2025 15:51:58 +0800 Subject: [PATCH] feat: support qwen2 and rename qwen3 layers. --- .../framework/parallel_state/parallel_args.h | 2 +- xllm/core/framework/state_dict/utils.cpp | 24 +- xllm/core/framework/state_dict/utils.h | 18 +- xllm/core/layers/common/CMakeLists.txt | 12 +- xllm/core/layers/common/fuse_norm.cpp | 7 +- xllm/core/layers/common/indexer.cpp | 8 - xllm/core/layers/common/indexer.h | 3 - xllm/core/layers/common/layer_utils.cpp | 14 +- xllm/core/layers/common/layer_utils.h | 5 +- .../common/{linear_impl.cpp => linear.cpp} | 78 ++--- xllm/core/layers/common/linear.h | 261 +++++++++++----- xllm/core/layers/common/linear_impl.h | 289 ------------------ ...wen3_attention.cpp => qwen2_attention.cpp} | 48 +-- .../{qwen3_attention.h => qwen2_attention.h} | 12 +- ...oder_layer.cpp => qwen2_decoder_layer.cpp} | 16 +- ..._decoder_layer.h => qwen2_decoder_layer.h} | 10 +- .../layers/common/qwen3_moe_decoder_layer.cpp | 4 +- .../layers/common/qwen3_moe_decoder_layer.h | 4 +- .../layers/common/word_embedding_impl.cpp | 60 ++++ xllm/core/layers/common/word_embedding_impl.h | 47 +-- xllm/core/layers/lm_head.h | 26 +- xllm/core/layers/qwen2_decoder_layer.h | 11 + xllm/core/layers/qwen3_decoder_layer.h | 10 +- xllm/core/layers/rms_norm.h | 6 + xllm/core/layers/word_embedding.h | 8 + xllm/models/llm/llm_model_base.h | 30 +- xllm/models/llm/qwen2.h | 8 + xllm/models/llm/qwen3.h | 28 +- xllm/models/llm/qwen3_moe.h | 35 +-- xllm/models/models.h | 11 +- 30 files changed, 446 insertions(+), 649 deletions(-) rename xllm/core/layers/common/{linear_impl.cpp => linear.cpp} (87%) delete mode 100644 xllm/core/layers/common/linear_impl.h rename xllm/core/layers/common/{qwen3_attention.cpp => qwen2_attention.cpp} (80%) rename xllm/core/layers/common/{qwen3_attention.h => qwen2_attention.h} (84%) rename xllm/core/layers/common/{qwen3_decoder_layer.cpp => qwen2_decoder_layer.cpp} (86%) rename xllm/core/layers/common/{qwen3_decoder_layer.h => qwen2_decoder_layer.h} (89%) create mode 100644 xllm/core/layers/common/word_embedding_impl.cpp mode change 100755 => 100644 xllm/models/llm/qwen3.h diff --git a/xllm/core/framework/parallel_state/parallel_args.h b/xllm/core/framework/parallel_state/parallel_args.h index d680a3819..fab721eb9 100644 --- a/xllm/core/framework/parallel_state/parallel_args.h +++ b/xllm/core/framework/parallel_state/parallel_args.h @@ -99,10 +99,10 @@ struct ParallelArgs { // ep size PROPERTY(int32_t, ep_size) = 1; -#if defined(USE_NPU) // atb hccl mapping json data PROPERTY(nlohmann::json, mapping_data); +#if defined(USE_NPU) // atb hccl mapping PROPERTY(atb_speed::base::Mapping, mapping); diff --git a/xllm/core/framework/state_dict/utils.cpp b/xllm/core/framework/state_dict/utils.cpp index dadd7790a..375a8153f 100644 --- a/xllm/core/framework/state_dict/utils.cpp +++ b/xllm/core/framework/state_dict/utils.cpp @@ -81,14 +81,21 @@ void load_fused_weight(const StateDict& state_dict, int32_t world_size, std::vector& accumulated_tensors, torch::Tensor& weight, - bool& weight_is_loaded) { + bool& weight_is_loaded, + int32_t num_kv_head_replicas) { // return if the weight is already loaded if (weight_is_loaded) { return; } - weight_is_loaded = load_tensor_list( - state_dict, prefixes, name, dim, rank, world_size, accumulated_tensors); + weight_is_loaded = load_tensor_list(state_dict, + prefixes, + name, + dim, + rank, + world_size, + accumulated_tensors, + num_kv_head_replicas); if (weight_is_loaded) { const auto merged_weight = torch::cat(accumulated_tensors, /*dim=*/dim); @@ -106,7 +113,8 @@ bool load_tensor_list(const StateDict& state_dict, int64_t dim, int32_t rank, int32_t world_size, - std::vector& tensors) { + std::vector& tensors, + int32_t num_kv_head_replicas) { // resize the accumulated weight list if needed if (tensors.size() < prefixes.size()) { tensors.resize(prefixes.size()); @@ -118,6 +126,14 @@ bool load_tensor_list(const StateDict& state_dict, continue; } + // When the number of key/value heads is smaller than the number of query + // heads (e.g., multi-query/grouped-query attention), the key/value head may + // be replicated while the query heads are partitioned. + if (i == 1 && num_kv_head_replicas > 1) { + rank = rank / num_kv_head_replicas; + world_size = world_size / num_kv_head_replicas; + } + const std::string tensor_name = prefixes[i] + name; torch::Tensor tensor; if (dim < 0) { diff --git a/xllm/core/framework/state_dict/utils.h b/xllm/core/framework/state_dict/utils.h index 21d836dcb..039dbf410 100644 --- a/xllm/core/framework/state_dict/utils.h +++ b/xllm/core/framework/state_dict/utils.h @@ -56,7 +56,8 @@ void load_fused_weight(const StateDict& state_dict, int32_t world_size, std::vector& accumulated_tensors, torch::Tensor& weight, - bool& weight_is_loaded); + bool& weight_is_loaded, + int32_t num_kv_head_replicas = 1); bool load_tensor_list(const StateDict& state_dict, const std::vector& prefixes, @@ -64,7 +65,8 @@ bool load_tensor_list(const StateDict& state_dict, int64_t dim, int32_t rank, int32_t world_size, - std::vector& accumulated_tensors); + std::vector& accumulated_tensors, + int32_t num_kv_head_replicas = 1); void load_moe_weight(const StateDict& state_dict, const std::string& sub_prefix, @@ -114,6 +116,18 @@ void load_moe_fused_weight(const StateDict& state_dict, name##_, \ name##_is_loaded_); +#define LOAD_QKV_WEIGHT(name, dim, num_kv_head_replicas) \ + weight::load_fused_weight(state_dict, \ + prefixes, \ + #name, \ + dim, \ + rank, \ + world_size, \ + name##_list_, \ + name##_, \ + name##_is_loaded_, \ + num_kv_head_replicas); + #define LOAD_SHARDED_WEIGHT(name, dim) \ weight::load_sharded_weight( \ state_dict, #name, dim, rank, world_size, name##_, name##_is_loaded_); diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index 6109b26c8..b308b3731 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -6,15 +6,14 @@ cc_library( HDRS flashinfer_workspace.h deepseek_v2_attention.h - qwen3_attention.h + qwen2_attention.h attention.h fuse_norm.h rotary_embedding.h fused_moe.h dense_mlp.h - qwen3_decoder_layer.h + qwen2_decoder_layer.h qwen3_moe_decoder_layer.h - linear_impl.h linear.h word_embedding_impl.h layer_utils.h @@ -22,15 +21,16 @@ cc_library( SRCS flashinfer_workspace.cpp deepseek_v2_attention.cpp - qwen3_attention.cpp + qwen2_attention.cpp attention.cpp fuse_norm.cpp rotary_embedding.cpp fused_moe.cpp dense_mlp.cpp - qwen3_decoder_layer.cpp + qwen2_decoder_layer.cpp qwen3_moe_decoder_layer.cpp - linear_impl.cpp + linear.cpp + word_embedding_impl.cpp layer_utils.cpp indexer.cpp DEPS diff --git a/xllm/core/layers/common/fuse_norm.cpp b/xllm/core/layers/common/fuse_norm.cpp index c7e851e70..651cc920c 100644 --- a/xllm/core/layers/common/fuse_norm.cpp +++ b/xllm/core/layers/common/fuse_norm.cpp @@ -59,12 +59,7 @@ torch::Tensor FusedRMSNormImpl::forward_output(torch::Tensor& input, } void FusedRMSNormImpl::load_state_dict(const StateDict& state_dict) { - const auto weight = state_dict.get_tensor("weight"); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - } + LOAD_WEIGHT(weight); } } // namespace layer diff --git a/xllm/core/layers/common/indexer.cpp b/xllm/core/layers/common/indexer.cpp index f81cec6b1..f54750a0d 100644 --- a/xllm/core/layers/common/indexer.cpp +++ b/xllm/core/layers/common/indexer.cpp @@ -323,13 +323,5 @@ void IndexerImpl::load_state_dict(const StateDict& state_dict) { state_dict.get_dict_with_prefix("weights_proj.")); } -// whether the weight is loaded -void IndexerImpl::verify_loaded_weights(const std::string& prefix) const { - // Verify that all linear layers have loaded their weights - wq_b_->verify_loaded_weights(prefix + "wq_b."); - wk_->verify_loaded_weights(prefix + "wk."); - weights_proj_->verify_loaded_weights(prefix + "weights_proj."); -} - } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index 428dc7d80..4804eb767 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -61,9 +61,6 @@ class IndexerImpl : public torch::nn::Module { // load the weight from the checkpoint void load_state_dict(const StateDict& state_dict); - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const; - private: int64_t dim_; int64_t n_heads_; diff --git a/xllm/core/layers/common/layer_utils.cpp b/xllm/core/layers/common/layer_utils.cpp index 7580aeaef..ab8ac0e85 100644 --- a/xllm/core/layers/common/layer_utils.cpp +++ b/xllm/core/layers/common/layer_utils.cpp @@ -20,18 +20,6 @@ limitations under the License. namespace xllm { namespace layer { -bool is_dummy_run(const ModelInputParams& input_params, - const ParallelArgs& parallel_args) { - int64_t dp_rank = 0; - if (parallel_args.dp_size() > 1) { - dp_rank = parallel_args.dp_local_process_group_->rank(); - } - if (input_params.dp_global_token_nums.size() <= 1) { - return input_params.q_max_seq_len == 0; - } - return input_params.dp_global_token_nums[dp_rank] == 0; -} - void update_dummy_run_input(int64_t dp_rank, torch::Tensor& positions, ModelInputParams& input_params) { @@ -48,4 +36,4 @@ void update_dummy_run_input(int64_t dp_rank, } } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/common/layer_utils.h b/xllm/core/layers/common/layer_utils.h index 5761ffbf3..e9825b42d 100644 --- a/xllm/core/layers/common/layer_utils.h +++ b/xllm/core/layers/common/layer_utils.h @@ -20,12 +20,9 @@ limitations under the License. namespace xllm { namespace layer { -bool is_dummy_run(const ModelInputParams& input_params, - const ParallelArgs& parallel_args); - void update_dummy_run_input(int64_t dp_rank, torch::Tensor& positions, ModelInputParams& input_params); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/common/linear_impl.cpp b/xllm/core/layers/common/linear.cpp similarity index 87% rename from xllm/core/layers/common/linear_impl.cpp rename to xllm/core/layers/common/linear.cpp index 8bdd50887..ff1ce0304 100644 --- a/xllm/core/layers/common/linear_impl.cpp +++ b/xllm/core/layers/common/linear.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "linear_impl.h" +#include "linear.h" #include #include @@ -82,9 +82,8 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl( torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { input = input.to(device_); - auto bias = (bias_.defined() && rank_ == 0) - ? std::optional(bias_) - : std::nullopt; + auto bias = + bias_.defined() ? std::optional(bias_) : std::nullopt; torch::Tensor output; @@ -148,8 +147,8 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) { // load the weight from the checkpoint void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = rank_; - const auto world_size = world_size_; + const int64_t rank = rank_; + const int64_t world_size = world_size_; // load and merge the weights on dim 0 // If quant_args_ indicates SmoothQuant, load qweight; otherwise, load @@ -172,8 +171,8 @@ void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { void ColumnParallelLinearImpl::load_state_dict( const StateDict& state_dict, const std::vector& prefixes) { - const auto rank = rank_; - const auto world_size = world_size_; + const int64_t rank = rank_; + const int64_t world_size = world_size_; // load and merge the weights on dim 0 // If quant_args_ indicates SmoothQuant, load qweight @@ -192,7 +191,6 @@ void ColumnParallelLinearImpl::load_state_dict( break; } } - LOAD_FUSED_WEIGHT(qweight, 0); LOAD_FUSED_WEIGHT(per_channel_scale, 0); } else { @@ -223,36 +221,32 @@ QKVParallelLinearImpl::QKVParallelLinearImpl( parallel_args_(parallel_args), options_(options), device_(options.device()) { - const int32_t QKV_CNT = 3; rank_ = parallel_args_.tp_group_->rank(); world_size_ = parallel_args_.tp_group_->world_size(); const int64_t out_features_per_partition = (num_heads + 2 * num_kv_heads) * head_size; // Note: torch.nn.functional.linear performs XA^T + b and as a result // we allocate the transpose. - qkv_weight_ = register_parameter( + weight_ = register_parameter( "weight", torch::empty({out_features_per_partition, hidden_size}, options), /*requires_grad=*/false); - qkv_weight_list_.resize(QKV_CNT); if (bias) { - qkv_bias_ = + bias_ = register_parameter("bias", torch::empty({out_features_per_partition}, options), /*requires_grad=*/false); - qkv_bias_list_.resize(QKV_CNT); } } torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) { input = input.to(device_); - auto bias = (qkv_bias_.defined() && rank_ == 0) - ? std::optional(qkv_bias_) - : std::nullopt; + auto bias = + bias_.defined() ? std::optional(bias_) : std::nullopt; xllm::kernel::MatmulParams matmul_params; matmul_params.a = input; - matmul_params.b = qkv_weight_; + matmul_params.b = weight_; matmul_params.bias = bias; auto output = xllm::kernel::matmul(matmul_params); @@ -262,46 +256,13 @@ torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) { return output; } -bool QKVParallelLinearImpl::load_qkv_weight(const StateDict& state_dict, - int32_t index) { - if (qkv_weight_list_[index].defined() || state_dict.size() == 0) { - return false; - } - DEFINE_WEIGHT(weight); - int64_t out_feature = num_heads_ * head_size_; - int64_t rank = rank_; - int64_t world_size = world_size_; - if (index > 0) { - rank = rank_ / num_kv_head_replicas_; - world_size = world_size_ / num_kv_head_replicas_; - out_feature = num_kv_heads_ * head_size_; - } - weight_ = torch::empty({out_feature, hidden_size_}, options_); - LOAD_SHARDED_WEIGHT(weight, 0); - if (weight_is_loaded_) { - qkv_weight_list_[index] = weight_.clone(); - } - return weight_is_loaded_; -} - void QKVParallelLinearImpl::load_state_dict(const StateDict& state_dict) { std::vector prefixes = {"q_proj.", "k_proj.", "v_proj."}; - if (!qkv_weight_is_loaded_) { - bool all_loaded = true; - for (size_t i = 0; i < prefixes.size(); ++i) { - all_loaded = - all_loaded && - load_qkv_weight(state_dict.get_dict_with_prefix(prefixes[i]), i); - } - if (all_loaded) { - const auto merged_weight = torch::cat(qkv_weight_list_, /*dim=*/0); - CHECK_EQ(qkv_weight_.sizes(), merged_weight.sizes()) - << "weight size mismatch"; - qkv_weight_.copy_(merged_weight); - // release the memory for weight_list - qkv_weight_list_.clear(); - qkv_weight_is_loaded_ = true; - } + const int64_t rank = rank_; + const int64_t world_size = world_size_; + LOAD_QKV_WEIGHT(weight, 0, num_kv_head_replicas_); + if (bias_.defined()) { + LOAD_QKV_WEIGHT(bias, 0, num_kv_head_replicas_); } } @@ -424,8 +385,8 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) { // load the weight from the checkpoint void RowParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - const auto rank = rank_; - const auto world_size = world_size_; + const int64_t rank = rank_; + const int64_t world_size = world_size_; // If quant_args_ indicates SmoothQuant, load qweight; otherwise, load // normal weight. @@ -462,7 +423,6 @@ ReplicatedLinearImpl::ReplicatedLinearImpl( } torch::Tensor ReplicatedLinearImpl::forward(torch::Tensor input) { - namespace F = torch::nn::functional; auto bias = bias_.defined() ? std::optional(bias_) : std::nullopt; xllm::kernel::MatmulParams matmul_params; diff --git a/xllm/core/layers/common/linear.h b/xllm/core/layers/common/linear.h index a2b238ab8..8c9dd9aa9 100644 --- a/xllm/core/layers/common/linear.h +++ b/xllm/core/layers/common/linear.h @@ -18,18 +18,32 @@ limitations under the License. #include #include -#include "linear_impl.h" +#include "framework/parallel_state/parallel_args.h" +#include "framework/quant_args.h" +#include "framework/state_dict/state_dict.h" +#include "framework/state_dict/utils.h" namespace xllm { namespace layer { -class ColumnParallelLinear - : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl; +// extra args for fused linear operation +struct FusedLinearExtraArgs { + // parameters for fusing smooth quant activation mode and is_gated + std::string act_mode; + bool is_gated; + + // default constructor + FusedLinearExtraArgs(const std::string& act_mode_ = "none", + bool is_gated_ = false) + : act_mode(act_mode_), is_gated(is_gated_) {} +}; - ColumnParallelLinear( +// Linear layer with column parallelism. +// The linear layer is defined as Y = XA + b. A is parallelized along +// its second dimension as A = [A_1, ..., A_p]. +class ColumnParallelLinearImpl : public torch::nn::Module { + public: + ColumnParallelLinearImpl( int64_t in_features, int64_t out_features, bool bias, @@ -37,52 +51,111 @@ class ColumnParallelLinear const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options, - const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()) - : ModuleHolder( - std::make_shared(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options, - linear_extra_args)) {} + const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()); + + torch::Tensor forward(torch::Tensor input); + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict); + + // special load_state_dict for fused cases + void load_state_dict(const StateDict& state_dict, + const std::vector& prefixes); + + void pretty_print(std::ostream& stream) const { + stream << name() << " " << weight_.sizes() << " " << weight_.device(); + } + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + bool is_weight_loaded() const { return weight_is_loaded_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + DEFINE_FUSED_WEIGHT(weight); + DEFINE_FUSED_WEIGHT(qweight); + DEFINE_FUSED_WEIGHT(per_channel_scale); + DEFINE_WEIGHT(smooth); + DEFINE_FUSED_WEIGHT(bias); + + int64_t rank_; + int64_t world_size_; + // whether to gather the output + bool gather_output_; + at::Device device_; + // parallel args + ParallelArgs parallel_args_; + + // quantization args + QuantArgs quant_args_; + at::ScalarType output_dtype_; + FusedLinearExtraArgs linear_extra_args_; }; +TORCH_MODULE(ColumnParallelLinear); -class QKVParallelLinear - : public torch::nn::ModuleHolder { +class QKVParallelLinearImpl : public torch::nn::Module { public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = QKVParallelLinearImpl; - - QKVParallelLinear(int64_t hidden_size, - int64_t num_heads, - int64_t num_kv_heads, - int64_t head_size, - int64_t num_kv_head_replicas, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : ModuleHolder( - std::make_shared(hidden_size, - num_heads, - num_kv_heads, - head_size, - num_kv_head_replicas, - bias, - gather_output, - parallel_args, - options)) {} + QKVParallelLinearImpl(int64_t hidden_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_size, + int64_t num_kv_head_replicas, + bool bias, + bool gather_output, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor input); + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict); + + void pretty_print(std::ostream& stream) const { + stream << name() << " " << weight().sizes() << " " << weight().device(); + } + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features_per_partition, in_features] + DEFINE_FUSED_WEIGHT(weight); + DEFINE_FUSED_WEIGHT(bias); + + int64_t rank_; + int64_t world_size_; + int64_t hidden_size_; + int64_t num_heads_; + int64_t num_kv_heads_; + int64_t head_size_; + int64_t num_kv_head_replicas_; + // whether to gather the output + bool gather_output_; + at::Device device_; + // parallel args + ParallelArgs parallel_args_; + torch::TensorOptions options_; }; - -class RowParallelLinear - : public torch::nn::ModuleHolder { +TORCH_MODULE(QKVParallelLinear); + +// Linear layer with row parallelism. +// The linear layer is defined as Y = XA + b. A is parallelized along +// its first dimension and X along its second dimension as: +// - - +// | A_1 | +// | . | +// A = | . | X = [X_1, ..., X_p] +// | . | +// | A_p | +// - - +class RowParallelLinearImpl : public torch::nn::Module { public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = RowParallelLinearImpl; - - RowParallelLinear( + RowParallelLinearImpl( int64_t in_features, int64_t out_features, bool bias, @@ -91,35 +164,77 @@ class RowParallelLinear const QuantArgs& quant_args, const ParallelArgs& parallel_args, const torch::TensorOptions& options, - const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()) - : ModuleHolder( - std::make_shared(in_features, - out_features, - bias, - input_is_parallelized, - if_reduce_results, - quant_args, - parallel_args, - options, - linear_extra_args)) {} + const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()); + + torch::Tensor forward(torch::Tensor input); + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict); + + void pretty_print(std::ostream& stream) const { + stream << name() << " " << weight_.sizes() << " " << weight_.device(); + } + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features, in_features_per_partition] + DEFINE_WEIGHT(weight); + DEFINE_WEIGHT(qweight); + DEFINE_WEIGHT(per_channel_scale); + DEFINE_WEIGHT(smooth); + DEFINE_WEIGHT(bias); + + // whether the input is already parallelized + bool input_is_parallelized_; + + // whether to reduce the results + bool if_reduce_results_; + + // parallel args + ParallelArgs parallel_args_; + + int64_t rank_; + int64_t world_size_; + + // quantization args + QuantArgs quant_args_; + at::ScalarType output_dtype_; + FusedLinearExtraArgs linear_extra_args_; }; +TORCH_MODULE(RowParallelLinear); -class ReplicatedLinear : public torch::nn::ModuleHolder { +class ReplicatedLinearImpl : public torch::nn::Module { public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = ReplicatedLinearImpl; - - ReplicatedLinear(int64_t in_features, - int64_t out_features, - bool bias, - const QuantArgs& quant_args, - const torch::TensorOptions& options) - : ModuleHolder(std::make_shared(in_features, - out_features, - bias, - quant_args, - options)) {} + ReplicatedLinearImpl(int64_t in_features, + int64_t out_features, + bool bias, + const QuantArgs& quant_args, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor input); + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict); + + void pretty_print(std::ostream& stream) const { + stream << name() << " " << weight_.sizes() << " " << weight_.device(); + } + + // return the weight (for testing) + torch::Tensor weight() const { return weight_; } + + private: + // parameter members, must be registered + // we allocate the transpose since linear performs XA^T. + // A^T: [out_features, in_features] + DEFINE_WEIGHT(weight); + DEFINE_WEIGHT(bias); }; +TORCH_MODULE(ReplicatedLinear); } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/linear_impl.h b/xllm/core/layers/common/linear_impl.h deleted file mode 100644 index ad6baface..000000000 --- a/xllm/core/layers/common/linear_impl.h +++ /dev/null @@ -1,289 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include - -#include "framework/parallel_state/parallel_args.h" -#include "framework/quant_args.h" -#include "framework/state_dict/state_dict.h" -#include "framework/state_dict/utils.h" - -namespace xllm { -namespace layer { - -// extra args for fused linear operation -struct FusedLinearExtraArgs { - // parameters for fusing smooth quant activation mode and is_gated - std::string act_mode; - bool is_gated; - - // default constructor - FusedLinearExtraArgs(const std::string& act_mode_ = "none", - bool is_gated_ = false) - : act_mode(act_mode_), is_gated(is_gated_) {} -}; - -// an interface for parallel linear layer. -// all linear classes should inherit from this class and implement the forward -// function. -class ParallelLinearImpl : public torch::nn::Module { - public: - ~ParallelLinearImpl() override = default; - - virtual torch::Tensor forward(torch::Tensor input) = 0; - - virtual void load_state_dict(const StateDict& state_dict) = 0; - - virtual void verify_loaded_weights(const std::string& prefix = "") const = 0; - - // special load_state_dict for fused cases - virtual void load_state_dict(const StateDict& /*state_dict*/, - const std::vector& /*prefixes*/) { - LOG(FATAL) << "not implemented"; - } -}; - -// Linear layer with column parallelism. -// The linear layer is defined as Y = XA + b. A is parallelized along -// its second dimension as A = [A_1, ..., A_p]. -class ColumnParallelLinearImpl : public ParallelLinearImpl { - public: - ColumnParallelLinearImpl( - int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // special load_state_dict for fused cases - void load_state_dict(const StateDict& state_dict, - const std::vector& prefixes) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - bool is_weight_loaded() const { return weight_is_loaded_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features_per_partition, in_features] - DEFINE_FUSED_WEIGHT(weight); - DEFINE_FUSED_WEIGHT(qweight); - DEFINE_FUSED_WEIGHT(per_channel_scale); - DEFINE_WEIGHT(smooth); - DEFINE_FUSED_WEIGHT(bias); - - int64_t rank_; - int64_t world_size_; - // whether to gather the output - bool gather_output_; - at::Device device_; - // parallel args - ParallelArgs parallel_args_; - - // quantization args - QuantArgs quant_args_; - at::ScalarType output_dtype_; - FusedLinearExtraArgs linear_extra_args_; -}; - -class QKVParallelLinearImpl : public ParallelLinearImpl { - public: - QKVParallelLinearImpl(int64_t hidden_size, - int64_t num_heads, - int64_t num_kv_heads, - int64_t head_size, - int64_t num_kv_head_replicas, - bool bias, - bool gather_output, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - bool load_qkv_weight(const StateDict& state_dict, int32_t index); - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const override { - CHECK(qkv_weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!qkv_bias_.defined() || qkv_bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight().sizes() << " " << weight().device(); - } - - // return the weight (for testing) - torch::Tensor weight() const { return qkv_weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features_per_partition, in_features] - DEFINE_FUSED_WEIGHT(qkv_weight); - DEFINE_FUSED_WEIGHT(qkv_bias); - - int64_t rank_; - int64_t world_size_; - int64_t hidden_size_; - int64_t num_heads_; - int64_t num_kv_heads_; - int64_t head_size_; - int64_t num_kv_head_replicas_; - // whether to gather the output - bool gather_output_; - at::Device device_; - // parallel args - ParallelArgs parallel_args_; - torch::TensorOptions options_; -}; - -// Linear layer with row parallelism. -// The linear layer is defined as Y = XA + b. A is parallelized along -// its first dimension and X along its second dimension as: -// - - -// | A_1 | -// | . | -// A = | . | X = [X_1, ..., X_p] -// | . | -// | A_p | -// - - -class RowParallelLinearImpl : public ParallelLinearImpl { - public: - RowParallelLinearImpl( - int64_t in_features, - int64_t out_features, - bool bias, - bool input_is_parallelized, - bool if_reduce_results, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - const FusedLinearExtraArgs& linear_extra_args = FusedLinearExtraArgs()); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features, in_features_per_partition] - DEFINE_WEIGHT(weight); - DEFINE_WEIGHT(qweight); - DEFINE_WEIGHT(per_channel_scale); - DEFINE_WEIGHT(smooth); - DEFINE_WEIGHT(bias); - - // whether the input is already parallelized - bool input_is_parallelized_; - - // whether to reduce the results - bool if_reduce_results_; - - // parallel args - ParallelArgs parallel_args_; - - int64_t rank_; - int64_t world_size_; - - // quantization args - QuantArgs quant_args_; - at::ScalarType output_dtype_; - FusedLinearExtraArgs linear_extra_args_; -}; - -class ReplicatedLinearImpl : public ParallelLinearImpl { - public: - ReplicatedLinearImpl(int64_t in_features, - int64_t out_features, - bool bias, - const QuantArgs& quant_args, - const torch::TensorOptions& options); - - torch::Tensor forward(torch::Tensor input) override; - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) override; - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix = "") const override { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - CHECK(!bias_.defined() || bias_is_loaded_) - << "bias is not loaded for " << prefix + "bias"; - } - - void pretty_print(std::ostream& stream) const override { - stream << name() << " " << weight_.sizes() << " " << weight_.device(); - } - - // return the weight (for testing) - torch::Tensor weight() const { return weight_; } - - private: - // parameter members, must be registered - // we allocate the transpose since linear performs XA^T. - // A^T: [out_features, in_features] - DEFINE_WEIGHT(weight); - DEFINE_WEIGHT(bias); -}; - -} // namespace layer -} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/common/qwen3_attention.cpp b/xllm/core/layers/common/qwen2_attention.cpp similarity index 80% rename from xllm/core/layers/common/qwen3_attention.cpp rename to xllm/core/layers/common/qwen2_attention.cpp index d3d087686..239cc8c7c 100644 --- a/xllm/core/layers/common/qwen3_attention.cpp +++ b/xllm/core/layers/common/qwen2_attention.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "qwen3_attention.h" +#include "qwen2_attention.h" #include @@ -22,14 +22,18 @@ limitations under the License. namespace xllm { namespace layer { -Qwen3AttentionImpl::Qwen3AttentionImpl(const ModelArgs& args, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { +Qwen2AttentionImpl::Qwen2AttentionImpl(const ModelContext& context) { + const auto& args = context.get_model_args(); + const auto& quant_args = context.get_quant_args(); + const auto& parallel_args = context.get_parallel_args(); + const auto& options = context.get_tensor_options(); const int64_t tp_size = parallel_args.tp_group_->world_size(); const int64_t total_num_heads = args.n_heads(); const int64_t total_num_kv_heads = args.n_kv_heads().value_or(args.n_heads()); + is_qwen3_style_ = + (args.model_type() == "qwen3" || args.model_type() == "qwen3_moe"); + CHECK(total_num_heads % tp_size == 0); num_heads_ = total_num_heads / tp_size; @@ -55,7 +59,7 @@ Qwen3AttentionImpl::Qwen3AttentionImpl(const ModelArgs& args, num_kv_heads_, args.head_dim(), num_kv_head_replicas_, - /*bias=*/false, + args.attention_bias(), /*gather_output=*/false, parallel_args, options)); @@ -72,11 +76,13 @@ Qwen3AttentionImpl::Qwen3AttentionImpl(const ModelArgs& args, options)); // 3. RMSNorm - q_norm_ = register_module( - "q_norm", RmsNorm(args.head_dim(), args.rms_norm_eps(), options)); + if (is_qwen3_style_) { + q_norm_ = register_module( + "q_norm", RmsNorm(args.head_dim(), args.rms_norm_eps(), options)); - k_norm_ = register_module( - "k_norm", RmsNorm(args.head_dim(), args.rms_norm_eps(), options)); + k_norm_ = register_module( + "k_norm", RmsNorm(args.head_dim(), args.rms_norm_eps(), options)); + } // 4. Rotary embedding rotary_emb_ = register_module("rope", @@ -95,7 +101,7 @@ Qwen3AttentionImpl::Qwen3AttentionImpl(const ModelArgs& args, args.sliding_window())); } -torch::Tensor Qwen3AttentionImpl::forward( +torch::Tensor Qwen2AttentionImpl::forward( const torch::Tensor& positions, const torch::Tensor& hidden_states, const AttentionMetadata& attn_metadata, @@ -109,11 +115,13 @@ torch::Tensor Qwen3AttentionImpl::forward( const int64_t T = q.size(0); - // 2. q-norm - q = q_norm_->forward(q); + if (is_qwen3_style_) { + // 2. q-norm + q = q_norm_->forward(q); - // 3. k-norm - k = k_norm_->forward(k); + // 3. k-norm + k = k_norm_->forward(k); + } // 4. rope rotary_emb_->forward(q, @@ -132,14 +140,12 @@ torch::Tensor Qwen3AttentionImpl::forward( return o_proj_->forward(out); } -void Qwen3AttentionImpl::load_state_dict(const StateDict& state_dict) { +void Qwen2AttentionImpl::load_state_dict(const StateDict& state_dict) { qkv_proj_->load_state_dict(state_dict); o_proj_->load_state_dict(state_dict.get_dict_with_prefix("o_proj.")); - if (auto w = state_dict.get_tensor("q_norm.weight"); w.defined()) { - q_norm_->load_state_dict(StateDict({{"weight", w}})); - } - if (auto w = state_dict.get_tensor("k_norm.weight"); w.defined()) { - k_norm_->load_state_dict(StateDict({{"weight", w}})); + if (is_qwen3_style_) { + q_norm_->load_state_dict(state_dict.get_dict_with_prefix("q_norm.")); + k_norm_->load_state_dict(state_dict.get_dict_with_prefix("k_norm.")); } } diff --git a/xllm/core/layers/common/qwen3_attention.h b/xllm/core/layers/common/qwen2_attention.h similarity index 84% rename from xllm/core/layers/common/qwen3_attention.h rename to xllm/core/layers/common/qwen2_attention.h index 6b2bc2ba2..8a2940f83 100644 --- a/xllm/core/layers/common/qwen3_attention.h +++ b/xllm/core/layers/common/qwen2_attention.h @@ -30,13 +30,10 @@ limitations under the License. namespace xllm { namespace layer { -class Qwen3AttentionImpl : public torch::nn::Module { +class Qwen2AttentionImpl : public torch::nn::Module { public: - Qwen3AttentionImpl() = default; - Qwen3AttentionImpl(const ModelArgs& args, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options); + Qwen2AttentionImpl() = default; + Qwen2AttentionImpl(const ModelContext& context); torch::Tensor forward(const torch::Tensor& positions, const torch::Tensor& hidden_states, @@ -53,6 +50,7 @@ class Qwen3AttentionImpl : public torch::nn::Module { int64_t q_size_; int64_t kv_size_; float scaling_; + bool is_qwen3_style_; QKVParallelLinear qkv_proj_{nullptr}; RowParallelLinear o_proj_{nullptr}; @@ -61,7 +59,7 @@ class Qwen3AttentionImpl : public torch::nn::Module { Attention attn_{nullptr}; RotaryEmbedding rotary_emb_{nullptr}; }; -TORCH_MODULE(Qwen3Attention); +TORCH_MODULE(Qwen2Attention); } // namespace layer } // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/common/qwen3_decoder_layer.cpp b/xllm/core/layers/common/qwen2_decoder_layer.cpp similarity index 86% rename from xllm/core/layers/common/qwen3_decoder_layer.cpp rename to xllm/core/layers/common/qwen2_decoder_layer.cpp index 51989714e..da4657c45 100644 --- a/xllm/core/layers/common/qwen3_decoder_layer.cpp +++ b/xllm/core/layers/common/qwen2_decoder_layer.cpp @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "qwen3_decoder_layer.h" +#include "qwen2_decoder_layer.h" #include @@ -22,7 +22,7 @@ limitations under the License. namespace xllm { namespace layer { -Qwen3DecoderImpl::Qwen3DecoderImpl(const ModelContext& context) +Qwen2DecoderImpl::Qwen2DecoderImpl(const ModelContext& context) : parallel_args_(context.get_parallel_args()) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); @@ -30,9 +30,7 @@ Qwen3DecoderImpl::Qwen3DecoderImpl(const ModelContext& context) const auto& options = context.get_tensor_options(); // Initialize attention layers - attention_ = register_module( - "self_attn", - Qwen3Attention(model_args, quant_args, parallel_args, options)); + attention_ = register_module("self_attn", Qwen2Attention(context)); // Initialize norm layers input_norm_ = register_module( @@ -55,7 +53,7 @@ Qwen3DecoderImpl::Qwen3DecoderImpl(const ModelContext& context) options)); } -void Qwen3DecoderImpl::load_state_dict(const StateDict& state_dict) { +void Qwen2DecoderImpl::load_state_dict(const StateDict& state_dict) { attention_->load_state_dict(state_dict.get_dict_with_prefix("self_attn.")); input_norm_->load_state_dict( state_dict.get_dict_with_prefix("input_layernorm.")); @@ -64,15 +62,11 @@ void Qwen3DecoderImpl::load_state_dict(const StateDict& state_dict) { mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp.")); } -torch::Tensor Qwen3DecoderImpl::forward(torch::Tensor& x, +torch::Tensor Qwen2DecoderImpl::forward(torch::Tensor& x, torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, const ModelInputParams& input_params) { - bool is_dummy = is_dummy_run(input_params, parallel_args_); - if (is_dummy) { - return x; - } // Pre-attention norm auto residual = x; x = input_norm_->forward(x); diff --git a/xllm/core/layers/common/qwen3_decoder_layer.h b/xllm/core/layers/common/qwen2_decoder_layer.h similarity index 89% rename from xllm/core/layers/common/qwen3_decoder_layer.h rename to xllm/core/layers/common/qwen2_decoder_layer.h index f5c8cc26d..dc4b80406 100644 --- a/xllm/core/layers/common/qwen3_decoder_layer.h +++ b/xllm/core/layers/common/qwen2_decoder_layer.h @@ -29,16 +29,16 @@ limitations under the License. #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" #include "layers/rms_norm.h" -#include "qwen3_attention.h" +#include "qwen2_attention.h" namespace xllm { namespace layer { -class Qwen3DecoderImpl : public torch::nn::Module { +class Qwen2DecoderImpl : public torch::nn::Module { public: - explicit Qwen3DecoderImpl(const ModelContext& context); + explicit Qwen2DecoderImpl(const ModelContext& context); - ~Qwen3DecoderImpl() {}; + ~Qwen2DecoderImpl() {}; void load_state_dict(const StateDict& state_dict); @@ -49,7 +49,7 @@ class Qwen3DecoderImpl : public torch::nn::Module { const ModelInputParams& input_params); private: - Qwen3Attention attention_{nullptr}; + Qwen2Attention attention_{nullptr}; DenseMLP mlp_{nullptr}; RmsNorm input_norm_{nullptr}; RmsNorm post_norm_{nullptr}; diff --git a/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp b/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp index ba261a5c4..73dca35fd 100644 --- a/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp +++ b/xllm/core/layers/common/qwen3_moe_decoder_layer.cpp @@ -28,9 +28,7 @@ Qwen3MoeDecoderImpl::Qwen3MoeDecoderImpl(const ModelContext& context, const auto& options = context.get_tensor_options(); // Initialize attention layers - attention_ = register_module( - "self_attn", - Qwen3Attention(model_args, quant_args, parallel_args, options)); + attention_ = register_module("self_attn", Qwen2Attention(context)); // Initialize norm layers input_norm_ = register_module( diff --git a/xllm/core/layers/common/qwen3_moe_decoder_layer.h b/xllm/core/layers/common/qwen3_moe_decoder_layer.h index 448956293..71bbc955b 100644 --- a/xllm/core/layers/common/qwen3_moe_decoder_layer.h +++ b/xllm/core/layers/common/qwen3_moe_decoder_layer.h @@ -28,7 +28,7 @@ limitations under the License. #include "framework/state_dict/state_dict.h" #include "fused_moe.h" #include "layers/rms_norm.h" -#include "qwen3_attention.h" +#include "qwen2_attention.h" namespace xllm { namespace layer { @@ -48,7 +48,7 @@ class Qwen3MoeDecoderImpl : public torch::nn::Module { const ModelInputParams& input_params); private: - Qwen3Attention attention_{nullptr}; + Qwen2Attention attention_{nullptr}; DenseMLP mlp_{nullptr}; FusedMoE moe_mlp_{nullptr}; RmsNorm input_norm_{nullptr}; diff --git a/xllm/core/layers/common/word_embedding_impl.cpp b/xllm/core/layers/common/word_embedding_impl.cpp new file mode 100644 index 000000000..7acb70403 --- /dev/null +++ b/xllm/core/layers/common/word_embedding_impl.cpp @@ -0,0 +1,60 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "word_embedding_impl.h" + +namespace xllm { +namespace layer { + +WordEmbeddingImpl::WordEmbeddingImpl(int64_t num_embeddings, + int64_t embedding_dim, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) + : parallel_args_(parallel_args) { + rank_ = parallel_args_.tp_group_->rank(); + world_size_ = parallel_args_.tp_group_->world_size(); + + CHECK(embedding_dim % world_size_ == 0) + << "out_features " << embedding_dim << " not divisible by world_size " + << world_size_; + const int64_t embedding_dim_per_partition = embedding_dim / world_size_; + + // register the weight parameter + weight_ = register_parameter( + "weight", + torch::empty({num_embeddings, embedding_dim_per_partition}, options), + /*requires_grad=*/false); +} + +// The input to the module is a list of indices, and the output is the +// corresponding word embeddings. +torch::Tensor WordEmbeddingImpl::forward(torch::Tensor input) { + namespace F = torch::nn::functional; + auto output = F::embedding(input, weight_); + if (world_size_ > 1) { + output = xllm::parallel_state::gather(output, parallel_args_.tp_group_); + } + return output; +} + +// load the weight from the checkpoint +void WordEmbeddingImpl::load_state_dict(const StateDict& state_dict) { + const int64_t rank = rank_; + const int64_t world_size = world_size_; + LOAD_SHARDED_WEIGHT(weight, 1); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/word_embedding_impl.h b/xllm/core/layers/common/word_embedding_impl.h index 948c1afea..a749c8e6a 100644 --- a/xllm/core/layers/common/word_embedding_impl.h +++ b/xllm/core/layers/common/word_embedding_impl.h @@ -23,6 +23,7 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/parallel_state/parallel_state.h" #include "framework/state_dict/state_dict.h" +#include "framework/state_dict/utils.h" namespace xllm { namespace layer { @@ -33,54 +34,14 @@ class WordEmbeddingImpl : public torch::nn::Module { WordEmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim, const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : parallel_args_(parallel_args) { - rank_ = parallel_args_.tp_group_->rank(); - world_size_ = parallel_args_.tp_group_->world_size(); - - CHECK(embedding_dim % world_size_ == 0) - << "out_features " << embedding_dim << " not divisible by world_size " - << world_size_; - const int64_t embedding_dim_per_partition = embedding_dim / world_size_; - - // register the weight parameter - weight_ = register_parameter( - "weight", - torch::empty({num_embeddings, embedding_dim_per_partition}, options), - /*requires_grad=*/false); - } + const torch::TensorOptions& options); // The input to the module is a list of indices, and the output is the // corresponding word embeddings. - torch::Tensor forward(torch::Tensor input) { - namespace F = torch::nn::functional; - auto output = F::embedding(input, weight_); - if (world_size_ > 1) { - output = xllm::parallel_state::gather(output, parallel_args_.tp_group_); - } - return output; - } + torch::Tensor forward(torch::Tensor input); // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - const auto weight = - state_dict.get_sharded_tensor("weight", - /*dim=*/1, - /*rank=*/rank_, - /*world_size=*/world_size_); - if (weight.defined()) { - CHECK_EQ(weight_.sizes(), weight.sizes()) - << "weight size mismatch for " << name(); - weight_.copy_(weight); - weight_is_loaded_ = true; - } - } - - // whether the weight is loaded - void verify_loaded_weights(const std::string& prefix) const { - CHECK(weight_is_loaded_) - << "weight is not loaded for " << prefix + "weight"; - } + void load_state_dict(const StateDict& state_dict); void pretty_print(std::ostream& stream) const override { stream << name() << " " << weight_.sizes() << " " << weight_.device(); diff --git a/xllm/core/layers/lm_head.h b/xllm/core/layers/lm_head.h index 3b6fcd491..2a0ef3495 100644 --- a/xllm/core/layers/lm_head.h +++ b/xllm/core/layers/lm_head.h @@ -18,8 +18,9 @@ limitations under the License. #if defined(USE_NPU) #include "npu/npu_lm_head_impl.h" #else -#include "common/linear_impl.h" +#include "common/linear.h" #endif +#include "core/framework/model_context.h" namespace xllm { namespace layer { @@ -39,20 +40,15 @@ class LmHead : public torch::nn::ModuleHolder { using torch::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl; - LmHead(int64_t in_features, - int64_t out_features, - bool bias, - bool gather_output, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) - : ModuleHolder(std::make_shared(in_features, - out_features, - bias, - gather_output, - quant_args, - parallel_args, - options)) {} + LmHead(const ModelContext& context) + : ModuleHolder(std::make_shared( + context.get_model_args().hidden_size(), + context.get_model_args().vocab_size(), + /*bias=*/false, + /*gather_output=*/true, + QuantArgs{}, + context.get_parallel_args(), + context.get_tensor_options())) {} }; #endif diff --git a/xllm/core/layers/qwen2_decoder_layer.h b/xllm/core/layers/qwen2_decoder_layer.h index 52a6e924d..a8b466bb9 100644 --- a/xllm/core/layers/qwen2_decoder_layer.h +++ b/xllm/core/layers/qwen2_decoder_layer.h @@ -17,6 +17,8 @@ limitations under the License. #if defined(USE_NPU) #include "npu/npu_qwen2_decoder_layer_impl.h" +#elif defined(USE_MLU) +#include "common/qwen2_decoder_layer.h" #endif namespace xllm { @@ -32,6 +34,15 @@ class Qwen2DecoderLayer Qwen2DecoderLayer(const ModelContext& context) : ModuleHolder(std::make_shared(context)) {} }; +#elif defined(USE_MLU) +class Qwen2DecoderLayer : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = Qwen2DecoderImpl; + + Qwen2DecoderLayer(const ModelContext& context) + : ModuleHolder(std::make_shared(context)) {} +}; #endif } // namespace layer diff --git a/xllm/core/layers/qwen3_decoder_layer.h b/xllm/core/layers/qwen3_decoder_layer.h index 324738d55..ae1e33a1a 100644 --- a/xllm/core/layers/qwen3_decoder_layer.h +++ b/xllm/core/layers/qwen3_decoder_layer.h @@ -18,7 +18,7 @@ limitations under the License. #if defined(USE_NPU) #include "npu/npu_qwen3_decoder_layer_impl.h" #else -#include "common/qwen3_decoder_layer.h" +#include "common/qwen2_decoder_layer.h" #endif namespace xllm { @@ -35,13 +35,13 @@ class Qwen3DecoderLayer : ModuleHolder(std::make_shared(context)) {} }; #else -class Qwen3DecoderLayer : public torch::nn::ModuleHolder { +class Qwen3DecoderLayer : public torch::nn::ModuleHolder { public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = Qwen3DecoderImpl; + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = Qwen2DecoderImpl; Qwen3DecoderLayer(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} + : ModuleHolder(std::make_shared(context)) {} }; #endif diff --git a/xllm/core/layers/rms_norm.h b/xllm/core/layers/rms_norm.h index d8920c680..a9a2f3c9d 100644 --- a/xllm/core/layers/rms_norm.h +++ b/xllm/core/layers/rms_norm.h @@ -19,6 +19,7 @@ limitations under the License. #else #include "common/fuse_norm.h" #endif +#include "core/framework/model_context.h" namespace xllm { namespace layer { @@ -40,6 +41,11 @@ class RmsNorm : public torch::nn::ModuleHolder { RmsNorm(int64_t dim, double eps, const torch::TensorOptions& options) : ModuleHolder(std::make_shared(dim, eps, options)) {} + RmsNorm(const ModelContext& context) + : ModuleHolder(std::make_shared( + context.get_model_args().hidden_size(), + context.get_model_args().rms_norm_eps(), + context.get_tensor_options())) {} }; #endif diff --git a/xllm/core/layers/word_embedding.h b/xllm/core/layers/word_embedding.h index c377dcc2a..d7ba3da38 100644 --- a/xllm/core/layers/word_embedding.h +++ b/xllm/core/layers/word_embedding.h @@ -20,6 +20,7 @@ limitations under the License. #else #include "common/word_embedding_impl.h" #endif +#include "core/framework/model_context.h" namespace xllm { namespace layer { @@ -39,6 +40,7 @@ class WordEmbedding : public torch::nn::ModuleHolder { public: using torch::nn::ModuleHolder::ModuleHolder; using Impl __attribute__((__unused__)) = WordEmbeddingImpl; + WordEmbedding(int64_t num_embeddings, int64_t embedding_dim, const ParallelArgs& parallel_args, @@ -47,6 +49,12 @@ class WordEmbedding : public torch::nn::ModuleHolder { embedding_dim, parallel_args, options)) {} + WordEmbedding(const ModelContext& context) + : ModuleHolder(std::make_shared( + context.get_model_args().vocab_size(), + context.get_model_args().hidden_size(), + context.get_parallel_args(), + context.get_tensor_options())) {} }; #endif diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index b8aa9705e..99cdecbaf 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -40,6 +40,7 @@ limitations under the License. #include "xllm_kernels/core/include/atb_speed/log.h" #else #include "core/layers/common/attention.h" +#include "core/layers/common/layer_utils.h" #endif namespace xllm { @@ -306,15 +307,18 @@ class LlmModelImplBase : public torch::nn::Module { auto cancated_h = torch::cat(hs, 0); return norm_(cancated_h, 0); #else - bool is_prefill = input_params[0].q_max_seq_len > 1; + auto modified_input_params = input_params[0]; + auto position = positions[0]; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = - layer::AttentionMetadata::build(input_params[0], is_prefill); + layer::AttentionMetadata::build(modified_input_params, is_prefill); torch::Tensor h; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer( - hs[0], positions[0], attn_metadata, kv_caches[i], input_params[0]); + hs[0], position, attn_metadata, kv_caches[i], modified_input_params); } return norm_(h); #endif @@ -369,13 +373,14 @@ class LlmModelImplBase : public torch::nn::Module { } protected: -#if defined(USE_NPU) - torch::Tensor cos_pos_; - torch::Tensor sin_pos_; torch::Tensor cos_sin_; int max_seq_len_ = 0; + torch::Tensor cos_pos_; + torch::Tensor sin_pos_; int device_id = 0; layer::AttentionMask attn_mask_; + int dp_rank_ = 0; +#if defined(USE_NPU) std::vector atb_pos_embeds_; #endif @@ -403,20 +408,7 @@ class LlmForCausalLMImplBase : public torch::nn::Module { // register submodules model_ = register_module("model", LlmModelType(context)); -#if defined(USE_NPU) lm_head_ = register_module("lm_head", layer::LmHead(context)); -#else - // lm_head_ is default to no quantization - lm_head_ = - register_module("lm_head", - layer::LmHead(context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*gather_output=*/true, - QuantArgs{}, - context.get_parallel_args(), - context.get_tensor_options())); -#endif } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index c510471cd..b137e9939 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -39,13 +39,19 @@ class QWen2ModelImpl : public LlmModelImplBase { // register submodules auto model_args = context.get_model_args(); auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); norm_ = register_module("norm", layer::RmsNorm(context)); for (auto i = 0; i < FLAGS_micro_batch_num; i++) { embed_tokens_.push_back(layer::WordEmbedding(context)); +#if defined(USE_NPU) atb_pos_embeds_.push_back(layer::PosEmbedding(context)); +#endif } cos_sin_ = get_concat_rotary_embedding( model_args.hidden_size() / model_args.n_heads(), @@ -87,6 +93,8 @@ REGISTER_MODEL_ARGS(qwen2, [&] { LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); LOAD_ARG_OR(n_heads, "num_attention_heads", 28); LOAD_ARG(n_kv_heads, "num_key_value_heads"); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(attention_bias, "attention_bias", true); // LOAD_ARG_OR(no_bias, "no_bias", true); LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h old mode 100755 new mode 100644 index 4d35d9c56..fbbffbae7 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -35,19 +35,25 @@ class QWen3ModelImpl : public LlmModelImplBase { // register submodules auto model_args = context.get_model_args(); auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); -#if defined(USE_NPU) norm_ = register_module("norm", layer::RmsNorm(context)); for (auto i = 0; i < FLAGS_micro_batch_num; i++) { embed_tokens_.push_back(layer::WordEmbedding(context)); +#if defined(USE_NPU) atb_pos_embeds_.push_back(layer::PosEmbedding(context)); +#endif } cos_sin_ = get_concat_rotary_embedding(128, model_args.max_position_embeddings(), model_args.rope_theta(), options); +#if defined(USE_NPU) int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; // encode_attn_mask_ = // layer::AttentionMask(options.device(), @@ -56,17 +62,6 @@ class QWen3ModelImpl : public LlmModelImplBase { attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), /*mask_value=*/mask_value); -#else - norm_ = register_module( - "norm", - layer::RmsNorm( - model_args.hidden_size(), model_args.rms_norm_eps(), options)); - for (auto i = 0; i < FLAGS_micro_batch_num; i++) { - embed_tokens_.push_back(layer::WordEmbedding(model_args.vocab_size(), - model_args.hidden_size(), - context.get_parallel_args(), - options)); - } #endif for (int32_t i = 0; i < model_args.n_layers(); i++) { @@ -226,15 +221,18 @@ class QWen3ModelImpl : public LlmModelImplBase { auto cancated_h = torch::cat(hs, 0); return norm_(cancated_h, 0); #else - bool is_prefill = input_params[0].q_max_seq_len > 1; + auto modified_input_params = input_params[0]; + auto position = positions[0]; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = - layer::AttentionMetadata::build(input_params[0], is_prefill); + layer::AttentionMetadata::build(modified_input_params, is_prefill); torch::Tensor h; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer( - hs[0], positions[0], attn_metadata, kv_caches[i], input_params[0]); + hs[0], position, attn_metadata, kv_caches[i], modified_input_params); } return norm_(h); #endif diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 0ab429bed..cfd3dbcd4 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -145,11 +145,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module { device_ = options.device(); dtype_ = options.dtype().toScalarType(); num_speculative_tokens_ = model_args.num_speculative_tokens(); -#if defined(USE_NPU) embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(context)); - atb_pos_emb_ = layer::PosEmbedding(context); cos_sin_ = get_qwen3_moe_rotary_embedding(128, model_args.max_position_embeddings(), @@ -157,24 +155,16 @@ class Qwen3MoeModelImpl : public torch::nn::Module { options); max_seq_len_ = model_args.max_position_embeddings(); +#if defined(USE_NPU) + atb_pos_emb_ = layer::PosEmbedding(context); int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), /*mask_value=*/mask_value); +#endif norm_ = register_module("norm", layer::RmsNorm(context)); mapping_data_ = parallel_args.mapping_data(); -#else - norm_ = register_module( - "norm", - layer::RmsNorm( - model_args.hidden_size(), model_args.rms_norm_eps(), options)); - embed_tokens_ = - register_module("embed_tokens", - layer::WordEmbedding(model_args.vocab_size(), - model_args.hidden_size(), - context.get_parallel_args(), - options)); -#endif + for (int32_t i = 0; i < model_args.n_layers(); ++i) { auto block = Qwen3MoeDecoderLayer(context, i); layers_.push_back(block); @@ -368,8 +358,8 @@ class Qwen3MoeModelImpl : public torch::nn::Module { layer::WordEmbedding embed_tokens_{nullptr}; layer::AttentionMask attn_mask_; layer::RmsNorm norm_{nullptr}; -#if defined(USE_NPU) torch::Tensor cos_sin_; +#if defined(USE_NPU) layer::PosEmbedding atb_pos_emb_{nullptr}; #endif std::vector mrope_section_; @@ -380,20 +370,7 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { public: Qwen3MoeForCausalLMImpl(const ModelContext& context) { model_ = register_module("model", Qwen3MoeModel(context)); -#if defined(USE_NPU) lm_head_ = register_module("lm_head", layer::LmHead(context)); -#else - // lm_head_ is default to no quantization - lm_head_ = - register_module("lm_head", - layer::LmHead(context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*gather_output=*/true, - QuantArgs{}, - context.get_parallel_args(), - context.get_tensor_options())); -#endif } // tokens: [num_tokens] @@ -508,4 +485,4 @@ REGISTER_MODEL_ARGS(qwen3_moe, [&] { SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); }); -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/models/models.h b/xllm/models/models.h index 161a4c318..c2f0e0503 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -15,6 +15,11 @@ limitations under the License. #pragma once +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3.h" // IWYU pragma: keep +#include "llm/qwen3_moe.h" // IWYU pragma: keep + #if defined(USE_NPU) #include "dit/pipeline_flux.h" // IWYU pragma: keep #include "dit/pipeline_flux_control.h" // IWYU pragma: keep @@ -27,15 +32,9 @@ limitations under the License. #include "llm/kimi_k2.h" // IWYU pragma: keep #include "llm/llama.h" // IWYU pragma: keep #include "llm/llama3.h" // IWYU pragma: keep -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep #include "llm/qwen3_embedding.h" // IWYU pragma: keep #include "vlm/minicpmv.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #endif - -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen3.h" // IWYU pragma: keep -#include "llm/qwen3_moe.h" // IWYU pragma: keep