Skip to content

Commit

Permalink
Polish names of variants.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Sep 27, 2021
1 parent 4dd4260 commit 2e3f4f2
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 120 deletions.
68 changes: 34 additions & 34 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
OP_INOUT_CHECK(ctx->HasInput("LinearW"), "Input", "LinearW",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
OP_INOUT_CHECK(ctx->HasInput("LinearBias"), "Input", "LinearBias",
"FusedAttentionOp");

OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
OP_INOUT_CHECK(ctx->HasOutput("PreLnMean"), "Output", "PreLnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
OP_INOUT_CHECK(ctx->HasOutput("PreLnVariance"), "Output", "PreLnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
OP_INOUT_CHECK(ctx->HasOutput("PreLnOut"), "Output", "PreLnOut",
"FusedAttentionOp");
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut"), "Output", "TransposeOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
"FusedAttentionOp");
Expand All @@ -61,11 +61,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
OP_INOUT_CHECK(ctx->HasOutput("LinearOut"), "Output", "LinearOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
Expand Down Expand Up @@ -98,16 +98,16 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));

ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("PreLnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("PreLnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("PreLnOut", ctx->GetInputDim("X"));
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
ctx->SetOutputDim("TransposeOut",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
Expand All @@ -124,10 +124,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch_size, seq_len, number of heads*head size]
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("LinearOut", ctx->GetInputDim("X"));

ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}
Expand All @@ -148,47 +148,47 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
AddInput("PreLnScale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("LnBias",
AddInput("PreLnBias",
"(optional) Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor.");
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.");
AddInput("Ln2Scale",
AddInput("LinearW", "The linear weight tensor.");
AddInput("LinearBias", "The linear bias tensor.");
AddInput("LnScale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("Ln2Bias",
AddInput("LnBias",
"(optional) Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("LnVariance", "Variance of the current mini batch.")
AddOutput("PreLnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("PreLnVariance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate();
AddOutput("PreLnOut", "The output of pre layer_norm.").AsIntermediate();
AddOutput("QKVOut", "Result after qkv.").AsIntermediate();
AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate();
AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate();
AddOutput("TransposeOut", "Result in fmha.").AsIntermediate();
AddOutput("QKOut", "Result in fmha.").AsIntermediate();
AddOutput("QKTVOut", "Result in fmha.").AsIntermediate();
AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate();
AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate();
AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate();
AddOutput("FMHAOut", "Result after fmha.").AsIntermediate();
AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate();
AddOutput("LinearOut", "Result after linear.").AsIntermediate();
AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
.AsIntermediate();
AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("Ln2Variance", "Variance of the current mini batch.")
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("LnVariance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("BiasDropoutResidualOut",
"Result of residual + dropout(src + bias).")
Expand Down Expand Up @@ -289,16 +289,16 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<float>("ln2epsilon",
AddAttr<float>("ln_epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &ln2epsilon) {
PADDLE_ENFORCE_EQ(ln2epsilon >= 0.0f && ln2epsilon <= 0.001f, true,
.AddCustomChecker([](const float &ln_epsilon) {
PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' of the second LayerNorm in Fused "
"attention op should be between"
"0.0 and 0.001, But received [%s].",
ln2epsilon));
ln_epsilon));
});

AddComment(R"DOC(
Expand All @@ -319,7 +319,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
out = transpose(out, perm=[0, 2, 1, 3]);
}
out = out_linear(out);
out = linear(out);
final_out = layer_norm(residual + dropout(bias + out));
)DOC");
}
Expand Down
94 changes: 47 additions & 47 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_bias = ctx.Input<Tensor>("LnBias");
auto *ln_mean = ctx.Output<Tensor>("LnMean");
auto *ln_var = ctx.Output<Tensor>("LnVariance");
auto *ln_out = ctx.Output<Tensor>("LnOut");
auto *pre_ln_scale = ctx.Input<Tensor>("PreLnScale");
auto *pre_ln_bias = ctx.Input<Tensor>("PreLnBias");
auto *pre_ln_mean = ctx.Output<Tensor>("PreLnMean");
auto *pre_ln_var = ctx.Output<Tensor>("PreLnVariance");
auto *pre_ln_out = ctx.Output<Tensor>("PreLnOut");

// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
Expand All @@ -52,7 +52,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");

auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
auto *transpose_out = ctx.Output<Tensor>("TransposeOut");
auto *qk_out = ctx.Output<Tensor>("QKOut");
auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
Expand All @@ -61,18 +61,18 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut");

auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");
auto *linear_weight = ctx.Input<Tensor>("LinearW");
auto *linear_bias = ctx.Input<Tensor>("LinearBias");
auto *linear_out = ctx.Output<Tensor>("LinearOut");

auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_bias = ctx.Input<Tensor>("LnBias");
auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Output<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
const float ln2epsilon = ctx.Attr<float>("ln2epsilon");
auto *ln_mean = ctx.Output<Tensor>("LnMean");
auto *ln_var = ctx.Output<Tensor>("LnVariance");
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

float attn_dropout_prob = ctx.Attr<float>("attn_dropout_prob");
bool attn_dropout_is_test = ctx.Attr<bool>("attn_dropout_is_test");
Expand All @@ -94,20 +94,21 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const auto qkv_w_dims = qkv_weight->dims();

auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
auto *pre_ln_scale_data =
(pre_ln_scale == nullptr ? nullptr : pre_ln_scale->data<U>());
auto *pre_ln_bias_data =
(pre_ln_bias == nullptr ? nullptr : pre_ln_bias->data<U>());
auto *pre_ln_mean_data = pre_ln_mean->mutable_data<U>(ctx.GetPlace());
auto *pre_ln_var_data = pre_ln_var->mutable_data<U>(ctx.GetPlace());
auto *pre_ln_out_data = pre_ln_out->mutable_data<T>(ctx.GetPlace());

auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for FMHA.
auto *transpose_out_2_data =
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *transpose_out_data = transpose_out->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
Expand All @@ -118,22 +119,20 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for linear.
auto *linear_weight_data = linear_weight->data<T>();
auto *linear_bias_data = linear_bias->data<T>();
auto *linear_out_data = linear_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());

int batch_size = input_x_dims[0];
Expand Down Expand Up @@ -164,38 +163,39 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

output_size = hidden_size;
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute =
auto linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
output_size, input_size, false);
DropoutParam dropout_param2(ctx, 0);
DropoutParam dropout_param(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param,
ln_epsilon);

if (pre_layer_norm) {
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
layer_norm_compute.ComputeForward(x_data, pre_ln_scale_data,
pre_ln_bias_data, pre_ln_out_data,
pre_ln_mean_data, pre_ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, pre_ln_out_data,
qkv_bias_data, qkv_out_data,
qkv_bias_out_data);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
// linear_out: [batch_size, seq_len, embed_dim]
linear_compute.ComputeForward(linear_weight_data, fmha_out_data, nullptr,
linear_out_data, nullptr);
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
ctx.cuda_device_context(), linear_out_data, x_data, linear_bias_data,
ln_scale_data, ln_bias_data, bias_dropout_residual_out_data,
dropout_mask_out_data, final_out_data, ln_mean_data, ln_var_data);
}
};

Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}},
{"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW",
"OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"X", "PreLnScale", "PreLnBias", "QKVW", "QKVBias", "SrcMask", "LinearW",
"LinearBias", "LnScale", "LnBias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
Expand Down Expand Up @@ -91,10 +91,10 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"fused_attention",
{"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2",
"QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut",
"SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut", "Y"}},
{"PreLnMean", "PreLnVariance", "PreLnOut", "QKVOut", "QKVBiasOut",
"TransposeOut", "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut",
"AttnDropoutOut", "SrcMaskOut", "FMHAOut", "LinearOut", "DropoutMaskOut",
"LnMean", "LnVariance", "BiasDropoutResidualOut", "Y"}},
{"sync_batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
Expand Down

0 comments on commit 2e3f4f2

Please sign in to comment.