Skip to content

Commit

Permalink
Add support for ONNX op "com.microsoft.EmbedLayerNormalization" (#7837)
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztabaka committed Oct 7, 2021
1 parent f856ac0 commit 1269438
Show file tree
Hide file tree
Showing 11 changed files with 1,187 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "op/com.microsoft/embed_layer_normalization.hpp"

#include "default_opset.hpp"
#include "onnx_import/core/null_node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector embed_layer_normalization(const Node& node) {
auto nodes = node.get_ng_inputs();
auto num_nodes = nodes.size();

NGRAPH_CHECK(num_nodes >= 7 && num_nodes <= 8,
"EmbedLayerNormalization takes 7 or 8 inputs. Provided " + std::to_string(num_nodes));
NGRAPH_CHECK(nodes[0].get_element_type() == element::i32, "input_ids must have int32 type");

const auto& input_ids = nodes[0];
const auto& segment_ids = nodes[1];
const auto& word_embeddings = nodes[2];
const auto& position_embeddings = nodes[3];
const auto& segment_embeddings = nodes[4];
const auto& gamma = nodes[5];
const auto& beta = nodes[6];

auto zero = default_opset::Constant::create(element::i32, Shape{1}, {0});
std::shared_ptr<ngraph::Node> input = std::make_shared<default_opset::Gather>(word_embeddings, input_ids, zero, 0);
input = std::make_shared<default_opset::Add>(input, position_embeddings);

// add segment embeddings if available
if (!ngraph::op::is_null(segment_ids)) {
NGRAPH_CHECK(!ngraph::op::is_null(segment_embeddings),
"segment_ids provided, but segment_embedding input is missing");
NGRAPH_CHECK(nodes[1].get_element_type() == element::i32, "segment_ids must have int32 type");
auto gathered_segment_embeddings =
std::make_shared<default_opset::Gather>(segment_embeddings, segment_ids, zero, 0);
input = std::make_shared<default_opset::Add>(input, gathered_segment_embeddings);
}

float eps = node.get_attribute_value<float>("epsilon");
// reduce over hidden_size
// hidden_size dimension is 2 here, because the shape after Gather(word_embedding, input_ids)
// is (batch_size, seq_len, hidden_size)
int hidden_size_dim = 2;
const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim});
std::shared_ptr<ngraph::Node> result =
std::make_shared<default_opset::MVN>(input, reduction_axes, true, eps, ngraph::op::MVNEpsMode::INSIDE_SQRT);

// result = gamma * result + beta
result = std::make_shared<default_opset::Multiply>(result, gamma);
result = std::make_shared<default_opset::Add>(result, beta);

// compute mask_index output
std::shared_ptr<ngraph::Node> mask_index;
if (num_nodes > 7 && !ngraph::op::is_null(nodes[7])) {
NGRAPH_CHECK(nodes[7].get_element_type() == element::i32, "mask must have int32 type");
auto axis = default_opset::Constant::create(element::i32, Shape{}, {1});
mask_index = std::make_shared<default_opset::ReduceSum>(nodes[7], axis, false);
} else {
auto batch_size = std::make_shared<default_opset::Gather>(std::make_shared<default_opset::ShapeOf>(nodes[0]),
zero, // indices
zero); // axis
mask_index = std::make_shared<default_opset::Broadcast>(zero, batch_size);
}
return {result, mask_index};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "onnx_import/core/node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector embed_layer_normalization(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
2 changes: 2 additions & 0 deletions ngraph/frontend/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "op/ceil.hpp"
#include "op/clip.hpp"
#include "op/com.microsoft/bias_gelu.hpp"
#include "op/com.microsoft/embed_layer_normalization.hpp"
#include "op/com.microsoft/skip_layer_normalization.hpp"
#include "op/compress.hpp"
#include "op/concat.hpp"
Expand Down Expand Up @@ -482,6 +483,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Swish", 1, swish);

REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "BiasGelu", 1, bias_gelu);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "EmbedLayerNormalization", 1, embed_layer_normalization);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "SkipLayerNormalization", 1, skip_layer_normalization);
}

Expand Down
3 changes: 3 additions & 0 deletions ngraph/test/engines_util/ie_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ testing::AssertionResult test::IE_Engine::compare_results_with_tolerance_as_fp(c
comparison_result = test::compare_with_tolerance(test_results.first, test_results.second, tolerance);
break;
}
case InferenceEngine::Precision::I32:
comparison_result = compare_blobs<int32_t>(computed_output_blob, expected_output_blob, 0);
break;
default:
comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
Expand Down
3 changes: 3 additions & 0 deletions ngraph/test/engines_util/interpreter_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ testing::AssertionResult test::INTERPRETER_Engine::compare_results_with_toleranc
case element::Type_t::f32:
comparison_result = compare_with_fp_tolerance(expected_result_constant, result_tensor, tolerance);
break;
case element::Type_t::i32:
comparison_result = compare_values<int32_t>(expected_result_constant, result_tensor, 0);
break;
default:
comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
ir_version: 6
producer_name: "nGraph"
graph {
node {
input: "input_ids"
input: "segment_ids"
input: "word_embeddings"
input: "position_embeddings"
input: "segment_embeddings"
input: "gamma"
input: "beta"
input: "mask"
output: "output"
output: "mask_index"
name: "EmbedLayerNormalization_1"
op_type: "EmbedLayerNormalization"
attribute {
name: "epsilon"
f: 9.999999960041972e-13
type: FLOAT
}
domain: "com.microsoft"
}
name: "graph"
input {
name: "input_ids"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "seq_len"
}
}
}
}
}
input {
name: "segment_ids"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "seq_len"
}
}
}
}
}
input {
name: "word_embeddings"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "word_embed_len"
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "position_embeddings"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "pos_embed_len"
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "segment_embeddings"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "segment_embed_len"
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "gamma"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "beta"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "mask"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "seq_len"
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "seq_len"
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "mask_index"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_param: "batch_size"
}
}
}
}
}
}
opset_import {
version: 11
}
opset_import {
domain: "com.microsoft"
version: 1
}
Loading

0 comments on commit 1269438

Please sign in to comment.