New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dnn: add gemm_layer in place of fully_connected_layer for onnx models #23897
Merged
Merged
Changes from 55 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
6cf9631
first commit
fengyuentau d94d776
turned C from input to constant; force C constant in impl; better han…
fengyuentau 03476d5
integrate with gemm from ficus nn
fengyuentau d466f77
fix const inputs
fengyuentau f4c3640
adjust threshold for int8 tryQuantize
fengyuentau 7a19272
adjust threshold for int8 quantized 2
fengyuentau 05d0793
support batched gemm and matmul; tune threshold for rcnn_ilsvrc13; up…
fengyuentau fd14e6b
add gemm perf against innerproduct
fengyuentau 9d5ac58
add perf tests for innerproduct with bias
fengyuentau de0beac
fix perf
fengyuentau baac71d
add memset
fengyuentau ec613bc
renamings for next step
fengyuentau 8657065
add dedicated perf gemm
fengyuentau dfff691
add innerproduct in perf_gemm
fengyuentau 7cacfc0
remove gemm and innerproduct perf tests from perf_layer
fengyuentau 61e458e
add perf cases for vit sizes; prepack constants
fengyuentau 8164e3b
remove batched gemm; fix wrong trans; optimize KC
fengyuentau 3ed1d48
remove prepacking for const A; several fixes for const B prepacking
fengyuentau c24d944
add todos and gemm expression
fengyuentau c78a09e
add optimized branch for avx/avx2
fengyuentau e9301b7
trigger build
fengyuentau b5c4bc4
update macros and signature
fengyuentau 6a3cf14
update signature
fengyuentau 7d00e56
fix macro
fengyuentau 5c0897a
fix bugs for neon aarch64 & x64
fengyuentau a7b9c3a
add backends: cuda, cann, inf_ngraph and vkcom
fengyuentau 66eb2e2
fix cuda backend
fengyuentau 8eadede
test commit for cuda
fengyuentau 4ec306b
test cuda backend
fengyuentau d81dae6
remove debug message from cuda backend
fengyuentau d407727
use cpu dispatcher
fengyuentau 0697b49
fix neon macro undef in dispatcher
fengyuentau 8d94a23
fix dispatcher
fengyuentau 6ba3c9a
fix inner kernel for neon aarch64
fengyuentau 67ee373
fix compiling issue on armv7; try fixing accuracy issue on other plat…
fengyuentau 0e543f4
broadcast C with beta multiplied; improve func namings
fengyuentau fc35800
fix bug for avx and avx2
fengyuentau 66c3d47
put all platform-specific kernels in dispatcher
fengyuentau e843852
fix typos
fengyuentau 8a84865
attempt to fix compile issues on x64
fengyuentau 9a1747a
run old gemm when neon, avx, avx2 are all not available; add kernel f…
fengyuentau 2b307a7
fix typo
fengyuentau 4b5cd4b
quick fix: add macros for pack4
fengyuentau ae4247c
quick fix: use vmlaq_f32 for armv7
fengyuentau 5ae13a9
quick fix for missing macro of fast gemm pack f32 4
fengyuentau bf274bf
disable conformance tests when optimized branches are not supported
fengyuentau d00060e
disable perf tests when optimized branches are not supported
fengyuentau beaddba
decouple cv_try_neon and cv_neon_aarch64
fengyuentau efb7dab
drop googlenet_2023; add fastGemmBatched
fengyuentau 1275cd3
fix step in fastGemmBatched
fengyuentau 07cf1c5
cpu: fix initialization ofb; gpu: support batch
fengyuentau 78afd01
quick followup fix for cuda
fengyuentau d88577a
add default kernels
fengyuentau 8feb258
quick followup fix to avoid macro redef
fengyuentau e695285
optmized kernels for lasx
fengyuentau 235156c
resolve mis-alignment; remove comments
fengyuentau f614554
tune performance for x64 platform
fengyuentau c08dd61
tune performance for neon aarch64
fengyuentau c1406ca
tune for armv7
fengyuentau 02718dc
comment time consuming tests
fengyuentau a0f7379
quick follow-up fix
fengyuentau File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
// This file is part of OpenCV project. | ||
// It is subject to the license terms in the LICENSE file found in the top-level directory | ||
// of this distribution and at http://opencv.org/license.html. | ||
|
||
#include "perf_precomp.hpp" | ||
#include <opencv2/dnn/shape_utils.hpp> | ||
|
||
namespace opencv_test { | ||
|
||
struct GemmParam_t { | ||
std::vector<int> a_shape; | ||
std::vector<int> b_shape; | ||
std::vector<int> c_shape; | ||
bool trans_a; | ||
bool trans_b; | ||
|
||
GemmParam_t(std::vector<int> a_shape_, std::vector<int> b_shape_, std::vector<int> c_shape_ = {}, bool trans_a_ = false, bool trans_b_ = false) | ||
: a_shape(a_shape_), b_shape(b_shape_), c_shape(c_shape_), trans_a(trans_a_), trans_b(trans_b_) {} | ||
}; | ||
|
||
// TODO: Dsiable most of the test cases except vision transformers to save time | ||
static const GemmParam_t test_gemm_configs[] = { | ||
// vision transformers cases | ||
{ { 768, 768 }, { 768, 768 }, { 768 } }, | ||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } }, | ||
{ { 50, 768 }, { 768, 2304 } }, | ||
{ { 197, 768 }, { 768, 2304 } }, | ||
{ { 50, 1024 }, { 1024, 3072 } }, | ||
{ { 197, 1024 }, { 1024, 3072 } }, | ||
|
||
// square mat | ||
{ { 64, 64 }, { 64, 64 } }, | ||
{ { 128, 128 }, { 128, 128 } }, | ||
{ { 256, 256 }, { 256, 256 } }, | ||
{ { 512, 512 }, { 512, 512 } }, | ||
{ { 1024, 1024 }, { 1024, 1024 } }, | ||
{ { 4096, 4096 }, { 4096, 4096 } }, | ||
|
||
// retangular mat | ||
{ { 256, 256 }, { 256, 1024 } }, | ||
{ { 256, 1024 }, { 1024, 256 } }, | ||
{ { 256, 1024 }, { 1024, 1024 } }, | ||
{ { 1024, 1024 }, { 1024, 256 } }, | ||
{ { 1024, 256 }, { 256, 1024 } }, | ||
{ { 1024, 256 }, { 256, 256 } }, | ||
|
||
// with C | ||
{ { 256, 256 }, { 256, 256 }, { 256 } }, | ||
{ { 256, 256 }, { 256, 1024 }, { 1024 } }, | ||
{ { 256, 1024 }, { 1024, 256 }, { 256 } }, | ||
{ { 256, 1024 }, { 1024, 1024 }, { 1024 } }, | ||
// { { 1024, 1024 }, { 1024, 1024 }, { 1024 } }, | ||
{ { 1024, 1024 }, { 1024, 256 }, { 256 } }, | ||
{ { 1024, 256 }, { 256, 1024 }, { 1024 } }, | ||
{ { 1024, 256 }, { 256, 256 }, { 256 } }, | ||
|
||
// with C and trans_b | ||
{ { 256, 256 }, { 256, 256 }, { 256 } , false, true}, | ||
{ { 256, 1024 }, { 256, 1024 }, { 256 } , false, true}, | ||
{ { 256, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, | ||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, | ||
{ { 1024, 256 }, { 1024, 256 }, { 1024 } , false, true}, | ||
{ { 1024, 256 }, { 256, 256 }, { 256 } , false, true}, | ||
|
||
// with C and trans_b and trans_a | ||
{ { 256, 256 }, { 256, 256 }, { 256 } , true, true}, | ||
{ { 1024, 256 }, { 256, 1024 }, { 256 } , true, true}, | ||
{ { 256, 1024 }, { 1024, 256 }, { 1024 } , true, true}, | ||
{ { 1024, 1024 }, { 1024, 1024 }, { 1024 } , true, true}, | ||
}; | ||
|
||
struct GemmParamId | ||
{ | ||
enum { | ||
GEMM_0 = 0, | ||
GEMM_LAST = sizeof(test_gemm_configs) / sizeof(test_gemm_configs[0]) | ||
}; | ||
int val_; | ||
GemmParamId(int val = 0) : val_(val) {} | ||
operator int() const { return val_; } | ||
static ::testing::internal::ParamGenerator<GemmParamId> all() | ||
{ | ||
enum { NUM = (int)GEMM_LAST }; | ||
GemmParamId v_[NUM]; for (int i = 0; i < NUM; ++i) { v_[i] = GemmParamId(i); } // reduce generated code size | ||
return ::testing::ValuesIn(v_, v_ + NUM); | ||
} | ||
}; | ||
|
||
static inline void PrintTo(const GemmParamId& v, std::ostream* os) | ||
{ | ||
CV_Assert((int)v >= 0); CV_Assert((int)v < GemmParamId::GEMM_LAST); | ||
const GemmParam_t& p = test_gemm_configs[(int)v]; | ||
|
||
auto print_shape = [os](const std::vector<int>& shape, const std::string tag) { | ||
if (shape.empty()) { | ||
return ; | ||
} | ||
|
||
*os << tag << "=["; | ||
for (size_t i = 0; i < shape.size(); ++i) { | ||
if (i == shape.size() - 1) { | ||
*os << shape[i] << "]"; | ||
break; | ||
} | ||
*os << shape[i] << ", "; | ||
} | ||
}; | ||
|
||
print_shape(p.a_shape, "A"); | ||
print_shape(p.b_shape, ", B"); | ||
print_shape(p.c_shape, ", C"); | ||
*os << ", trans_a=" << p.trans_a << ", trans_b=" << p.trans_b; | ||
} | ||
|
||
typedef tuple<GemmParamId, tuple<Backend, Target> > GemmTestParam_t; | ||
typedef TestBaseWithParam<GemmTestParam_t> Gemm; | ||
|
||
PERF_TEST_P_(Gemm, gemm) | ||
{ | ||
int test_id = (int)get<0>(GetParam()); | ||
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); | ||
const GemmParam_t& params = test_gemm_configs[test_id]; | ||
auto a_shape = params.a_shape; | ||
auto b_shape = params.b_shape; | ||
auto c_shape = params.c_shape; | ||
auto trans_a = params.trans_a; | ||
auto trans_b = params.trans_b; | ||
float alpha = 1.f; | ||
float beta = 1.f; | ||
|
||
Backend backend_id = get<0>(get<1>(GetParam())); | ||
Target target_id = get<1>(get<1>(GetParam())); | ||
|
||
bool have_bias = c_shape.empty() ? false : true; | ||
|
||
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F); | ||
randu(A, -1.0f, 1.0f); | ||
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F); | ||
randu(A, -1.0f, 1.0f); | ||
|
||
LayerParams lp; | ||
lp.type = "Gemm"; | ||
lp.name = "testLayer"; | ||
lp.set("transA", trans_a); | ||
lp.set("transB", trans_b); | ||
lp.set("alpha", alpha); | ||
lp.set("beta", beta); | ||
lp.set("real_ndims_C", static_cast<int>(c_shape.size())); | ||
|
||
lp.set("constB", true); | ||
lp.blobs.push_back(B); | ||
if (have_bias) { | ||
Mat C(static_cast<int>(c_shape.size()), c_shape.data(), CV_32F); | ||
randu(C, -1.0f, 1.0f); | ||
lp.set("have_bias", true); | ||
lp.set("constC", true); | ||
lp.blobs.push_back(C); | ||
} | ||
|
||
Net net; | ||
int id = net.addLayerToPrev(lp.name, lp.type, lp); | ||
net.connect(0, 0, id, 0); | ||
net.setPreferableBackend(backend_id); | ||
net.setPreferableTarget(target_id); | ||
|
||
// warmup | ||
{ | ||
net.setInput(A); | ||
Mat out = net.forward(); | ||
} | ||
|
||
TEST_CYCLE() | ||
{ | ||
Mat res = net.forward(); | ||
} | ||
|
||
SANITY_CHECK_NOTHING(); | ||
} | ||
|
||
PERF_TEST_P_(Gemm, innerproduct) | ||
{ | ||
int test_id = (int)get<0>(GetParam()); | ||
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); | ||
const GemmParam_t& params = test_gemm_configs[test_id]; | ||
auto a_shape = params.a_shape; | ||
auto b_shape = params.b_shape; | ||
auto c_shape = params.c_shape; | ||
auto trans_a = params.trans_a; | ||
auto trans_b = params.trans_b; | ||
|
||
Backend backend_id = get<0>(get<1>(GetParam())); | ||
Target target_id = get<1>(get<1>(GetParam())); | ||
|
||
bool have_bias = c_shape.empty() ? false : true; | ||
|
||
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F); | ||
randu(A, -1.0f, 1.0f); | ||
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F); | ||
randu(A, -1.0f, 1.0f); | ||
|
||
LayerParams lp; | ||
lp.type = "InnerProduct"; | ||
lp.name = "testLayer"; | ||
if (trans_a) { | ||
cv::transpose(A, A); | ||
} | ||
if (!trans_b) { | ||
cv::transpose(B, B); | ||
} | ||
lp.blobs.push_back(B); | ||
lp.set("num_output", B.size[0]); | ||
if (have_bias) { | ||
Mat C(static_cast<int>(c_shape.size()), c_shape.data(), CV_32F); | ||
randu(C, -1.0f, 1.0f); | ||
lp.blobs.push_back(C); | ||
lp.set("bias_term", true); | ||
} else { | ||
lp.set("bias_term", false); | ||
} | ||
|
||
Net net; | ||
int id = net.addLayerToPrev(lp.name, lp.type, lp); | ||
net.connect(0, 0, id, 0); | ||
net.setPreferableBackend(backend_id); | ||
net.setPreferableTarget(target_id); | ||
|
||
// warmup | ||
{ | ||
std::vector<std::string> input_names(2); | ||
input_names[0] = "A"; | ||
net.setInputsNames(input_names); | ||
net.setInput(A, input_names[0]); | ||
Mat out = net.forward(); | ||
} | ||
|
||
TEST_CYCLE() | ||
{ | ||
Mat res = net.forward(); | ||
} | ||
|
||
SANITY_CHECK_NOTHING(); | ||
} | ||
|
||
INSTANTIATE_TEST_CASE_P(/**/, Gemm, Combine( | ||
GemmParamId::all(), | ||
dnnBackendsAndTargets(false, false) // defined in ../test/test_common.hpp | ||
)); | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually don't use runtime dispatching with NEON (it doesn't work due to different ABI).
Whole library is compiled with NEON instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing via #24315