From 3a1398316d9faa3a5fe694959e0ac2037ee71361 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Fri, 21 Jun 2024 21:35:34 +0200 Subject: [PATCH] Add FullyConnected ACL executor --- .../src/nodes/executors/acl/acl_eltwise.cpp | 121 +++------------ .../src/nodes/executors/acl/acl_executor.cpp | 81 ++++++++++ .../src/nodes/executors/acl/acl_executor.hpp | 52 +++++++ .../executors/acl/acl_fullyconnected.cpp | 92 ++++++++++++ .../executors/acl/acl_fullyconnected.hpp | 37 +++++ .../src/nodes/executors/acl/acl_utils.cpp | 42 +++++- .../src/nodes/executors/acl/acl_utils.hpp | 32 ++-- .../src/nodes/executors/debug_messages.hpp | 1 + .../fullyconnected_implementations.cpp | 45 ++++++ .../nodes/executors/implementation_utils.hpp | 5 + .../single_layer_tests/classes/matmul.cpp | 6 + .../instances/arm/matmul.cpp | 142 ++++++++++++++++++ 12 files changed, 538 insertions(+), 118 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp create mode 100644 src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp index e22b1493f36ae2..4d9f6e4680141a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp @@ -361,66 +361,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseRelu: - if (aclEltwiseAttrs.alpha == 0) { - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - ActivationLayerInfo::ActivationFunction::RELU)) - return false; - } else { - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::LEAKY_RELU, aclEltwiseAttrs.alpha})) - return false; - } - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - if (aclEltwiseAttrs.alpha == 0) { - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::RELU); - } else { - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::LEAKY_RELU, aclEltwiseAttrs.alpha}); - } - return acl_op; - }; - break; - case Algorithm::EltwiseGeluErf: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::GELU)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::GELU); - return acl_op; - }; - break; - case Algorithm::EltwiseElu: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::ELU, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], {ActivationLayerInfo::ActivationFunction::ELU, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; - case Algorithm::EltwiseTanh: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}); - return acl_op; - }; - break; - case Algorithm::EltwiseSigmoid: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::LOGISTIC)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::LOGISTIC); - return acl_op; - }; - break; case Algorithm::EltwiseAbs: if (!NEAbsLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) return false; @@ -430,24 +370,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseSqrt: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::SQRT)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::SQRT); - return acl_op; - }; - break; - case Algorithm::EltwiseSoftRelu: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::SOFT_RELU)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::SOFT_RELU); - return acl_op; - }; - break; case Algorithm::EltwiseExp: if (!NEExpLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) return false; @@ -457,28 +379,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseClamp: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, aclEltwiseAttrs.beta, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, aclEltwiseAttrs.beta, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; - case Algorithm::EltwiseSwish: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::SWISH, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::SWISH, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; case Algorithm::EltwisePrelu: if (!NEPReluLayer::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) return false; @@ -488,12 +388,29 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: case Algorithm::EltwiseHswish: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::HARD_SWISH)) + if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], + getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma))) return false; exec_func = [this]() -> std::unique_ptr { auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::HARD_SWISH); + acl_op->configure(&srcTensors[0], &dstTensors[0], + getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma)); return acl_op; }; break; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.cpp new file mode 100644 index 00000000000000..00dee4bda245c3 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_executor.hpp" +#include "acl_utils.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "utils/debug_capabilities.h" + +namespace ov { +namespace intel_cpu { + +ACLMemoryInfo ACLCommonExecutor::initTensorInfo(const MemoryPtr& memoryPtr, ACLTensorAttrs attrs) { + auto acl_tensor_type = precisionToAclDataType(memoryPtr->getPrecision()); + auto acl_tensor_layout = getAclDataLayoutByMemoryDesc(memoryPtr->getDescPtr()); + + ACLMemoryInfo aclMemoryInfo = nullptr; + if (acl_tensor_type != arm_compute::DataType::UNKNOWN) { + auto collapsed_dims = collapse_dims_to_max_rank(memoryPtr->getStaticDims(), attrs.maxDimsShape); + auto acl_tensor_shape = shapeCast(collapsed_dims); + if (attrs.hasLayoutTypeNHWC) { + changeLayoutToNH_C({&acl_tensor_shape}); + } + aclMemoryInfo = std::make_shared( + acl_tensor_shape, 1, + acl_tensor_type, + acl_tensor_layout); + } + return aclMemoryInfo; +} + +ACLMemory ACLCommonExecutor::initTensor(const ACLMemoryInfo& aclMemoryInfo) { + ACLMemory aclMemory = nullptr; + if (aclMemoryInfo) { + aclMemory = std::make_shared(); + aclMemory->allocator()->init(*aclMemoryInfo); + } + return aclMemory; +} + +bool ACLCommonExecutor::update(const MemoryArgs &memory) { + for (auto& cpu_mem_ptr : memory) { + // Initialize arm_compute::TensorInfo object + auto aclTensorInfo = initTensorInfo(cpu_mem_ptr.second, aclTensorAttrs); + // Initialize arm_compute::Tensor object + aclMemoryMap[cpu_mem_ptr.first] = initTensor(aclTensorInfo); + } + + // Update arm_compute::TensorInfo objects for specific ACL function + auto tensorsInfoValidateStatus = updateTensorsInfo(aclMemoryMap); + if (!tensorsInfoValidateStatus) { + DEBUG_LOG("ACL operator validation was failed: ", tensorsInfoValidateStatus.error_description()); + return false; + } + + // Configure arm_compute::IFunction object + configureThreadSafe([&] { + iFunction = configureFunction(aclMemoryMap); + }); + return true; +} + +void ACLCommonExecutor::execute(const MemoryArgs &memory) { + for (auto& acl_tensor : aclMemoryMap) { + if (acl_tensor.second) { + acl_tensor.second->allocator()->import_memory(memory.at(acl_tensor.first)->getData()); + } + } + iFunction->run(); +} + +ACLCommonExecutor::~ACLCommonExecutor() { + for (auto& acl_tensor : aclMemoryMap) { + if (acl_tensor.second) { + acl_tensor.second->allocator()->free(); + } + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.hpp new file mode 100644 index 00000000000000..1b608a8d06115e --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_executor.hpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu_memory.h" +#include "nodes/executors/executor.hpp" +#include "arm_compute/runtime/NEON/NEFunctions.h" + +namespace ov { +namespace intel_cpu { + +using ACLMemoryInfo = std::shared_ptr; +using ACLMemory = std::shared_ptr; +using ACLMemoryMap = std::unordered_map; +using ACLFunction = std::unique_ptr; + +struct ACLTensorAttrs { + bool hasLayoutTypeNHWC = false; + size_t maxDimsShape = arm_compute::MAX_DIMS; +}; + +class ACLCommonExecutor : public Executor { +public: + virtual arm_compute::Status updateTensorsInfo(const ACLMemoryMap& acl_memory) { + OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'updateTensorsInfo' method is not implemented by executor"); + } + virtual ACLFunction configureFunction(const ACLMemoryMap& acl_memory) { + OPENVINO_THROW_NOT_IMPLEMENTED("This version of the 'configureFunction' method is not implemented by executor"); + } + impl_desc_type implType() const override { + return impl_desc_type::acl; + } + void execute(const MemoryArgs& memory) final; + bool update(const MemoryArgs& memory) final; + ~ACLCommonExecutor(); + +protected: + ACLTensorAttrs aclTensorAttrs; + +private: + ACLMemoryMap aclMemoryMap; + ACLFunction iFunction = nullptr; + static ACLMemoryInfo initTensorInfo(const MemoryPtr& memoryPtr, ACLTensorAttrs attrs); + static ACLMemory initTensor(const ACLMemoryInfo& aclMemoryInfo); +}; + +using ACLCommonExecutorPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp new file mode 100644 index 00000000000000..eb628cdeb19c96 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_fullyconnected.hpp" +#include "acl_utils.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "utils/debug_capabilities.h" +#include "nodes/executors/debug_messages.hpp" +#include "nodes/executors/implementation_utils.hpp" + +namespace ov { +namespace intel_cpu { + +ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, const PostOps &postOps, + const MemoryArgs &memory, + const ExecutorContext::CPtr context) { + aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc); + fullyConnectedLayerInfo.weights_trained_layout = getAclDataLayoutByMemoryDesc(memory.at(ARG_WEI)->getDescPtr()); + fullyConnectedLayerInfo.transpose_weights = !attrs.weightsNonTransposed; + if (memory.at(ARG_SRC)->getPrecision() == ov::element::f16) { + fullyConnectedLayerInfo.fp_mixed_precision = true; + } + + // Add postops + if (!postOps.empty() && postOps.size() == 1) { + if (const auto activation = std::dynamic_pointer_cast(postOps[0])) { + fullyConnectedLayerInfo.activation_info = getActivationLayerInfo(convertToEltwiseAlgorithm(activation->type()), + activation->alpha(), + activation->beta(), + activation->gamma()); + } + } +} + +bool ACLFullyConnectedExecutor::supports(const FCConfig &config) { + VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_SRC_PRECISIONS); + VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS); + VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK); + VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_SRC_RANK); + return true; +} + +arm_compute::Status ACLFullyConnectedExecutor::updateTensorsInfo(const ACLMemoryMap& acl_memory) { + auto wei_shape = acl_memory.at(ARG_WEI)->info()->tensor_shape(); + if (wei_shape.num_dimensions() == 3U) { + acl_memory.at(ARG_WEI)->info()->set_tensor_shape({wei_shape[0] * wei_shape[1], wei_shape[2]}); + } + + auto src_shape = acl_memory.at(ARG_SRC)->info()->tensor_shape(); + if (one_of(src_shape.num_dimensions(), 3U, 4U)) { + acl_memory.at(ARG_SRC)->info()->set_tensor_shape({ + acl_memory.at(ARG_WEI)->info()->tensor_shape()[0], + src_shape.total_size() / acl_memory.at(ARG_WEI)->info()->tensor_shape()[0]}); + } + + if (one_of(acl_memory.at(ARG_DST)->info()->tensor_shape().num_dimensions(), 3U, 4U)) { + acl_memory.at(ARG_DST)->info()->set_tensor_shape({ + acl_memory.at(ARG_WEI)->info()->tensor_shape()[1], + acl_memory.at(ARG_SRC)->info()->tensor_shape()[1]}); + } + + if (!fullyConnectedLayerInfo.transpose_weights) { + arm_compute::TensorShape temp_weights_shape = acl_memory.at(ARG_WEI)->info()->tensor_shape(); + std::swap(temp_weights_shape[0], temp_weights_shape[1]); + acl_memory.at(ARG_WEI)->info()->set_tensor_shape(temp_weights_shape); + } + + return arm_compute::NEFullyConnectedLayer::validate( + acl_memory.at(ARG_SRC)->info(), + acl_memory.at(ARG_WEI)->info(), + acl_memory.at(ARG_BIAS) ? acl_memory.at(ARG_BIAS)->info() : nullptr, + acl_memory.at(ARG_DST)->info(), + fullyConnectedLayerInfo, + weightsInfo); +} + +ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLMemoryMap& acl_memory) { + auto neFC = std::make_unique(); + neFC->configure( + acl_memory.at(ARG_SRC).get(), + acl_memory.at(ARG_WEI).get(), + acl_memory.at(ARG_BIAS).get(), + acl_memory.at(ARG_DST).get(), + fullyConnectedLayerInfo, + weightsInfo); + return neFC; +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp new file mode 100644 index 00000000000000..d1acb49d85fdef --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "acl_executor.hpp" +#include "nodes/executors/fullyconnected_config.hpp" + +namespace ov { +namespace intel_cpu { + +class ACLFullyConnectedExecutor : public ACLCommonExecutor { +public: + ACLFullyConnectedExecutor(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context); + + static bool supports(const FCConfig& config); + + arm_compute::Status updateTensorsInfo(const ACLMemoryMap& acl_memory) override; + + ACLFunction configureFunction(const ACLMemoryMap& acl_memory) override; + + impl_desc_type implType() const override { + return impl_desc_type::gemm_acl; + } +private: + arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; + arm_compute::WeightsInfo weightsInfo; +}; + +using ACLFullyConnectedExecutorPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp index c2dfecbf57106c..64b3b23215c6c9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp @@ -5,9 +5,49 @@ #include "acl_utils.hpp" #include "support/Mutex.h" -void ov::intel_cpu::configureThreadSafe(const std::function& config) { +namespace ov { +namespace intel_cpu { + +void configureThreadSafe(const std::function& config) { // Issue: CVS-123514 static arm_compute::Mutex mtx_config; arm_compute::lock_guard _lock{mtx_config}; config(); } + +arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, + float alpha = 0.0, + float beta = 0.0, + float gamma = 0.0) { + switch (algorithm) { + case Algorithm::EltwiseRelu: + if (alpha == 0) { + return arm_compute::ActivationLayerInfo::ActivationFunction::RELU; + } else { + return {arm_compute::ActivationLayerInfo::ActivationFunction::LEAKY_RELU, alpha}; + } + case Algorithm::EltwiseGeluErf: + return arm_compute::ActivationLayerInfo::ActivationFunction::GELU; + case Algorithm::EltwiseElu: + return {arm_compute::ActivationLayerInfo::ActivationFunction::ELU, alpha}; + case Algorithm::EltwiseTanh: + return {arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}; + case Algorithm::EltwiseSigmoid: + return arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC; + case Algorithm::EltwiseSqrt: + return arm_compute::ActivationLayerInfo::ActivationFunction::SQRT; + case Algorithm::EltwiseSoftRelu: + return arm_compute::ActivationLayerInfo::ActivationFunction::SOFT_RELU; + case Algorithm::EltwiseClamp: + return {arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, beta, alpha}; + case Algorithm::EltwiseSwish: + return {arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, alpha}; + case Algorithm::EltwiseHswish: + return arm_compute::ActivationLayerInfo::ActivationFunction::HARD_SWISH; + default: + OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", static_cast(algorithm)); + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp index b3077d4c16e342..de9eed5a96bcb5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp @@ -5,6 +5,7 @@ #include "memory_desc/cpu_memory_desc.h" #include "arm_compute/core/Types.h" +#include "cpu_types.h" namespace ov { namespace intel_cpu { @@ -15,15 +16,14 @@ namespace intel_cpu { * @param dims vector of dimensions to squash * @return vector of dimensions that complies to ACL */ -inline VectorDims collapse_dims_to_max_rank(VectorDims dims) { - const size_t MAX_NUM_SHAPE = arm_compute::MAX_DIMS; - VectorDims result_dims(MAX_NUM_SHAPE - 1); - if (dims.size() >= MAX_NUM_SHAPE) { - for (size_t i = 0; i < MAX_NUM_SHAPE - 1; i++) { +inline VectorDims collapse_dims_to_max_rank(VectorDims dims, size_t max_num_shape = arm_compute::MAX_DIMS) { + VectorDims result_dims(max_num_shape - 1); + if (dims.size() >= max_num_shape) { + for (size_t i = 0; i < max_num_shape - 1; i++) { result_dims[i] = dims[i]; } - for (size_t i = MAX_NUM_SHAPE - 1; i < dims.size(); i++) { - result_dims[MAX_NUM_SHAPE - 2] *= dims[i]; + for (size_t i = max_num_shape - 1; i < dims.size(); i++) { + result_dims[max_num_shape - 2] *= dims[i]; } } else { result_dims = dims; @@ -51,7 +51,7 @@ inline void changeLayoutToNH_C(const std::vector &_li } /** -* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL +* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL * @param dims vector of dimensions to convert * @return ComputeLibrary TensorShape object */ @@ -96,13 +96,6 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis } } -inline Dim vectorProduct(const VectorDims& vec, size_t size) { - Dim prod = 1; - for (size_t i = 0; i < size; ++i) - prod *= vec[i]; - return prod; -} - /** * @brief Return ComputeLibrary DataType that corresponds to the given precision * @param precision precision to be converted @@ -150,5 +143,14 @@ inline arm_compute::DataLayout getAclDataLayoutByMemoryDesc(MemoryDescCPtr desc) */ void configureThreadSafe(const std::function& config); +/** +* @brief get ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. +* @param algorithm activation function of openvino representation +* @param alpha alpha coefficient for algorithm +* @param beta beta coefficient for algorithm +* @param gamma gamma coefficient for algorithm +*/ +arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, float alpha, float beta, float gamma); + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp index 2ee407564bf957..46339304f7c635 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp @@ -7,6 +7,7 @@ #define UNSUPPORTED_SPARSE_WEIGHTS " sparse weights are not supported" #define UNSUPPORTED_WEIGHTS_DECOMPRESSION " weights decompression is not supported" #define UNSUPPORTED_POST_OPS " post ops are not supported" +#define UNSUPPORTED_NUMBER_OF_POSTOPS " post ops numbers are not supported" #define UNSUPPORTED_SRC_PRECISIONS " unsupported src precisions" #define UNSUPPORTED_WEI_PRECISIONS " unsupported wei precisions" #define UNSUPPORTED_DST_PRECISIONS " unsupported dst precisions" diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index f6ecbba58147ba..36b653baf803f2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -26,6 +26,10 @@ #include "ov_optional.hpp" #include "utils/cpp/maybe_unused.hpp" +#if defined(OV_CPU_WITH_ACL) +#include "nodes/executors/acl/acl_fullyconnected.hpp" +#endif + namespace ov { namespace intel_cpu { @@ -37,6 +41,7 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A using LayoutConfig = std::vector; static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; +static const LayoutConfig aclFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; template struct Require { @@ -70,10 +75,20 @@ static const TypeMapping dnnlFCTypeMapping { // @todo explicitly cover configuration limitations for oneDNN on ARM }; +static const TypeMapping aclFCTypeMapping { + // {src, wei, bia, dst} pt + {{_f32 | _f16, _any, _any, _any}, pt(bypass(), use<0>(), use<0>(), use<0>())}, + {{_any, _any, _any, _any}, pt(just(), just(), just(), just())} +}; + static const MappingNotation dnnlConvolutionMappingNotation { ARG_SRC, ARG_WEI, ARG_BIAS, ARG_DST }; +static const MappingNotation aclFullyConnectedMappingNotation { + ARG_SRC, ARG_WEI, ARG_BIAS, ARG_DST +}; + static const TypeMapping dnnlConvolutionTypeMapping { // {src, wei, bia, dst} pt {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, @@ -301,6 +316,36 @@ const std::vector>& getImplementations() { context, false); }) + OV_CPU_INSTANCE_ACL( + "fullyconnected_acl", + ExecutorType::Acl, + OperationType::FullyConnected, + ShapeTolerance::Agnostic, + // supports + [](const FCConfig& config) -> bool { + VERIFY(noSparseDecompression(config), UNSUPPORTED_SPARSE_WEIGHTS); + VERIFY(noWeightsDecompression(config), UNSUPPORTED_WEIGHTS_DECOMPRESSION); + return ACLFullyConnectedExecutor::supports(config); + }, + // requiresFallback + [](const FCConfig& config) -> ov::optional> { + return requiresFallbackCommon(config, + aclFCTypeMapping, + aclFCLayoutConfig, + aclFullyConnectedMappingNotation); + }, + // acceptsShapes + [](const MemoryArgs& memory) -> bool { + // @todo create syntactic sugar (functor) for shape agnostic lambda + return true; + }, + // create + [](const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context) { + return std::make_shared(attrs, postOps, memory, context); + }) OV_CPU_INSTANCE_DNNL( "fullyconnected_dnnl", ExecutorType::Dnnl, diff --git a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp index 2382f5e4091a9f..cd029283a09c50 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp @@ -83,5 +83,10 @@ size_t weiMemSize(const Config& config) { return memSize(config); } +template +size_t postOpsNumbers(const Config& config) { + return config.postOps.size(); +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp index 3e55d3368cefb3..88e06c7903ba90 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp @@ -121,6 +121,12 @@ void MatMulLayerCPUTest::SetUp() { if (it != additionalConfig.end() && it->second.as() == ov::element::bf16) { inType = outType = netType = ElementType::bf16; rel_threshold = abs_threshold = 1e-2f; + } else if (it != additionalConfig.end() && it->second.as() == ov::element::f16) { + inType = outType = netType = ElementType::f16; + // rel_threshold = abs_threshold = 1e-2f; + // Temporarily created the following rel_threshold because of this bug CVS-144523 and + // https://github.com/ARM-software/ComputeLibrary/issues/1112 + rel_threshold = abs_threshold = 3e-1f; } else { inType = outType = netType; rel_threshold = 1e-4f; diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp new file mode 100644 index 00000000000000..e43a94b9490054 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "custom/single_layer_tests/classes/matmul.hpp" +#include "utils/cpu_test_utils.hpp" +#include "utils/filter_cpu_info.hpp" +#include "utils/fusing_test_utils.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { +namespace MatMul { +/* ============= MatMul ============= */ +namespace matmul { + +std::vector fusingParamsSet2D_smoke { + emptyFusingSpec, + fusingBias, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; + +const auto testParams2D_smoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_smoke()), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet2D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulLayerCPUTest, testParams2D_smoke, MatMulLayerCPUTest::getTestCaseName); + + +std::vector fusingParamsSet2D_smoke_f16 { + emptyFusingSpec, + fusingBias, + fusingRelu +}; +const auto testParams2D_smoke_f16 = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_smoke()), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet2D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_f16, MatMulLayerCPUTest, testParams2D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +std::vector fusingParamsSet3D_smoke { + emptyFusingSpec, + fusingBias, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; +const auto fullyConnectedParams3D_smoke = ::testing::Combine(::testing::ValuesIn(IS3D_smoke()), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())); +std::vector fusingParamsSet3D_smoke_f16 { + emptyFusingSpec, + fusingBias, + fusingRelu +}; +const auto fullyConnectedParams3D_smoke_f16 = ::testing::Combine(::testing::ValuesIn(IS3D_smoke()), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))); +const auto testParams3D_smoke = ::testing::Combine(fullyConnectedParams3D_smoke, + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet3D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +const auto testParams3D_smoke_f16 = ::testing::Combine(fullyConnectedParams3D_smoke_f16, + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet3D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulLayerCPUTest, testParams3D_smoke, MatMulLayerCPUTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_f16, MatMulLayerCPUTest, testParams3D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +const std::vector IS = { + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, false}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, false}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, true}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, true}}, +}; + +std::vector fusingParamsSet4D_smoke { + emptyFusingSpec, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; + +const auto testParams4D_smoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet4D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerCPUTest, testParams4D_smoke, MatMulLayerCPUTest::getTestCaseName); + +std::vector fusingParamsSet4D_smoke_f16 { + emptyFusingSpec, + fusingRelu +}; + +const auto testParams4D_smoke_f16 = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet4D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParams()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_4D_f16, MatMulLayerCPUTest, testParams4D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +} // namespace matmul +} // namespace MatMul +} // namespace test +} // namespace ov