-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
754925b
commit 0ee15e5
Showing
3 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
31 changes: 31 additions & 0 deletions
31
...ormations/include/transformations/common_optimizations/division_to_zero_fp16_resolver.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <utility> | ||
#include <memory> | ||
|
||
#include <transformations_visibility.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
#include "ngraph/pattern/matcher.hpp" | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API DivisionToZeroFP16Resolver; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief : | ||
*/ | ||
class ngraph::pass::DivisionToZeroFP16Resolver: public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
DivisionToZeroFP16Resolver(); | ||
}; |
57 changes: 57 additions & 0 deletions
57
...ansformations/src/transformations/common_optimizations/division_to_zero_fp16_resolver.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "itt.hpp" | ||
#include "transformations/common_optimizations/division_to_zero_fp16_resolver.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include <ngraph/opsets/opset8.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include <ngraph/pattern/op/or.hpp> | ||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DivisionToZeroFP16Resolver, "DivisionToZeroFP16Resolver", 0); | ||
|
||
constexpr float normalized_fp16_min = 6.103515625e-05f; // normalized minimum of fp16 | ||
|
||
ngraph::pass::DivisionToZeroFP16Resolver::DivisionToZeroFP16Resolver() { | ||
MATCHER_SCOPE(DivisionToZeroFP16Resolver); | ||
auto input_1 = ngraph::pattern::any_input(); | ||
auto input_2 = ngraph::pattern::any_input(); | ||
|
||
|
||
auto eps_const_pattern = ngraph::pattern::wrap_type<ngraph::opset8::Constant>(); | ||
auto max = std::make_shared<ngraph::opset8::Maximum>(input_2, eps_const_pattern); | ||
auto add = std::make_shared<ngraph::opset8::Add>(input_2, eps_const_pattern); | ||
auto max_or_add = std::make_shared<pattern::op::Or>(OutputVector{max, add}); | ||
auto divide = std::make_shared<ngraph::opset8::Divide>(input_1, max_or_add); | ||
|
||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { | ||
const auto& pattern_to_output = m.get_pattern_value_map(); | ||
|
||
const auto eps_const = std::dynamic_pointer_cast<ngraph::opset8::Constant>(pattern_to_output.at(eps_const_pattern).get_node_shared_ptr()); | ||
|
||
if (!eps_const) { | ||
return false; | ||
} | ||
for (float val : eps_const->get_vector<float>()) { | ||
if (val >= normalized_fp16_min) { | ||
return false; | ||
} | ||
} | ||
|
||
auto new_constant = std::make_shared<opset8::Constant>(eps_const->get_element_type(), | ||
eps_const->get_shape(), | ||
normalized_fp16_min); | ||
ngraph::copy_runtime_info(eps_const, new_constant); | ||
ngraph::replace_node(eps_const, new_constant); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, matcher_name); | ||
register_matcher(m, callback); | ||
} |
89 changes: 89 additions & 0 deletions
89
...tests/functional/inference_engine/transformations/division_to_zero_fp16_resolver_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <string> | ||
#include <memory> | ||
|
||
#include <ngraph/function.hpp> | ||
#include <ngraph/opsets/opset4.hpp> | ||
#include <ngraph/pass/manager.hpp> | ||
#include <transformations/common_optimizations/division_to_zero_fp16_resolver.hpp> | ||
#include <transformations/init_node_info.hpp> | ||
#include <transformations/utils/utils.hpp> | ||
|
||
#include "common_test_utils/ngraph_test_utils.hpp" | ||
|
||
using namespace testing; | ||
constexpr float normalized_fp16_min = 6.103515625e-05f; // normalized minimum of fp16 | ||
|
||
TEST_F(TransformationTestsF, DivisionToZeroWithMax) { | ||
const float eps_value = 1.e-12; | ||
{ | ||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(3)); | ||
auto exp = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {2.f}); | ||
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp); | ||
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1}); | ||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const); | ||
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {eps_value}); | ||
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const); | ||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max); | ||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt); | ||
|
||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input}); | ||
|
||
manager.register_pass<ngraph::pass::DivisionToZeroFP16Resolver>(); | ||
} | ||
|
||
{ | ||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(3)); | ||
auto exp = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {2.f}); | ||
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp); | ||
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1}); | ||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const); | ||
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {normalized_fp16_min}); | ||
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const); | ||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max); | ||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt); | ||
|
||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input}); | ||
} | ||
} | ||
|
||
|
||
TEST_F(TransformationTestsF, DivisionToZeroWithAdd) { | ||
const float eps_value = 0.000099f; | ||
{ | ||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(3)); | ||
auto exp = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {2.f}); | ||
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp); | ||
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1}); | ||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const); | ||
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {eps_value}); | ||
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const); | ||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add); | ||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt); | ||
|
||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input}); | ||
|
||
manager.register_pass<ngraph::pass::DivisionToZeroFP16Resolver>(); | ||
} | ||
|
||
{ | ||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(3)); | ||
auto exp = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {2.f}); | ||
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp); | ||
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1}); | ||
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const); | ||
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {normalized_fp16_min}); | ||
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const); | ||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add); | ||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt); | ||
|
||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input}); | ||
|
||
manager.register_pass<ngraph::pass::DivisionToZeroFP16Resolver>(); | ||
} | ||
} |