diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h deleted file mode 100644 index 277e29c4d59ce..0000000000000 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" -#include "paddle/phi/kernels/primitive/kernel_primitives.h" - -namespace paddle { -namespace operators { - -// support gemm-nt and gemm-nn, which is used in fused_attention_op. -template -class AttnMatMul { - public: - // (m, n, k) = bsz_seq, output_size, input_size - AttnMatMul(const phi::GPUContext& dev_ctx, - bool transA, - bool transB, - int bsz_seq, - int output_size, - int input_size, - bool compute_bias) - : dev_ctx_(dev_ctx), - transA_(transA), - transB_(transB), - bsz_seq_(bsz_seq), - output_size_(output_size), - input_size_(input_size), - compute_bias_(compute_bias) {} - - void ComputeForward(const phi::DenseTensor* weight, - const phi::DenseTensor* input, - const phi::DenseTensor* bias, - phi::DenseTensor* output, - phi::DenseTensor* bias_out, - bool fused = false) { - VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={" - << weight->dims() << "}, output.shape={" << output->dims() - << "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_ - << ", input_size=" << input_size_ << ", transA=" << transA_ - << ", transB=" << transB_ << ", compute_bias=" << compute_bias_ - << ", fused=" << fused; - -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 - if (compute_bias_ && fused) { - PADDLE_ENFORCE_EQ( - !output || output == bias_out, - true, - phi::errors::InvalidArgument( - "The output (= input * weight) is expected to be nullptr or the " - "same as bias_out when fused is true.")); - - phi::funcs::LinearWithCublasLt::Run( - dev_ctx_, - input, // x - weight, // y - bias_out, // out - static_cast(bias->data()), // bias - nullptr, - bsz_seq_, // M - output_size_, // N - input_size_, // K - transA_, - transB_, - phi::funcs::MatmulFusedType::kMatmulBias); - - return; - } -#endif - - // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. - // here: (transa, transb): nt, input * weight. - CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE transB = transB_ ? CblasTrans : CblasNoTrans; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - - // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) - auto blas = phi::funcs::GetBlas(dev_ctx_); - blas.GEMM(transA, - transB, - bsz_seq_, - output_size_, - input_size_, - alpha, - input->data(), - weight->data(), - beta, - output->data()); - if (compute_bias_) { - // bias_out = output + bias - std::vector ins = {output, bias}; - std::vector outs = {bias_out}; - phi::funcs::BroadcastKernel( - dev_ctx_, ins, &outs, phi::funcs::AddFunctor()); - } - } - - void ComputeBackward(const phi::DenseTensor* input, - const phi::DenseTensor* weight, - const phi::DenseTensor* d_output, - phi::DenseTensor* d_input, - phi::DenseTensor* d_weight, - phi::DenseTensor* d_bias, - bool use_addto = false, - bool fused = false) { -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 - if (compute_bias_ && fused) { - phi::funcs::ComputeFusedGemmEpilogueBackward(dev_ctx_, - d_output, - input, - weight, - nullptr, - bsz_seq_, // M - output_size_, // N - input_size_, // K - transA_, - transB_, - "none", - d_input, - d_weight, - d_bias, - use_addto); - return; - } -#endif - - T alpha = static_cast(1.0); - T beta_dA = use_addto ? static_cast(1.0) : static_cast(0.0); - T beta_dB = static_cast(0.0); - - auto blas = phi::funcs::GetBlas(dev_ctx_); - if (!transA_) { - // forward: gemm-nt - if (transB_) { - // backward: gemm-tn, dB = (dC)^T * A - if (d_weight) { - int dB_m = output_size_; - int dB_n = input_size_; - int dB_k = bsz_seq_; - - T* dB_output_ptr = d_weight->data(); - blas.GEMM(CblasTrans, - CblasNoTrans, - dB_m, - dB_n, - dB_k, - alpha, - d_output->data(), - input->data(), - beta_dB, - dB_output_ptr); - } - - // backward: gemm-nn, dA = dC * B - if (d_input) { - int dA_m = bsz_seq_; - int dA_n = input_size_; - int dA_k = output_size_; - - T* dA_output_ptr = d_input->data(); - blas.GEMM(CblasNoTrans, - CblasNoTrans, - dA_m, - dA_n, - dA_k, - alpha, - d_output->data(), - weight->data(), - beta_dA, - dA_output_ptr); - } - } else { // fw: gemm-nn - // backward: gemm-tn, dB = A^T * dC - if (d_weight) { - int dB_m = input_size_; - int dB_n = output_size_; - int dB_k = bsz_seq_; - - T* dB_output_ptr = d_weight->data(); - blas.GEMM(CblasTrans, - CblasNoTrans, - dB_m, - dB_n, - dB_k, - alpha, - input->data(), - d_output->data(), - beta_dB, - dB_output_ptr); - } - - // backward: gemm-nt, dA = dC * B^T - if (d_input) { - int dA_m = bsz_seq_; - int dA_n = input_size_; - int dA_k = output_size_; - - T* dA_output_ptr = d_input->data(); - blas.GEMM(CblasNoTrans, - CblasTrans, - dA_m, - dA_n, - dA_k, - alpha, - d_output->data(), - weight->data(), - beta_dA, - dA_output_ptr); - } - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "AttnMatMul wrapper do not support (transA=T, transB=T/N)" - "parameters.")); - } - if (compute_bias_ && d_bias) { - // reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3} - // -> {3} or {0,1,2,3,4} -> {3,4} - const auto input_dims = d_output->dims(); - const auto output_dims = d_bias->dims(); - bool support_case_1 = - (input_dims.size() == 5 && output_dims.size() == 3 && - (input_dims[2] == output_dims[0]) && - (input_dims[3] == output_dims[1]) && - (input_dims[4] == output_dims[2])); - bool support_case_2 = - (input_dims.size() == 3 && output_dims.size() == 1 && - (input_dims[2] == output_dims[0])); - bool support_case_3 = - (input_dims.size() == 4 && output_dims.size() == 1 && - input_dims[3] == output_dims[0]); - bool support_case_4 = - (input_dims.size() == 5 && output_dims.size() == 2 && - input_dims[3] == output_dims[0] && input_dims[4] == output_dims[1]); - - gpuStream_t stream = dev_ctx_.stream(); - if (support_case_1 || support_case_2) { - TensorReduceImpl>( - dev_ctx_, - *d_output, - d_bias, - kps::IdentityFunctor(), - {0, 1}, - stream); - } else if (support_case_3 || support_case_4) { - TensorReduceImpl>( - dev_ctx_, - *d_output, - d_bias, - kps::IdentityFunctor(), - {0, 1, 2}, - stream); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Only support reduce when the input dims are [0,1,2,3,4] and " - "output is [2,3,4]" - "or input is [0,1,2] and output is [2].")); - } - } - } - - private: - const phi::GPUContext& dev_ctx_; - - bool transA_; - bool transB_; - - int bsz_seq_; - int output_size_; - int input_size_; - - int compute_bias_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index 5cbc4788a0c68..89f17f24b74a1 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h" +#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index e2cdb513feada..9caca507c08bb 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fused_gate_attention.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" namespace paddle { namespace operators { @@ -73,8 +73,8 @@ void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = 3 * config.num_heads * config.head_dim; int k = config.q_dim; - auto qkv_compute = - AttnMatMul(ctx.cuda_device_context(), false, true, m, n, k, false); + auto qkv_compute = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, true, m, n, k, false); qkv_compute.ComputeForward(qkv_weight, query, nullptr, qkv_out, nullptr); } @@ -95,8 +95,8 @@ void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = 3 * config.num_heads * config.head_dim; int k = config.q_dim; - auto qkv_compute = - AttnMatMul(ctx.cuda_device_context(), false, true, m, n, k, false); + auto qkv_compute = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, true, m, n, k, false); qkv_compute.ComputeBackward(query, qkv_weight, qkv_out_grad, @@ -125,7 +125,7 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; int q_n = config.num_heads * config.head_dim; int q_k = config.q_dim; - auto q_compute = AttnMatMul( + auto q_compute = phi::fusion::AttnMatMul( ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false); q_compute.ComputeForward(query_weight, query, nullptr, query_out, nullptr); @@ -136,7 +136,7 @@ void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx, int kv_m = config.batch_size * config.seq_len_m * config.m_size; int kv_n = config.num_heads * config.head_dim; int kv_k = config.kv_dim; - auto kv_compute = AttnMatMul( + auto kv_compute = phi::fusion::AttnMatMul( ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false); kv_compute.ComputeForward(key_weight, key, nullptr, key_out, nullptr); @@ -165,7 +165,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, int kv_m = config.batch_size * config.seq_len_m * config.m_size; int kv_n = config.num_heads * config.head_dim; int kv_k = config.kv_dim; - auto kv_compute = AttnMatMul( + auto kv_compute = phi::fusion::AttnMatMul( ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false); kv_compute.ComputeBackward( key, key_weight, key_out_grad, key_grad, key_weight_grad, nullptr, false); @@ -193,7 +193,7 @@ void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx, int q_m = config.batch_size * config.seq_len_m * config.seq_len_r; int q_n = config.num_heads * config.head_dim; int q_k = config.q_dim; - auto q_compute = AttnMatMul( + auto q_compute = phi::fusion::AttnMatMul( ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false); q_compute.ComputeBackward(query, query_weight, @@ -221,8 +221,8 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.num_heads * config.head_dim; int k = config.q_dim; - auto gate_linear = - AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); + auto gate_linear = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, false, m, n, k, true); gate_linear.ComputeForward(gate_weight, query, gate_bias, @@ -258,8 +258,8 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.num_heads * config.head_dim; int k = config.q_dim; - auto gate_linear = - AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); + auto gate_linear = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, false, m, n, k, true); gate_linear.ComputeForward(gate_weight, query, gate_bias, @@ -307,8 +307,8 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; int k = config.num_heads * config.head_dim; - auto out_linear = - AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); + auto out_linear = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, false, m, n, k, true); out_linear.ComputeForward(out_linear_weight, fmha_or_gate_out, out_linear_bias, @@ -342,8 +342,8 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; int k = config.num_heads * config.head_dim; - auto out_linear = - AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); + auto out_linear = phi::fusion::AttnMatMul( + ctx.cuda_device_context(), false, false, m, n, k, true); out_linear.ComputeBackward(input, out_linear_weight, out_grad, diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index aa7c4cb4fd9f8..e3158d74df629 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -109,13 +109,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // (transA, transB, compute_bias) = (false, trans_qkvw, false) // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set // compute_bias as false. - auto qkv_compute = AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); + auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, + false, + trans_qkvw, + token_num, + output_size, + input_size, + /*compute_bias=*/false); phi::DenseTensor qkv_out; qkv_out.Resize({{token_num, 3, num_head, dim_head}}); @@ -219,7 +219,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = AttnMatMul( + auto out_linear_compute = phi::fusion::AttnMatMul( dev_ctx, false, false, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) @@ -260,7 +260,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn2_weights = ctx.MultiInput("FFN2Weight"); auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_compute = AttnMatMul( + auto ffn2_linear_compute = phi::fusion::AttnMatMul( dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); // 8. ffn2 Layernorm residual bias @@ -775,13 +775,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // (transA, transB, compute_bias) = (false, trans_qkvw, false) // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we // set compute_bias as false. - auto qkv_compute = AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); + auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, + false, + trans_qkvw, + token_num, + output_size, + input_size, + /*compute_bias=*/false); phi::DenseTensor qkv_out; qkv_out.Resize({{token_num, 3, num_head, dim_head}}); @@ -885,7 +885,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto out_linear_biases = ctx.MultiInput("OutLinearBias"); int ring_id = ctx.Attr("ring_id"); // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = AttnMatMul( + auto out_linear_compute = phi::fusion::AttnMatMul( dev_ctx, false, false, token_num, dim_embed, hidden_size, false); // 5. ln(residual + bias) @@ -912,7 +912,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn1_weight_dim = ffn1_weights[0]->dims(); int dim_ffn = ffn1_weight_dim[1]; - auto ffn1_linear_compute = AttnMatMul( + auto ffn1_linear_compute = phi::fusion::AttnMatMul( dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); phi::DenseTensor ffn1_out; ffn1_out.Resize({{token_num, dim_ffn}}); @@ -934,7 +934,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // 8. ffn2 matmul auto ffn2_weights = ctx.MultiInput("FFN2Weight"); auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_compute = AttnMatMul( + auto ffn2_linear_compute = phi::fusion::AttnMatMul( dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); // 9. ffn2 residual bias diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index a81a38ed3877f..ba12bdc8b9d7f 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -28,7 +28,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" @@ -39,6 +38,7 @@ limitations under the License. */ #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" diff --git a/paddle/phi/kernels/fusion/gpu/attn_gemm.h b/paddle/phi/kernels/fusion/gpu/attn_gemm.h index a96601dddacac..8b83ddab93b9b 100644 --- a/paddle/phi/kernels/fusion/gpu/attn_gemm.h +++ b/paddle/phi/kernels/fusion/gpu/attn_gemm.h @@ -73,18 +73,20 @@ class AttnMatMul { phi::errors::InvalidArgument( "The output (= input * weight) is expected to be nullptr or the " "same as bias_out when fused is true.")); - phi::funcs::ComputeFusedGemmEpilogueForward(dev_ctx_, - input, - weight, - bias, - bsz_seq_, // M - output_size_, // N - input_size_, // K - transA_, - transB_, - "none", - bias_out, - nullptr); + + phi::funcs::LinearWithCublasLt::Run( + dev_ctx_, + input, // x + weight, // y + bias_out, // out + static_cast(bias->data()), // bias + nullptr, + bsz_seq_, // M + output_size_, // N + input_size_, // K + transA_, + transB_, + phi::funcs::MatmulFusedType::kMatmulBias); return; } #endif