Skip to content

Commit

Permalink
Add model compression to FP16 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Oct 7, 2021
1 parent 1269438 commit 6f254d2
Show file tree
Hide file tree
Showing 21 changed files with 502 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer
C.ApplyLowLatencyTransformation(network.impl, use_const_initializer)


def CompressModelTransformation(IENetwork network):
C.CompressModelTransformation(network.impl)


def ApplyPruningTransformation(IENetwork network):
C.ApplyPruningTransformation(network.impl)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <ngraph/pass/manager.hpp>
#include <pot_transformations.hpp>
#include <pruning.hpp>
#include <transformations/common_optimizations/compress_constants.hpp>
#include <transformations/common_optimizations/mark_precision_sensitive_subgraphs.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>

void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf) {
Expand Down Expand Up @@ -47,6 +49,13 @@ void InferenceEnginePython::GenerateMappingFile(InferenceEnginePython::IENetwork
manager.run_passes(network.actual->getFunction());
}

void InferenceEnginePython::CompressModelTransformation(InferenceEnginePython::IENetwork network) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::MarkPrecisionSensitiveSubgraphs>();
manager.register_pass<ngraph::pass::CompressConstants>();
manager.run_passes(network.actual->getFunction());
}

void InferenceEnginePython::CheckAPI() {
std::shared_ptr<ngraph::Function> f;
{
Expand All @@ -63,4 +72,4 @@ void InferenceEnginePython::CheckAPI() {
auto reshape = f->get_result()->input_value(0).get_node_shared_ptr();
assert(std::dynamic_pointer_cast<ngraph::opset6::Parameter>(reshape->input_value(0).get_node_shared_ptr()));
assert(std::dynamic_pointer_cast<ngraph::opset6::Constant>(reshape->input_value(1).get_node_shared_ptr()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);

void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);

void CompressModelTransformation(InferenceEnginePython::IENetwork network);

void CheckAPI();

}; // namespace InferenceEnginePython
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi
cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer)

cdef void ApplyPruningTransformation(IENetwork network)

cdef void CompressModelTransformation(IENetwork network)

cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API CompressConstants;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief CompressConstants transformation replaces FP32 Constants with FP16 ones.
*/
class ngraph::pass::CompressConstants : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
CompressConstants();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief Disables ConstantFolding for Convert operation in compressed function.
*/
class ngraph::pass::DisableDecompressionConvertConstantFolding : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DisableDecompressionConvertConstantFolding();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/pass.hpp>

#include <transformations_visibility.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API MarkPrecisionSensitiveSubgraphs;

} // namespace pass
} // namespace ngraph

class ngraph::pass::MarkPrecisionSensitiveSubgraphs : public FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <transformations/rt_info/old_api_map_attribute.hpp>
#include <transformations/rt_info/primitives_priority_attribute.hpp>
#include <transformations/rt_info/strides_property.hpp>
#include <transformations/rt_info/decompression.hpp>

namespace ov {
namespace pass {
Expand All @@ -34,6 +35,7 @@ class TRANSFORMATIONS_API Attributes {
register_factory<NmsSelectedIndices>();
register_factory<StridesPropagation>();
register_factory<OldApiMap>();
register_factory<Decompression>();
}

Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <assert.h>
#include <functional>
#include <memory>
#include <string>
#include <set>

#include <ngraph/node.hpp>
#include <ngraph/variant.hpp>
#include <transformations_visibility.hpp>


namespace ov {

class TRANSFORMATIONS_API Decompression : public VariantImpl<void> {
public:
OPENVINO_RTTI("decompression", "0");

Decompression() = default;

bool visit_attributes(AttributeVisitor& visitor) override { return true; }

bool is_copyable() const override { return false; }
};

} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include <ngraph/op/constant.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset8.hpp>

#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations/rt_info/attributes.hpp>

namespace ngraph {
namespace op {
Expand Down Expand Up @@ -50,6 +52,17 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
return false;
}

inline bool has_decompression_converts(const std::shared_ptr<const ngraph::Function>& function) {
for (const auto& op : function->get_ops()) {
if (std::dynamic_pointer_cast<ngraph::opset8::Convert>(op)) {
const auto& rt_info = op->get_rt_info();
if (rt_info.count(ov::Decompression::get_type_info_static()))
return true;
}
}
return false;
}

inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& output) {
const auto& prev_layer = output.get_node_shared_ptr();
std::string out_name = prev_layer->get_friendly_name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/disable_random_uniform_constant_folding.hpp"
#include "transformations/common_optimizations/disable_decomression_convert_constant_folding.hpp"
#include "transformations/common_optimizations/random_uniform_fusion.hpp"
#include "transformations/common_optimizations/add_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
Expand Down Expand Up @@ -77,6 +78,8 @@
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
Expand All @@ -94,8 +97,17 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// This pass must be called first in pipeline
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DisableRandomUniformConstantFolding>();
manager.register_pass<ngraph::pass::DisableDecompressionConvertConstantFolding>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
manager.register_pass<ngraph::pass::ConstantFolding>();

if (ngraph::op::util::has_decompression_converts(f)) {
const precisions_array convert_precision_list{
{ngraph::element::f32, ngraph::element::f16}
};
manager.register_pass<ngraph::pass::ConvertPrecision>(convert_precision_list);
}

manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
manager.register_pass<ngraph::pass::ConvertNmsGatherPathToUnsigned>(); // workaround until dynamism in NMS is not supported

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/compress_constants.hpp"

#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/rt_info/decompression.hpp>
#include "itt.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::CompressConstants, "CompressConstants", 0);

ngraph::pass::CompressConstants::CompressConstants() {
MATCHER_SCOPE(CompressConstants);
auto const_node = ngraph::pattern::wrap_type<opset8::Constant>();

ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto& const_node_pattern = pattern_map.at(const_node);

const auto& const_node = std::dynamic_pointer_cast<ngraph::opset8::Constant>(
const_node_pattern.get_node_shared_ptr());

if (const_node->get_element_type() != ov::element::f32 && const_node->get_element_type() != ov::element::f64)
return false;

const auto& rt_info = const_node->get_rt_info();
if (rt_info.count("DISABLE_FP16_COMPRESSION"))
return false;

auto new_const = std::make_shared<ngraph::opset8::Constant>(ov::element::f16,
const_node->get_shape(),
const_node->cast_vector<float16>().data());
auto convert = std::make_shared<ngraph::opset8::Convert>(new_const, const_node->get_element_type());

convert->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(const_node, convert);

auto& convert_rt_info = convert->get_rt_info();
convert_rt_info[ov::Decompression::get_type_info_static()] = std::make_shared<ov::Decompression>();

ngraph::replace_node(m.get_match_root(), convert);

return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(const_node, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/disable_decomression_convert_constant_folding.hpp"

#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/rt_info/disable_constant_folding.hpp>
#include <transformations/rt_info/decompression.hpp>
#include <itt.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableDecompressionConvertConstantFolding, "DisableDecompressionConvertConstantFolding", 0);

ngraph::pass::DisableDecompressionConvertConstantFolding::DisableDecompressionConvertConstantFolding() {
MATCHER_SCOPE(DisableDecompressionConvertConstantFolding);
auto convert = pattern::wrap_type<opset8::Convert>();

ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& node = m.get_match_root();
const auto& rt_info = node->get_rt_info();
if (!rt_info.count(ov::Decompression::get_type_info_static()))
return false;
disable_constant_folding(node);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(convert, matcher_name);
this->register_matcher(m, callback);
}
Loading

0 comments on commit 6f254d2

Please sign in to comment.