Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [CPU][ARM] Weights compression f32->f16 is moved to CPU Plug-in side #21080

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace pass {
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API KeepConstAndDecompression;
class TRANSFORMATIONS_API KeepConstFP32Unfolded;
class TRANSFORMATIONS_API KeepConstantsPrecisionAndAddConverts;

} // namespace pass
Expand Down Expand Up @@ -49,6 +50,12 @@ class ov::pass::KeepConstAndDecompression : public MatcherPass {
KeepConstAndDecompression();
};

class ov::pass::KeepConstFP32Unfolded : public MatcherPass {
public:
OPENVINO_RTTI("KeepConstFP32Unfolded", "0");
KeepConstFP32Unfolded();
};

/**
* @ingroup ie_transformation_common_api
* @brief Prevents Consts precision conversion and adds Convert with disabled ConstantFolding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ TRANSFORMATIONS_API void unmark_as_decompression(const std::shared_ptr<Node>& no

TRANSFORMATIONS_API bool is_decompression(const std::shared_ptr<Node>& node);

TRANSFORMATIONS_API void mark_as_compression(const std::shared_ptr<Node>& node);

TRANSFORMATIONS_API void unmark_as_compression(const std::shared_ptr<Node>& node);

TRANSFORMATIONS_API bool is_compression(const std::shared_ptr<Node>& node);

/**
* @ingroup ie_runtime_attr_api
* @brief Decompression class represents runtime info attribute that marks operation
Expand All @@ -43,4 +49,19 @@ class TRANSFORMATIONS_API Decompression : public RuntimeAttribute {
}
};

class TRANSFORMATIONS_API Compression : public RuntimeAttribute {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide in comment with explanation why we need this rt_info? Fron which it will be clear why we cannot use the existing ones a need a new rt_info

public:
OPENVINO_RTTI("Compression", "0");

Compression() = default;

bool visit_attributes(AttributeVisitor& visitor) override {
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really necessary to store this rt_info to IR?

}

bool is_copyable() const override {
return false;
}
};

} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/op/result.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "transformations/rt_info/decompression.hpp"
#include "transformations/rt_info/disable_fp16_compression.hpp"

using namespace ov;
Expand Down Expand Up @@ -48,6 +49,7 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
copy_runtime_info(incoming_node, convert);
input.replace_source_output(convert);
disable_fp16_compression(convert);
mark_as_compression(convert);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This converts are decompression converts: they upcast to fp32 for precision sensitive subgraphs. Is it possible to rename mark_as_compression to avoid confusions?
Since mark_as_compression is used only for converts that are inserted to align types for f16 and f32 parts, can we name it e.g. mark_type_aligning_convert to avoid confusion?

pass::disable_constant_folding(convert);
is_changed = true;
}
Expand Down Expand Up @@ -76,6 +78,7 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
auto init_name = node->get_friendly_name() + "_compressed_to_f16";
convert->set_friendly_name(generate_uniq_name(init_name));
out_inputs.replace_source_output(convert);
mark_as_compression(convert);
pass::disable_constant_folding(convert);
is_changed = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
register_matcher(m, callback);
}

pass::KeepConstFP32Unfolded::KeepConstFP32Unfolded() {
MATCHER_SCOPE(KeepConstFP16Unfolded);

auto node_pattern = pattern::wrap_type<ov::op::v0::MatMul>();

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();

if (transformation_callback(node)) {
return false;
}

auto constNode = node->get_input_node_shared_ptr(1);
if (!is_type<ov::op::v0::Constant>(constNode) || constNode->get_output_element_type(0) != element::f32)
return false;

disable_constant_folding(constNode);
enable_keep_const_precision(constNode);
disable_fp16_compression(constNode);

return false;
};
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
register_matcher(m, callback);
}

pass::KeepConstantsPrecisionAndAddConverts::KeepConstantsPrecisionAndAddConverts() {
MATCHER_SCOPE(KeepConstantsPrecisionAndAddConverts);
auto const_pattern = pattern::wrap_type<ov::op::v0::Constant>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,18 @@ bool ov::is_decompression(const std::shared_ptr<Node>& node) {
const auto& rt_info = node->get_rt_info();
return rt_info.count(Decompression::get_type_info_static());
}

void ov::mark_as_compression(const std::shared_ptr<Node>& node) {
auto& rt_info = node->get_rt_info();
rt_info[Compression::get_type_info_static()] = Compression();
}

void ov::unmark_as_compression(const std::shared_ptr<Node>& node) {
auto& rt_info = node->get_rt_info();
rt_info.erase(Compression::get_type_info_static());
}

bool ov::is_compression(const std::shared_ptr<Node>& node) {
const auto& rt_info = node->get_rt_info();
return rt_info.count(Compression::get_type_info_static());
}
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,8 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
&& parent->getChildEdges().size() == 1
&& parent->getChildEdgeAt(0)->getOutputNum() == 1
&& parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
&& one_of(parent->getOriginalInputPrecisionAtPort(0), ov::element::f16)
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), ov::element::f32, ov::element::bf16)
&& one_of(parent->getOriginalInputPrecisionAtPort(0), ov::element::f32, ov::element::bf16, ov::element::f16)
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), ov::element::f32, ov::element::bf16, ov::element::f16)
&& parent->isConstant();
return res;
};
Expand Down
24 changes: 23 additions & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,33 @@ void FullyConnected::prepareWeightsUsingDummyShape() {
if (selected_pd == nullptr)
OPENVINO_THROW("Preferable primitive descriptor is not set for node ", getName(), ".");

auto inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtInputPort(DATA_ID)));
DnnlMemoryDescPtr inDesc = nullptr;
auto weightDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weightDescIP);
auto biasDesc = withBiases ? MemoryDescUtils::convertToDnnlMemoryDesc(getBaseMemDescAtInputPort(BIAS_ID)) : nullptr;
auto outDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtOutputPort(0)));

Shape newInShape = getBaseMemDescAtInputPort(DATA_ID)->getShape();
if (isDynamicNode()) {
auto originalInDesc = getBaseMemDescAtInputPort(DATA_ID);
auto originalInDims = originalInDesc->getShape().getDims();
size_t dimIdx = originalInDims.size() == 3 ? 1 : 0;
// Propagate N dim from the output shape to the input shape
if (newInShape.getDims()[dimIdx] == Shape::UNDEFINED_DIM &&
getBaseMemDescAtOutputPort(0)->getShape().getDims()[dimIdx] != Shape::UNDEFINED_DIM) {
newInShape = cloneShapeWithNewDim(newInShape, getBaseMemDescAtOutputPort(0)->getShape().getDims()[dimIdx], dimIdx);
}
// Propagate K dim from the weights shape to the input shape
if (newInShape.getDims()[dimIdx+1] == Shape::UNDEFINED_DIM &&
weightDesc->getShape().getDims()[1] != Shape::UNDEFINED_DIM) {
newInShape = cloneShapeWithNewDim(newInShape, weightDesc->getShape().getDims()[1], dimIdx+1);
}

auto newInDesc = DnnlBlockedMemoryDesc(originalInDesc->getPrecision(), MemoryDescUtils::makeDummyShape(newInShape));
inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(newInDesc));
} else {
inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtInputPort(DATA_ID)));
}

const FCKey key = {inDesc,
weightDesc,
biasDesc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
auto fc_input_b = pattern_map.at(weights_m);
bool is_convert = false;
if (auto convert_node = std::dynamic_pointer_cast<ov::op::v0::Convert>(fc_input_b.get_node_shared_ptr())) {
if (is_decompression(convert_node)) {
if (is_decompression(convert_node) || fp16_compression_is_disabled(convert_node) || is_compression(convert_node)) {
is_convert = true;
fc_input_b = convert_node->get_input_node_shared_ptr(0);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
// It cannot be static data, because it may be difference for different inferencePrecision
const auto precisions = get_convert_precisions();
if (inferencePrecision == ov::element::f16) {
CPU_REGISTER_PASS_ARM(manager, ov::pass::KeepConstFP32Unfolded);
precisions_map fp_convert_precision_map = {{ov::element::f32, ov::element::f16}};
type_to_fuse_map empty_fuse_map = {};
const bool keep_precision_sensitive_in_fp32 = true;
Expand Down
17 changes: 17 additions & 0 deletions src/plugins/intel_cpu/src/utils/cpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ inline std::vector<size_t> getNormalizedDimsBySize(const VectorDims &dims, size_
return normalizedDims;
}

/**
* @brief Clones passed shape and replaces one its dimention.
* @param originalShape
* shape to clone
* @param newDimValue
* new dimention value
* @param dim
* dimention index
* @return cloned shape
*/
inline Shape cloneShapeWithNewDim(Shape originalShape, Dim newDimValue, size_t dim) {
VectorDims newDims = originalShape.getDims();
assert(dim < newDims.size());
newDims[dim] = newDimValue;
return Shape(originalShape.getMinDims(), newDims);
}

/**
* @brief Checked that secondInputDims unidirectional broadcastable per tensor or per channel to firstInputDims
* @param firstInputDims
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ else()
file(GLOB_RECURSE TMP_LIST_OF_TEST_CLASSES ${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests/classes/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_COMMON_TEST_INSTANCES ${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests/instances/common/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_TEST_INSTANCES ${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests/instances/arm/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tests/arm/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tests/src/arm/*.cpp)
list(APPEND TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS
${TMP_LIST_OF_TEST_CLASSES} ${TMP_LIST_OF_COMMON_TEST_INSTANCES} ${TMP_LIST_OF_ARM_TEST_INSTANCES} ${TMP_LIST_OF_ARM_SUBGRAPH_TESTS})
set(TMP_EXPLICITLY_ENABLED_TESTS "${TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS}")
Expand Down
Loading
Loading