Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnn: refactor ONNX MatMul with fastGemm #24694

Merged
merged 20 commits into from Dec 19, 2023

Conversation

fengyuentau
Copy link
Member

@fengyuentau fengyuentau commented Dec 13, 2023

Done:

  • add backends
    • CUDA
    • OpenVINO
    • CANN
    • OpenCL
    • Vulkan
  • add perf tests
  • const B case

Benchmark

Tests are done on M1. All data is in milliseconds (ms).

Configuration MatMul (Prepacked) MatMul InnerProduct
A=[12, 197, 197], B=[12, 197, 64], trans_a=0, trans_b=0 0.39 0.41 1.33
A=[12, 197, 64], B=[12, 64, 197], trans_a=0, trans_b=0 0.42 0.42 1.17
A=[12, 50, 64], B=[12, 64, 50], trans_a=0, trans_b=0 0.13 0.15 0.33
A=[12, 50, 50], B=[12, 50, 64], trans_a=0, trans_b=0 0.11 0.13 0.22
A=[16, 197, 197], B=[16, 197, 64], trans_a=0, trans_b=0 0.46 0.54 1.46
A=[16, 197, 64], B=[16, 64, 197], trans_a=0, trans_b=0 0.46 0.95 1.74
A=[16, 50, 64], B=[16, 64, 50], trans_a=0, trans_b=0 0.18 0.32 0.43
A=[16, 50, 50], B=[16, 50, 64], trans_a=0, trans_b=0 0.15 0.25 0.25

Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

  • I agree to contribute to the project under Apache 2 License.
  • To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
  • The PR is proposed to the proper branch
  • There is a reference to the original bug report and related work
  • There is accuracy test, performance test and test data in opencv_extra repository, if applicable
    Patch to opencv_extra has the same branch name.
  • The feature is well documented and sample code can be built with the project CMake

@fengyuentau fengyuentau marked this pull request as draft December 13, 2023 09:17
@fengyuentau
Copy link
Member Author

The previous performance results of InnerProduct was from its branch of calling BLAS. Now I put B as constant input so as to call the non-blas branch. Results show that fastGemm is generally better than FullyConnected acceleration.

@fengyuentau fengyuentau marked this pull request as ready for review December 19, 2023 09:19
@fengyuentau
Copy link
Member Author

All todo items are checked!

@vpisarev vpisarev self-requested a review December 19, 2023 12:50
@vpisarev
Copy link
Contributor

@asmorkalov, this PR looks good to me. It needs to be merged in order to merge the other important PR, #24476.

@asmorkalov
Copy link
Contributor

@dkurt Please join the review too.

int total_tiles = m_tiles * n_tiles;

auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenCV AutoBuffer makes sense here: https://docs.opencv.org/4.x/d8/dd0/classcv_1_1AutoBuffer.html. No problems with memory leaks and it has built-in logic for alloca.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a temporary buffer is usually small (a few K's of memory)

AutoBuffer says something like this. A typical buff_size would be

FAST_GEMM_F32_PACKED_STRIDE_K * (FAST_GEMM_F32_MC + FAST_GEMM_F32_NC) * 4 / 1024
= 64 * (144 + 72) * 4 / 1024 = 54 KB

Is 54 KB still considered to be a few KBs?

int total_tiles = m_tiles * n_tiles;

auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same idea for AutoBuffer.

int total_tiles = m_tiles * n_tiles;

auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoBuffer.

int total_tiles = m_tiles * n_tiles;

auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoBuffer

Comment on lines +440 to +452
half **dev_C_slices = 0;
cudaMalloc((void**)&dev_A_slices, batch_count * sizeof(half*));
cudaMalloc((void**)&dev_B_slices, batch_count * sizeof(half*));
cudaMalloc((void**)&dev_C_slices, batch_count * sizeof(half*));
cudaMemcpy(dev_A_slices, A_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_B_slices, B_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_C_slices, C_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);

CUDA4DNN_CHECK_CUBLAS(cublasHgemmBatched(handle.get(), opa, opb, iM, iN, iK, &alpha, dev_A_slices, ilda, dev_B_slices, ildb, &beta, dev_C_slices, ildc, batch_count));

cudaFree(dev_A_slices);
cudaFree(dev_B_slices);
cudaFree(dev_C_slices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional optimization with streams is possible. E.g. create stream, use cudaMemcopyAsync and cublasSetStream(). It reduces amount of CPU-GPU syncs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have examples demonstrating how to use these two APIs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, Linux-RISC-V-Clang seems to have trouble starting jobs.

Comment on lines +491 to +503
cudaMalloc((void**)&dev_A_slices, batch_count * sizeof(float*));
cudaMalloc((void**)&dev_B_slices, batch_count * sizeof(float*));
cudaMalloc((void**)&dev_C_slices, batch_count * sizeof(float*));
cudaMemcpy(dev_A_slices, A_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_B_slices, B_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_C_slices, C_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);

// cuBLAS is column-major
CUDA4DNN_CHECK_CUBLAS(cublasSgemmBatched(handle.get(), opa, opb, iM, iN, iK, &alpha, dev_A_slices, ilda, dev_B_slices, ildb, &beta, dev_C_slices, ildc, batch_count));

cudaFree(dev_A_slices);
cudaFree(dev_B_slices);
cudaFree(dev_C_slices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same optional recommendation here.

dkurt
dkurt previously approved these changes Dec 19, 2023
@dkurt dkurt dismissed their stale review December 19, 2023 15:33

attention layer

@asmorkalov asmorkalov merged commit fa5ed62 into opencv:4.x Dec 19, 2023
24 of 26 checks passed
@fengyuentau fengyuentau deleted the matmul_refactor branch December 19, 2023 16:45
@asmorkalov asmorkalov mentioned this pull request Jan 19, 2024
@fengyuentau fengyuentau mentioned this pull request Feb 21, 2024
48 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants