From ce95ff5aceba18f1b85953b83036dc9a52864b2c Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Thu, 28 Sep 2023 20:13:21 +0800 Subject: [PATCH] compilation optimization for matmul_grad_kernel (#57823) --- paddle/phi/kernels/impl/matmul_grad_kernel_impl.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index 899ee5f3a497b..4125e49db6eef 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/dot_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/phi/kernels/gpu/reduce.h" @@ -60,8 +61,8 @@ struct ReduceSumForMatmulGrad { const DenseTensor& input, DenseTensor* output, const std::vector& reduce_dims) { - funcs::ReduceKernel>( - dev_ctx, input, output, kps::IdentityFunctor(), reduce_dims); + phi::SumKernel( + dev_ctx, input, reduce_dims, input.dtype(), false, output); } }; #endif