Skip to content

Conversation

@melkap01-Arm
Copy link

Key changes

This PR makes changes to improve the performance on Dynamic Qgemms by implementing tiling and threading across operations.

The changes introduce thread local buffers for reusing memory during inference. And utilizes those in Dynamic Quantised Matmul operations using Kleidiai kernels.

And updating KleidiAI version to 1.15.0

Example performance

single thread :
ort_ops_compare_encoder_1_2025-10-02_17-21-32_vs_encoder_1_2025-10-02_16-54-55

2 threads :
ort_ops_compare_encoder_2_2025-10-02_17-21-47_vs_encoder_2_2025-10-02_16-55-13

Signed-off-by: melkap01 <melike.kaptan@arm.com>
@melkap01-Arm
Copy link
Author

@microsoft-github-policy-service agree company="Arm"

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this update in the KAI version from 1.10 to 1.15, can SME/SME2 detection be enabled on Windows too to leverage the kernels ?

https://github.com/microsoft/onnxruntime/pull/25187/files#r2223006773
https://github.com/microsoft/onnxruntime/pull/25760/files#r2325260570

@patryk-kaiser-ARM
Copy link
Contributor

Can we get workflows ran please

@hariharans29
Copy link
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).


g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
}
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just do the resizing directly instead of reserve + resize ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.)
ort_ops_compare_2_thread_before_2025-10-29_13-08-56_vs_2_thread_after_2025-10-29_13-32-05

g_kai_tls_qgemm.output_tile.reserve(tile_elems);
}
// resize the tile to the required size (doesn't effect memory)
g_kai_tls_qgemm.output_tile.resize(tile_elems);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto - Is Reserve + Resize necessary ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

// Thread-local reusable buffers to reduce allocation overhead across tiles.
struct KaiTlsBuffersQgemm {
std::vector<float> output_tile;
std::vector<float> bias_zero;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is bias_zero used somewhere ?

Copy link
Author

@melkap01-Arm melkap01-Arm Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in the new commit

g_kai_tls_qgemm.output_tile.resize(tile_elems);
}
float* temp_tile = g_kai_tls_qgemm.output_tile.data();
std::fill_n(temp_tile, TileSizeM * TileSizeN, 0.0f);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this buffer zeroing absolutely needed (i.e.) Does the micro-kernel accumulate into the existing contents ?

Is there a concept of dis-reagrding existing contents in the output buffer in the micro-kernel's interface ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the fill_n, the kernel handles zeroing of the tile

LhsPackedData = g_kai_tls_qgemm.lhs_packed.data();

//Per-batch table of lhs
std::vector<const std::byte*> LhsBase(BatchSize);
Copy link
Member

@hariharans29 hariharans29 Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought - Can this vector containing the per-batch address be moved into the KaiTlsBuffersQgemm struct and be re-sized when it's size is less than the BatchSize ?

The pro of that approach:

  1. We generally expect the BatchSize to be stable across runs and that will mean we can do away with the dynamic memory allocation latency variance that comes with using std::vector

The con of that approach:

  1. The size of that caching vector will be bound by the highest batch size that the kernel will encounter.

Given that the batch sizes are generally stable across different runs, I am thinking the pro might outweight the con ?

What are your thoughts on this ?

Copy link
Author

@melkap01-Arm melkap01-Arm Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an idea worth to try and measure the impact.
I implemented it and the results with single thread:

ort_ops_compare_single_thread_before_2025-10-30_10-17-10_vs_single_thread_after_2025-10-30_14-24-15

and 2 threads :
ort_ops_compare_2_thread_before_2025-10-30_10-17-38_vs_2_thread_after_2025-10-30_14-24-29

After: Lhsbase is moved inside the TLS structure. Before: LhaBase is a local buffer shared with the threads.

here is the implementation:

      //Per-batch table of lhs
    if (g_kai_tls_qgemm.LhsBase.capacity() < BatchSize) {
        g_kai_tls_qgemm.LhsBase.reserve(BatchSize);
    }
      g_kai_tls_qgemm.LhsBase.resize(BatchSize);
    // Capture the shared batch table pointer so worker threads use the same backing storage.
    const std::byte** tls_lhs_base = g_kai_tls_qgemm.LhsBase.data();
      // B batches require no packing
        ⋮
          kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs);

        tls_lhs_base[batch_idx] = lhs;
      });
        ⋮

        const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
         auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect perf-wise there isn't much difference but it is coming from a performance variance POV. If we performed dynamic memory allocations on every Run(), I suspect we may see some latency variance. I was just wonderinf if this can be avoided as in most cases, usually the Gemm problem shapes stay the same across invocations. Let us dynamically resize only when we encounter a change of shape (batch size). Hope the motivation of the comment is clear now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Motivation behind the comment is clear, if we expect generally stable batches, reusing its capacity across calls is making sense. If the performance results also acceptable we are all good with this idea. Please find the implementation in the latest commit.


if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) {
lhs = static_cast<std::byte*>(DataParams->Workspace);
if (Shape.M == 0 || Shape.N == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a Shape.K check for completeness ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in the newest commit.

@hariharans29
Copy link
Member

General sanity check question: Are there enough tests that trigger all the nuances of the multi-threaded implementation - Are there enough tests with multiple batch sizes, M, and N dimensions that exercise all aspects of the multi-threaded implementation ?

return;
}
if ((Shape.M < m_step || Shape.N < n_step) && !DataParams->PackedB) {
// Fallback to MLAS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no fallback implementation of MlasDynamicQGemmBatch().

#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif
MLAS_UNREFERENCED_PARAMETER(Shape);
MLAS_UNREFERENCED_PARAMETER(DataParams);
MLAS_UNREFERENCED_PARAMETER(BatchN);
MLAS_UNREFERENCED_PARAMETER(ThreadPool);

if we get to this point, the computation should happen or (maybe less preferably) it should be a hard error.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will investigate the fallback case further and try to provide better implementation.
Until then, would like to get your opinion on using ORT_ENFORCE

ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)");

Copy link
Member

@hariharans29 hariharans29 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead implement @edgchen1's suggestion in the other PR: #26302 (comment) to have a universal check that can be used in all places to check if MLAS supports QGemm for that problem shape, platform, etc. ?

Also since we have a check on the M dimension, this might need some thinking - In the current setup, we turn off MLAS usage for QGemm in PrePack() if we don't detect SME or the weight's shape don't match requirements in PrePack(). See here and here. The M dimension won't be known in PrePack().

Just curious - what would happen if the M was < m_step ? Would there be a crash or would the perf be sub-optimal ? If so, we need to add a runtime check in the CPU kernel's Run() function which means we may need to perform pre-packing for both KAI and the "regular" path. See here.


// Final output tile pointer
float* dst_tile = reinterpret_cast<float*>(CTile);
std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the benefit of writing to a temporary buffer (temp_tile) and then copying it to dst_tile instead of directly writing to dst_tile?

Copy link
Author

@melkap01-Arm melkap01-Arm Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind it was making the arithmetics on the temporary tile to be error prone as it was implemented on the sgemms. But I see making the calculations on the destination and writing directly is lowering the complexity.

instead of having the result in each TLS and copying to the destination tile, destination tile can have the result directly.
Measuring the impact :
single thread:

ort_ops_compare_single_thread_no_fill_after_2025-10-30_15-28-05_vs_single_thread_no_temp_buffer_after_2025-10-30_16-39-27

2 threads :
ort_ops_compare_2_thread_no_fill_after_2025-10-30_15-28-19_vs_2_thread_no_temp_buffer_after_2025-10-30_16-39-43

unused variable removed,
unnecessary temp_tile use and copy removed,
K==0 case checked

Signed-off-by: melkap01 <melike.kaptan@arm.com>
@hariharans29
Copy link
Member

Will trigger CI once you push commits addressing the PR feedback (right now I only see a rebase). Thanks.

@melkap01-Arm
Copy link
Author

General sanity check question: Are there enough tests that trigger all the nuances of the multi-threaded implementation - Are there enough tests with multiple batch sizes, M, and N dimensions that exercise all aspects of the multi-threaded implementation ?

We checked the existing tests for qgemm. In current implementation tests are supported for thread pool = null. We created a follow up ticket for test coverage.

@hariharans29
Copy link
Member

General sanity check question: Are there enough tests that trigger all the nuances of the multi-threaded implementation - Are there enough tests with multiple batch sizes, M, and N dimensions that exercise all aspects of the multi-threaded implementation ?

We checked the existing tests for qgemm. In current implementation tests are supported for thread pool = null. We created a follow up ticket for test coverage.

If all the tests are with ThreadPool == null, does that mean the new threadpool based parallel code path(s) are not exercised ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants