Skip to content

Commit

Permalink
compile ok
Browse files Browse the repository at this point in the history
  • Loading branch information
gongweibao committed Aug 30, 2017
1 parent eae9acb commit 0a3fd42
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
1 change: 1 addition & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu)

op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(element_wise_mul_op SRCS element_wise_mul_op.cc element_wise_mul_op.cu DEPS math_function)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)

op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
Expand Down
15 changes: 7 additions & 8 deletions paddle/operators/element_wise_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/element_wise_mul_op.h"

namespace paddle {
namespace operators {
Expand All @@ -32,7 +32,7 @@ class ElemWiseMulOp : public framework::OperatorWithKernel {
}
};

class ElemWiseOpMaker : public framework::OpProtoAndCheckerMaker {
class ElemWiseMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElemWiseMulOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
Expand Down Expand Up @@ -65,21 +65,20 @@ class ElemWiseMulOpGrad : public framework::OperatorWithKernel {
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims == out_dims, "Out@GRAD must equal to X dims");
PADDLE_ENFORCE(y_dim == out_dims, "Out@GRAD must equal to Y dims");
PADDLE_ENFORCE(y_dims == out_dims, "Out@GRAD must equal to Y dims");

x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ElemWiseMulOp, ops::ElemWiseMulOpMaker, sigmoid_grad,
ops::ElemWiseMulOpGrad);
REGISTER_OP(elemwisemul, ops::ElemWiseMulOp, ops::ElemWiseMulOpMaker,
elemwisemul_grad, ops::ElemWiseMulOpGrad);
REGISTER_OP_CPU_KERNEL(
sigmoid, ops::ElemWiseMulKernel<paddle::platform::CPUPlace, float>);
elemwisemul, ops::ElemWiseMulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad,
elemwisemul_grad,
ops::ElemWiseMulGradKernel<paddle::platform::CPUPlace, float>);
40 changes: 19 additions & 21 deletions paddle/operators/element_wise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include "paddle/operators/math/math_function.h"

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

Expand All @@ -26,22 +28,20 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename Place, typename T>
class ElemWiseMulOPKernel : public framework::OpKernel {
class ElemWiseMulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);

auto X_e = EigenMatrix<T>::From(*X);
auto Y_e = EigenMatrix<T>::From(*Y);
auto Z_e = EigenMatrix<T>::From(*Z);
auto X_e = framework::EigenVector<T>::Flatten(*X);
auto Y_e = framework::EigenVector<T>::Flatten(*Y);
auto Z_e = framework::EigenVector<T>::Flatten(*Z);

// TODO: gpu?
Z_e.device(context.GetEigenDevice<place>()) = X_e.cwiseProduct(Y_e);
// TODO(gongweibao): gpu?
Z_e.device(context.GetEigenDevice<Place>()) = X_e * Y_e;
}
};

Expand All @@ -57,20 +57,18 @@ class ElemWiseMulGradKernel : public framework::OpKernel {
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);

auto X_e = EigenMatrix<T>::From(*X);
auto Y_e = EigenMatrix<T>::From(*Y);
auto dX_e = EigenMatrix<T>::From(*dX);
auto dY_e = EigenMatrix<T>::From(*dY);
auto dOut_e = EigenMatrix<T>::From(*dOut);

// TODO: gpu?
dX.device(context.GetEigenDevice<place>()) = dOut_e.cwiseProduct(Y_e);
dY.device(context.GetEigenDevice<place>()) = dOut_e.cwiseProduct(X_e);

auto X_e = framework::EigenVector<T>::Flatten(*X);
auto Y_e = framework::EigenVector<T>::Flatten(*Y);
auto dX_e = framework::EigenVector<T>::Flatten(*dX);
auto dY_e = framework::EigenVector<T>::Flatten(*dY);
auto dOut_e = framework::EigenVector<T>::Flatten(*dOut);

// TODO(gongweibao): gpu?
dX_e.device(ctx.GetEigenDevice<Place>()) = dOut_e * Y_e;
dY_e.device(ctx.GetEigenDevice<Place>()) = dOut_e * X_e;
}
};

} // namespace operators
}
} // namespace paddle

0 comments on commit 0a3fd42

Please sign in to comment.