Skip to content

Commit

Permalink
Add cross batch wasq (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes committed Oct 27, 2020
1 parent fd717a3 commit ab2db95
Show file tree
Hide file tree
Showing 10 changed files with 1,099 additions and 171 deletions.
48 changes: 14 additions & 34 deletions paddle/fluid/operators/batch_fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,27 @@ class BatchFCOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"X(Input) of Batch Fully Connected should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Out(Output) of Batch Fully Connected should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"W(Input) of Batch Fully Connected should not be null."));
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "BatchFCOp");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "BatchFCOp");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchFCOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BatchFCOp");

auto input_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
auto batchcount = ctx->Attrs().Get<int64_t>("batchcount");

PADDLE_ENFORCE_EQ(input_dims.size(), 3,
int feature_dim = input_dims[1] / batchcount;
PADDLE_ENFORCE_EQ(feature_dim, w_dims[0],
platform::errors::InvalidArgument(
"Input of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(w_dims.size(), 3, platform::errors::InvalidArgument(
"W of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(
input_dims[0], w_dims[0],
platform::errors::InvalidArgument(
"Input.dim[0] and W.dim[0] of BatchFCOp should be same."));
PADDLE_ENFORCE_EQ(
input_dims[2], w_dims[1],
platform::errors::InvalidArgument(
"Input.dim[2] and W.dim[1] of BatchFCOp should be same."));
"Input.dim[1]/batchcount and W.dim[0] of BatchFCOp "
"should be same."));

auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0],
platform::errors::InvalidArgument(
"Bias.dim[0] should be same as input.dim[0]."));
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[2],
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
platform::errors::InvalidArgument(
"Bias.dim[1] should be same as input.dim[2]."));
"Bias.dim[1] should be same as W.dim[1]."));

ctx->SetOutputDim("Out", {input_dims[0], input_dims[1], w_dims[2]});
ctx->SetOutputDim("Out", {input_dims[0], w_dims[1]});
ctx->ShareLoD("Input", /*->*/ "Out");
}

Expand Down Expand Up @@ -107,6 +89,7 @@ class BatchFCOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input", "(Tensor) Input tensor of batch_fc_op operator.");
AddInput("W", "(Tensor) Input tensor of batch_fc_op operator.");
AddInput("Bias", "(Tensor) Input tensor of batch_fc_op operator.");
AddAttr<int64_t>("batchcount", "(int64_t) the batchcount");
AddOutput("Out", "Output tensor of batch_fc_op operator.");
AddComment(R"DOC(
BatchFC Operator.
Expand Down Expand Up @@ -136,8 +119,6 @@ class BatchFCGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(BatchFCGradOpNoNeedBufferVarsInferer,
"Bias");

} // namespace operators
} // namespace paddle
Expand All @@ -147,8 +128,7 @@ REGISTER_OPERATOR(batch_fc, ops::BatchFCOp, ops::BatchFCOpMaker,
ops::BatchFCGradOpMaker<paddle::framework::OpDesc>,
ops::BatchFCGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp,
ops::BatchFCGradOpNoNeedBufferVarsInferer);
REGISTER_OPERATOR(batch_fc_grad, ops::BatchFCGradOp);

REGISTER_OP_CPU_KERNEL(
batch_fc, ops::BatchFCKernel<paddle::platform::CPUDeviceContext, float>,
Expand Down

0 comments on commit ab2db95

Please sign in to comment.