Skip to content

Commit

Permalink
LayerNorm acceleration on GPU (apache#14935)
Browse files Browse the repository at this point in the history
fix lint

fix lint

fix bug

further accelerate

fix

fix bug

fix bug
  • Loading branch information
sxjscience authored and haohuw committed Jun 23, 2019
1 parent c8ace2e commit 4749fed
Show file tree
Hide file tree
Showing 3 changed files with 686 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
}
};


template<typename xpu>
void LayerNormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const std::vector<TBlob>& outputs);

template<typename xpu>
void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
Expand Down Expand Up @@ -146,6 +151,12 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {outputs[0]});
}

template<typename xpu>
void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

/*
Calculate the gradient of layer normalization.
We have the following gradient for gamma, beta and x:
Expand All @@ -157,10 +168,10 @@ grad_beta = sum(og, exclude_axis)
grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
*/
template<typename xpu>
void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 5U);
Expand Down
16 changes: 16 additions & 0 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
}


template<>
void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
return LayerNormComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

template<>
void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
return LayerNormGradComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

NNVM_REGISTER_OP(LayerNorm)
.describe(R"code(Layer normalization.
Expand Down
Loading

0 comments on commit 4749fed

Please sign in to comment.