From 82d1febf4fe92a3f1962047f41f64852e73abc99 Mon Sep 17 00:00:00 2001 From: limenxin1 Date: Thu, 21 Aug 2025 11:59:55 +0800 Subject: [PATCH] feat: optimize the groupmatmul operator. --- third_party/xllm_ops | 2 +- .../aclnn/ops/grouped_matmul_operation.cpp | 76 +++++++++++++++---- .../aclnn/ops/grouped_matmul_operation.h | 1 + .../ascend/deepseek_v2_decoder_layer.cpp | 8 +- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 73746aba..a5b84ac9 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 73746aba666c24fa30aebb45cf8be09955e4342e +Subproject commit a5b84ac9bfe0a0c69990e457d9bc756a67a2ce4d diff --git a/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.cpp b/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.cpp index f69f24fe..4c94d025 100755 --- a/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.cpp +++ b/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.cpp @@ -25,6 +25,7 @@ #include "atb_speed/log.h" #include "atb_speed/utils/timer.h" #include "aclnnop/aclnn_grouped_matmul_v4.h" +#include "aclnn_index_group_matmul.h" #include "operations/aclnn/utils/utils.h" #include "atb_speed/utils/check_util.h" @@ -50,7 +51,7 @@ atb::Status GroupedMatmulOperation::InferShape( ACL_BF16 : ACL_FLOAT16; outTensorDescs.at(DIM0).shape.dimNum = inTensorDescs.at(DIM0).shape.dimNum; - int nDim = param_.transposeB ? DIM1 : DIM2; + int nDim = DIM2; ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape origin inTensorDescs.at(DIM0).shape.dims[DIM0]" << inTensorDescs.at(DIM0).shape.dims[DIM0]); ATB_SPEED_LOG_DEBUG(opName_ << "GroupedMatmulOperation infer shape origin inTensorDescs.at(DIM1).shape.dims[nDim]" @@ -95,9 +96,9 @@ atb::Dims GetWeightStorageShape(const atb::TensorDesc atbTensorDesc) // (group_size, n, k) => (group_size, n / 16, k / 16, 16, 16) storageTensorDims.dims[0] = atbTensorDesc.shape.dims[0]; storageTensorDims.dims[1] = 1 + ((atbTensorDesc.shape.dims[1] - 1) / 16); // 1, 16:1: 维度, 16: padding大小 - storageTensorDims.dims[2] = 1 + ((atbTensorDesc.shape.dims[2] - 1) / 16); // 2, 16:1: 维度, 16: padding大小 + storageTensorDims.dims[2] = 1 + ((atbTensorDesc.shape.dims[2] - 1) / 32); // 2, 16:1: 维度, 16: padding大小 storageTensorDims.dims[3] = 16; // 3, 16:NZ格式要求 - storageTensorDims.dims[4] = 16; // 4, 16:NZ格式要求 + storageTensorDims.dims[4] = 32; // 4, 16:NZ格式要求 } return storageTensorDims; } @@ -124,10 +125,18 @@ atb::Status GroupedMatmulOperation::CreateAclNNInTensorVariantPack(const atb::Va AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; aclnnVariantPack.aclInTensors.resize(GetInputNum()); uint32_t inTensorCount = aclnnVariantPack.aclInTensors.size(); + atb::Tensor x = variantPack.inTensors.at(DIM0); + atb::Tensor y = variantPack.inTensors.at(DIM1); + int index = 7; + if(!(x.desc.shape.dims[0] < 2048 && x.desc.shape.dims[0] > 8 && y.desc.shape.dims[2]==7168)){ + index = 7; + }else{ + index = 4; + } for (size_t i = 0; i < inTensorCount; i++) { std::shared_ptr aclnnTensor = std::make_shared(); if (i == inTensorCount - 1) { - aclnnTensor->tensorIdx = 7; // 7 : for the last tensor + aclnnTensor->tensorIdx = index; // 7 : for the last tensor } else { aclnnTensor->tensorListidx = i; aclnnTensor->tensorIdx = 0; @@ -154,16 +163,16 @@ atb::Status GroupedMatmulOperation::CreateAclNNInTensorVariantPack(const atb::Va storageTensorDims = GetWeightStorageShape(squeezedAtbTensor.desc); } - // ViewShape and Stride + atb::Dims viewDims = squeezedAtbTensor.desc.shape; - if (squeezedAtbTensor.desc.shape.dimNum >= 3 && this->param_.transposeB) { // 3: 维度 - aclnnTensor->strides = GetTransposeTensorStride(viewDims); - viewDims.dims[0] = squeezedAtbTensor.desc.shape.dims[0]; - viewDims.dims[1] = squeezedAtbTensor.desc.shape.dims[2]; // 1, 2: 后两维转置 - viewDims.dims[2] = squeezedAtbTensor.desc.shape.dims[1]; // 1, 2: 后两维转置 - } else { + // if (squeezedAtbTensor.desc.shape.dimNum >= 3 && this->param_.transposeB) { // 3: 维度 + // aclnnTensor->strides = GetTransposeTensorStride(viewDims); + // viewDims.dims[0] = squeezedAtbTensor.desc.shape.dims[0]; + // viewDims.dims[1] = squeezedAtbTensor.desc.shape.dims[2]; // 1, 2: 后两维转置 + // viewDims.dims[2] = squeezedAtbTensor.desc.shape.dims[1]; // 1, 2: 后两维转置 + // } else { aclnnTensor->strides = GetCopyTensorStride(viewDims); - } + // } CHECK_OPERATION_STATUS_RETURN(CallAclCreateTensor(viewDims, storageTensorDims, squeezedAtbTensor, aclnnTensor)); aclnnVariantPack.aclInTensors[i] = aclnnTensor; @@ -266,9 +275,26 @@ int GroupedMatmulOperation::CreateA16(AclNNVariantPack &aclnnVariantPack) &this->aclnnOpCache_->aclExecutor); return ret; } +int GroupedMatmulOperation::IndexGmmQuant(AclNNVariantPack &aclnnVariantPack) +{ + int ret = aclnnIndexGroupMatmulGetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), + aclnnVariantPack.aclInTensorList.at(DIM1), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM3) : + aclnnVariantPack.aclInTensorList.at(DIM2), + param_.hasBias ? aclnnVariantPack.aclInTensorList.at(4) : // 5 : index of input tensor + aclnnVariantPack.aclInTensorList.at(3), // 4 : index of input tensor + param_.hasBias ? aclnnVariantPack.aclInTensors.at(5)->tensor : // 6 : index of input tensor + aclnnVariantPack.aclInTensors.at(4)->tensor, // 5 : index of input tensor + aclnnVariantPack.aclOutTensorList.at(DIM0), + &this->aclnnOpCache_->workspaceSize, + &this->aclnnOpCache_->aclExecutor); + return ret; + // splitItem, groupType, groupListType, actType, +} + int GroupedMatmulOperation::CreateW8A8Token(AclNNVariantPack &aclnnVariantPack) -{ +{ int ret = aclnnGroupedMatmulV4GetWorkspaceSize(aclnnVariantPack.aclInTensorList.at(DIM0), aclnnVariantPack.aclInTensorList.at(DIM1), param_.hasBias ? aclnnVariantPack.aclInTensorList.at(DIM2) : nullptr, @@ -313,7 +339,13 @@ int GroupedMatmulOperation::SetAclNNWorkspaceExecutor() } else if (param_.quantType == GmmQuantType::W4A8_GROUP) { ret = CreateW4A8(aclnnVariantPack); } else { - ret = CreateW8A8Token(aclnnVariantPack); + atb::Tensor x = aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor; + atb::Tensor y = aclnnVariantPack.aclInTensors.at(DIM1)->atbTensor; + if(!(x.desc.shape.dims[0] < 2048 && x.desc.shape.dims[0] > 8 && y.desc.shape.dims[2]==7168)){ + ret = CreateW8A8Token(aclnnVariantPack); + }else{ + ret = IndexGmmQuant(aclnnVariantPack); + } } ATB_SPEED_LOG_DEBUG(opName_ << " SetAclNNWorkspaceExecutor end, ret:" << ret @@ -325,8 +357,20 @@ int GroupedMatmulOperation::SetAclNNWorkspaceExecutor() int GroupedMatmulOperation::ExecuteAclNNOp(uint8_t *workspace, aclrtStream &stream) { ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul start"); - int ret = aclnnGroupedMatmulV4( - workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + + AclNNVariantPack &aclnnVariantPack = this->aclnnOpCache_->aclnnVariantPack; + atb::Tensor x = aclnnVariantPack.aclInTensors.at(DIM0)->atbTensor; + atb::Tensor y = aclnnVariantPack.aclInTensors.at(DIM1)->atbTensor; + int ret = 0; + if(!(x.desc.shape.dims[0] < 2048 && x.desc.shape.dims[0] > 8 && y.desc.shape.dims[2]==7168)){ + ret = aclnnGroupedMatmulV4( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + } + else{ + ret = aclnnIndexGroupMatmul( + workspace, this->aclnnOpCache_->workspaceSize, this->aclnnOpCache_->aclExecutor, stream); + } + ATB_SPEED_LOG_DEBUG(opName_ << " aclnnGroupedMatmul end, ret:" << ret); return ret; } diff --git a/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.h b/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.h index e752463f..804162ff 100755 --- a/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.h +++ b/xllm/core/kernels/ascend/a2/atb_layers/operations/aclnn/ops/grouped_matmul_operation.h @@ -105,6 +105,7 @@ class GroupedMatmulOperation : public AclNNOperation { int CreateA16(AclNNVariantPack &aclnnVariantPack); int CreateW8A8Token(AclNNVariantPack &aclnnVariantPack); int CreateW4A8(AclNNVariantPack &aclnnVariantPack); + int IndexGmmQuant(AclNNVariantPack &aclnnVariantPack); std::vector yTensorVector; std::vector> inputVectorOfTensor; diff --git a/xllm/core/layers/ascend/deepseek_v2_decoder_layer.cpp b/xllm/core/layers/ascend/deepseek_v2_decoder_layer.cpp index 8952329a..358888b8 100644 --- a/xllm/core/layers/ascend/deepseek_v2_decoder_layer.cpp +++ b/xllm/core/layers/ascend/deepseek_v2_decoder_layer.cpp @@ -1067,11 +1067,11 @@ void DeepseekV2DecoderImpl::merge_experts_weights() { torch::Tensor mlp_down_weight = merge_experts_weights(experts_weights_["down_proj.weight"], - /*transpose=*/false); - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - // at_npu::native::npu_format_cast(mlp_down_weight, 29); + /*transpose=*/true); at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + at_npu::native::npu_format_cast(mlp_down_weight, 29); + // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + // at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); if (quantize_type_ == "w8a8_dynamic") { at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights(experts_weights_["down_proj.weight_offset"]);