Skip to content

[CPU/CUDA ep] Improve DeformConv op performance#27824

Merged
tianleiwu merged 54 commits intomicrosoft:mainfrom
ShirasawaSama:feature/improve-deform-conv-pref
Apr 9, 2026
Merged

[CPU/CUDA ep] Improve DeformConv op performance#27824
tianleiwu merged 54 commits intomicrosoft:mainfrom
ShirasawaSama:feature/improve-deform-conv-pref

Conversation

@ShirasawaSama
Copy link
Copy Markdown
Contributor

@ShirasawaSama ShirasawaSama commented Mar 24, 2026

Description

Improve DeformConv op performance

Motivation and Context

This PR consolidates a series of optimizations targeting the DeformConv (Deformable Convolution) operator across both CPU and CUDA execution providers.

  • For CPU: The previous implementation suffered from bottlenecks due to redundant computations, lack of vectorization in bilinear sampling, and sub-optimal thread pool utilization. This overhaul redesigns the memory layout and execution pipeline to maximize SIMD opportunities and harden memory safety.
  • For GPU: The batched GEMM operation previously relied on an intermediate buffer and a custom scatter kernel to format the output, which consumed extra memory and kernel launch overhead. This update introduces a zero-copy approach.

1. CPU Optimizations & Refactoring

The CPU execution path has been heavily refactored to minimize branching in hot paths, maximize vectorization, and safely handle edge cases.

Feature / Optimization Description Key Benefit
AoSoA Bilinear Sampling Plan Replaced on-the-fly interpolation with a precomputed sampling plan using an 8-lane Array-of-Structures-of-Arrays (AoSoA) layout (kPlanAoSoALanes). Perfectly aligns with 256-bit AVX2 vectors, enabling highly efficient SIMD unrolling during the im2col gathering phase.
Kernel Metadata Caching Introduced DeformConvKernelMetaCacheData to cache static convolution geometry (e.g., kH, kW, padding, dilation). Eliminates the O(kernel_size) overhead of reallocating and recomputing base offsets on every single Compute() step.
Fast Math & Branchless Logic Implemented a custom DeformConvFastFloor and utilized an inverted bounds check with bitwise operations to evaluate all four corners simultaneously. Removes expensive std::floor calls and unpredictable branches from the operator's hottest path.
Enhanced Parallelization Flattened the bilinear sampling plan build tasks across spatial pixels. Allows concurrency::ThreadPool::TryParallelFor to split fine-grained work effectively, drastically improving thread pool scaling.
Hardened Bounds Checking Introduced compute-time bounds checks using CheckedMulSizeT and CheckedBatchSpan. Ensures batch indexing and stride calculations stay within the addressable size_t range, preventing integer overflow vulnerabilities.
Bias Addition Refactoring Refactored bias addition to avoid expensive div/mod operations, applying ORT_CPU_RESTRICT and force-inlining. Maximizes memory throughput and instruction pipelining during the final bias addition phase.

2. GPU (CUDA) Optimizations

The CUDA implementation was optimized to reduce memory footprint and eliminate unnecessary kernel launches.

  • Zero-Copy GEMM Output: Removed the temporary gemm_output_buffer allocation entirely. By carefully configuring the stride_c parameter (stride_c_y = M * output_image_size), the cublasGemmStridedBatchedHelper now writes the computed output directly into the correct NCHW memory layout of the final Y tensor.
  • Kernel Elimination: Completely removed the DeformConvCopyGemmOutputRowMajorToNCHW custom kernel and its associated dispatch logic. This reduces kernel launch overhead, lowers GPU memory bandwidth pressure, and simplifies the overall CUDA execution pipeline.
  • Reduced Memory Footprint: Updated the bytes_per_image calculation for workspace memory to reflect the removal of the GEMM output buffer. This allows the operator to potentially process more images in parallel under the same memory constraints.

3. Changed

  • Batch chunking: Chunk size k is chosen so that the number of outer rounds is minimized under the temp-memory cap; k does not have to divide N. The host loop uses cur_parallel = min(k, N - b), so the last chunk may be smaller. This is the intended default behavior for this EP (not yet in a formal release).
  • Kernel-size templates: Im2col is specialized for 1×1, 3×3, and 7×7; other sizes (including 5×5) use the dynamic kH/kW path. Rationale: 5×5 is less common in current stacks (often replaced by stacked 3×3); specializing 7×7 targets common large-kernel cases. Older DCN/detection models that still use 5×5 deformable conv will take the dynamic path—correctness is unchanged; only compile-time unrolling differs.
  • Add aliasing flags: Updated DeformConv aliasing comments to make the stronger guarantee explicit: if output Y overlaps any input buffer, results can be incorrect regardless of restrict, because output writes may clobber source elements before they are fully consumed. restrict further tightens this by introducing undefined behavior when aliasing assumptions are violated.

Summary

In the current implementation, CPU performance is 33x (main branch is 15x) that of TorchVision. If we were to implement AVX2/AVX512 optimizations from scratch, we could achieve a 36x performance boost. However, I haven’t found any similar reference code in the ONNX Runtime repository.

This PR also significantly improves parallelism:

image

Both ort and tv are configured with 16 threads

Open Question for Reviewers

Regarding CUDA Temporary Memory Allocation:
Currently, the effective maximum temporary memory for CUDA is calculated using a heuristic (total_global_mem * 0.1 or similar logic in GetDeformConvEffectiveMaxTempBytes). While the removal of gemm_output_buffer has reduced the memory footprint per image, I am not entirely certain if this 10% threshold is still the most appropriate value for balancing parallel image processing (n_parallel_imgs) against overall VRAM consumption in large models.

I would appreciate any feedback or suggestions on whether we should tune this threshold, or if there's a more robust way to dynamically determine the optimal temporary workspace size for DeformConv in ORT.

@ShirasawaSama ShirasawaSama marked this pull request as draft March 24, 2026 05:21
@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch from 860a063 to fea461b Compare March 24, 2026 15:34
@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Mar 24, 2026

DeformConv Benchmark: TorchVision vs ONNX Runtime

Real model: Birefnet HR matte 2048*2048

Summary

Same inputs are fed to both TorchVision and ONNX Runtime so that timing and numerical comparison are fair. GPU timing uses cuda.synchronize().
Input tensors and deformconv parameters are exported from real BiRefNet runtime calls (torch.save samples).

1. Absolute time (ms)

Table
Config GPU TV GPU ORT CPU TV CPU ORT
1x64x32x32-1x1 0.079 0.028 1.4 0.1
2x64x32x32-1x1 0.101 0.038 2.5 0.2
4x64x32x32-1x1 0.106 0.042 5.0 0.4
1x64x32x32-1x1 0.102 0.039 1.3 0.1
2x64x32x32-1x1 0.105 0.047 2.5 0.1
4x64x32x32-1x1 0.108 0.042 5.0 0.4
1x64x32x32-3x3 0.106 0.043 6.7 0.7
2x64x32x32-3x3 0.110 0.065 13.4 1.0
4x64x32x32-3x3 0.144 0.082 28.0 1.3
1x64x32x32-7x7 0.159 0.087 32.4 1.4
2x64x32x32-7x7 0.232 0.142 64.6 2.5
4x64x32x32-7x7 0.345 0.237 133.2 5.1
1x64x32x32-1x1 0.106 0.040 1.3 0.1
2x64x32x32-1x1 0.106 0.039 2.5 0.2
4x64x32x32-1x1 0.107 0.047 5.1 0.4
1x64x32x32-1x1 0.103 0.041 1.3 0.1
2x64x32x32-1x1 0.103 0.052 2.6 0.1
4x64x32x32-1x1 0.105 0.046 5.0 0.5
1x64x32x32-3x3 0.127 0.055 6.7 0.6
2x64x32x32-3x3 0.135 0.078 13.5 0.9
4x64x32x32-3x3 0.144 0.079 27.4 1.3
1x64x32x32-7x7 0.171 0.094 32.4 1.4
2x64x32x32-7x7 0.233 0.144 65.9 2.6
4x64x32x32-7x7 0.416 0.256 134.5 5.2
1x64x64x64-1x1 0.109 0.048 5.0 0.2
2x64x64x64-1x1 0.127 0.058 10.0 0.7
4x64x64x64-1x1 0.125 0.074 20.4 0.9
1x64x64x64-1x1 0.106 0.043 5.0 0.4
2x64x64x64-1x1 0.130 0.063 9.9 0.3
4x64x64x64-1x1 0.809 0.070 20.4 0.9
1x64x64x64-3x3 0.118 0.073 28.7 1.2
2x64x64x64-3x3 0.184 0.113 59.0 1.9
4x64x64x64-3x3 0.324 0.193 120.4 5.4
1x64x64x64-7x7 0.346 0.231 139.1 5.2
2x64x64x64-7x7 0.672 0.416 275.8 10.3
4x64x64x64-7x7 1.307 0.860 515.2 20.3
1x64x128x128-1x1 0.120 0.068 20.3 0.9
2x64x128x128-1x1 0.421 0.108 41.7 1.3
4x64x128x128-1x1 0.715 0.266 88.1 2.4
1x64x128x128-1x1 0.135 0.071 20.3 0.9
2x64x128x128-1x1 0.243 0.102 41.8 1.4
4x64x128x128-1x1 0.712 0.265 87.5 2.4
1x64x128x128-3x3 0.329 0.189 119.8 3.9
2x64x128x128-3x3 0.691 0.362 241.0 7.7
4x64x128x128-3x3 1.456 0.745 453.0 15.5
1x64x128x128-7x7 1.309 0.856 501.4 22.1
2x64x128x128-7x7 2.735 1.719 1003.6 43.2
4x64x128x128-7x7 6.657 4.722 1977.6 86.5
1x64x256x256-1x1 0.696 0.233 86.9 1.9
2x64x256x256-1x1 1.471 0.612 180.1 8.3
4x64x256x256-1x1 2.967 1.243 332.8 19.4
1x64x256x256-1x1 0.694 0.235 86.4 1.9
2x64x256x256-1x1 1.473 0.613 179.0 7.7
4x64x256x256-1x1 2.969 1.244 332.0 19.3
1x64x256x256-3x3 1.445 0.743 422.2 17.0
2x64x256x256-3x3 2.885 1.488 832.0 37.8
4x64x256x256-3x3 7.028 4.253 1661.1 73.7

2. Relative time (ORT / TV), ratio < 1 ⇒ ORT faster

Table
Config GPU CPU
1x64x32x32-1x1 0.36 0.03
2x64x32x32-1x1 0.37 0.09
4x64x32x32-1x1 0.40 0.08
1x64x32x32-1x1 0.38 0.07
2x64x32x32-1x1 0.45 0.04
4x64x32x32-1x1 0.39 0.09
1x64x32x32-3x3 0.41 0.10
2x64x32x32-3x3 0.59 0.07
4x64x32x32-3x3 0.57 0.05
1x64x32x32-7x7 0.55 0.04
2x64x32x32-7x7 0.61 0.04
4x64x32x32-7x7 0.69 0.04
1x64x32x32-1x1 0.38 0.04
2x64x32x32-1x1 0.37 0.09
4x64x32x32-1x1 0.44 0.08
1x64x32x32-1x1 0.39 0.07
2x64x32x32-1x1 0.51 0.04
4x64x32x32-1x1 0.43 0.10
1x64x32x32-3x3 0.43 0.09
2x64x32x32-3x3 0.58 0.07
4x64x32x32-3x3 0.55 0.05
1x64x32x32-7x7 0.55 0.04
2x64x32x32-7x7 0.62 0.04
4x64x32x32-7x7 0.62 0.04
1x64x64x64-1x1 0.44 0.04
2x64x64x64-1x1 0.46 0.07
4x64x64x64-1x1 0.59 0.04
1x64x64x64-1x1 0.41 0.07
2x64x64x64-1x1 0.48 0.03
4x64x64x64-1x1 0.09 0.04
1x64x64x64-3x3 0.62 0.04
2x64x64x64-3x3 0.62 0.03
4x64x64x64-3x3 0.60 0.04
1x64x64x64-7x7 0.67 0.04
2x64x64x64-7x7 0.62 0.04
4x64x64x64-7x7 0.66 0.04
1x64x128x128-1x1 0.57 0.04
2x64x128x128-1x1 0.26 0.03
4x64x128x128-1x1 0.37 0.03
1x64x128x128-1x1 0.53 0.04
2x64x128x128-1x1 0.42 0.03
4x64x128x128-1x1 0.37 0.03
1x64x128x128-3x3 0.57 0.03
2x64x128x128-3x3 0.52 0.03
4x64x128x128-3x3 0.51 0.03
1x64x128x128-7x7 0.65 0.04
2x64x128x128-7x7 0.63 0.04
4x64x128x128-7x7 0.71 0.04
1x64x256x256-1x1 0.33 0.02
2x64x256x256-1x1 0.42 0.05
4x64x256x256-1x1 0.42 0.06
1x64x256x256-1x1 0.34 0.02
2x64x256x256-1x1 0.42 0.04
4x64x256x256-1x1 0.42 0.06
1x64x256x256-3x3 0.51 0.04
2x64x256x256-3x3 0.52 0.05
4x64x256x256-3x3 0.61 0.04

3. Numerical accuracy (same inputs)

max_abs_diff = max |output_TV − output_ORT|.

Table

Note for reviewers: Relative error is not used as the primary metric because the DeformConv output tensor contains many values near zero; dividing by small denominators would inflate relative error without indicating a real numerical bug. For float32 and typical GEMM/interpolation pipelines, a max absolute difference on the order of 1e-4–1e-3 is consistent with implementation differences (e.g. TF32, reduction order). All reported max_abs_diff values are below 1e-2.

Config GPU max_abs_diff CPU max_abs_diff
1x64x32x32-1x1 0.000401 0.000000
2x64x32x32-1x1 0.000409 0.000000
4x64x32x32-1x1 0.000423 0.000000
1x64x32x32-1x1 0.000408 0.000000
2x64x32x32-1x1 0.000392 0.000000
4x64x32x32-1x1 0.000383 0.000000
1x64x32x32-3x3 0.000316 0.000001
2x64x32x32-3x3 0.000315 0.000001
4x64x32x32-3x3 0.000312 0.000001
1x64x32x32-7x7 0.000284 0.000001
2x64x32x32-7x7 0.000305 0.000001
4x64x32x32-7x7 0.000308 0.000001
1x64x32x32-1x1 0.000403 0.000000
2x64x32x32-1x1 0.000443 0.000000
4x64x32x32-1x1 0.000432 0.000000
1x64x32x32-1x1 0.000376 0.000000
2x64x32x32-1x1 0.000397 0.000000
4x64x32x32-1x1 0.000468 0.000000
1x64x32x32-3x3 0.000285 0.000001
2x64x32x32-3x3 0.000299 0.000001
4x64x32x32-3x3 0.000305 0.000001
1x64x32x32-7x7 0.000282 0.000001
2x64x32x32-7x7 0.000289 0.000001
4x64x32x32-7x7 0.000310 0.000001
1x64x64x64-1x1 0.000530 0.000000
2x64x64x64-1x1 0.000418 0.000000
4x64x64x64-1x1 0.000457 0.000000
1x64x64x64-1x1 0.000435 0.000000
2x64x64x64-1x1 0.000427 0.000000
4x64x64x64-1x1 0.000459 0.000000
1x64x64x64-3x3 0.000357 0.000001
2x64x64x64-3x3 0.000362 0.000001
4x64x64x64-3x3 0.000330 0.000001
1x64x64x64-7x7 0.000302 0.000001
2x64x64x64-7x7 0.000321 0.000001
4x64x64x64-7x7 0.000354 0.000001
1x64x128x128-1x1 0.000443 0.000000
2x64x128x128-1x1 0.000476 0.000000
4x64x128x128-1x1 0.000539 0.000000
1x64x128x128-1x1 0.000454 0.000000
2x64x128x128-1x1 0.000497 0.000000
4x64x128x128-1x1 0.000518 0.000000
1x64x128x128-3x3 0.000357 0.000001
2x64x128x128-3x3 0.000346 0.000001
4x64x128x128-3x3 0.000355 0.000001
1x64x128x128-7x7 0.000359 0.000001
2x64x128x128-7x7 0.000352 0.000001
4x64x128x128-7x7 0.000378 0.000001
1x64x256x256-1x1 0.000582 0.000000
2x64x256x256-1x1 0.000569 0.000000
4x64x256x256-1x1 0.000538 0.000000
1x64x256x256-1x1 0.000499 0.000000
2x64x256x256-1x1 0.000495 0.000000
4x64x256x256-1x1 0.000613 0.000000
1x64x256x256-3x3 0.000341 0.000001
2x64x256x256-3x3 0.000381 0.000001
4x64x256x256-3x3 0.000407 0.000001

4. Chart

Batch 1

image

Batch 2

image

Batch 4

image

Charts: GPU/CPU absolute latency (TV vs ORT); ORT/TV ratio (< 1 ⇒ ORT faster); max absolute error per config.


5. Setup

  • Backend: TorchVision deform_conv2d vs ONNX Runtime (custom CUDA/CPU DeformConv).
  • Inputs: Same tensors for TV and ORT (fixed seed per config). Default: offset=0, mask=1; use --deform for random offset/mask. Configs prefixed nm use DeformConv without mask (ONNX: input + offset only).
  • Warmup / runs: 50 / 200 (or 10 / 50 with --quick). GPU timing uses cuda.synchronize().

@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch 3 times, most recently from b14d3b8 to 3623db2 Compare March 26, 2026 12:41
@ShirasawaSama ShirasawaSama marked this pull request as ready for review March 26, 2026 14:07
@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Mar 26, 2026

Hi @tianleiwu , sorry to tag you here — could you please help trigger a Copilot code review for me?

This commit implements performance optimizations for the DeformConv operator on the main branch, achieving ~65% speedup on CPU and ~30% on GPU.
Compared to torchvision, this implementation runs at approximately 0.02×–0.07× latency on CPU and 0.34×–0.6× on GPU.

All unit tests have passed locally, lint checks are clean, and the performance report on the real model (Birefnet) is included above.

I’ve also added additional comments to improve readability and maintainability.

If you have any suggestions or would like me to make any changes, I’d be happy to address them promptly.

@ShirasawaSama ShirasawaSama changed the title [CPU/CUDA ep] [WIP] Improve DeformConv op performance [CPU/CUDA ep] Improve DeformConv op performance Mar 26, 2026
@ShirasawaSama ShirasawaSama marked this pull request as draft March 26, 2026 18:07
@ShirasawaSama ShirasawaSama marked this pull request as ready for review March 26, 2026 18:41
@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch from 7f4e638 to 37099b6 Compare March 26, 2026 18:53
@tianleiwu tianleiwu requested a review from Copilot March 29, 2026 23:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves DeformConv performance across CPU and CUDA execution providers by refactoring hot paths and reducing GPU workspace/kernel overhead, while also adding additional bounds/indexing safeguards and clarifying documentation.

Changes:

  • CPU: refactors im2col to precompute and reuse an AoSoA bilinear sampling plan (SIMD-friendly) and adds overflow-safe stride/dimension handling.
  • CUDA: removes GEMM-output staging/copy kernel by writing GEMM results directly into Y (zero-copy), and adds a faster bias-add path with an optional 2D kernel launch.
  • Shared: introduces DeformConvValidateAndComputeCommonDims to centralize derived dimension computation.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc Updates test comment to reflect mask semantics more generally.
onnxruntime/core/providers/cuda/nn/deform_conv_impl.h Updates CUDA kernel entry-point docs; removes GEMM-output copy entry point; extends bias API signature.
onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu Implements zero-copy-related kernel changes (im2col masking specialization, bias add rework, 32/64-bit index selection helpers).
onnxruntime/core/providers/cuda/nn/deform_conv.cc Updates host orchestration: new chunk sizing, removes temp GEMM output buffer, writes GEMM directly to NCHW, passes grid-y limit for bias.
onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h Adds shared derived-dim helper used by CPU/CUDA.
onnxruntime/core/providers/cpu/nn/deform_conv.cc Major CPU refactor: sampling plan + SIMD-friendly fill, overflow-checked strides, improved bias add.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ShirasawaSama ShirasawaSama marked this pull request as draft March 30, 2026 18:07
@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch 2 times, most recently from f938f4c to f892ad8 Compare March 30, 2026 18:57
@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Mar 30, 2026

In the end, I decided not to make too many changes; the remaining optimizations I could make would be practically negligible.

I only added the groups==1 specialization (cuda) and fixed the issues reported by Copilot, then rebased the changes onto the latest main branch. (reverted, no noticeable improvement in performance)

I only fixed the issues reported by Copilot, then rebased the changes onto the latest main branch.

@ShirasawaSama ShirasawaSama marked this pull request as ready for review March 30, 2026 19:01
@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch from f892ad8 to cf3d79c Compare March 30, 2026 19:17
@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Mar 31, 2026

Failed runner:

The self-hosted runner lost communication with the server. Verify the machine is running and has a healthy network connection. Anything in your workflow that terminates the runner process, starves it for CPU/Memory, or blocks its network access can cause this error.

It doesn't seem to be caused by this PR; it might be due to an unstable runner. You can just rerun the CI.

@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Apr 2, 2026

Rebase to latest main branch?

Waiting for the merge of the PR titled "Fix WebGPU Windows CI timeouts by removing redundant tests and sharding provider tests" to prevent pipeline timeouts.

@ShirasawaSama ShirasawaSama force-pushed the feature/improve-deform-conv-pref branch from cf3d79c to 8c6ed74 Compare April 2, 2026 17:31
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

A substantial, well-motivated performance overhaul of CPU and CUDA DeformConv. The CPU path is redesigned around a precomputed AoSoA bilinear sampling plan that amortizes interpolation setup across channels, while the CUDA path eliminates a temp buffer + scatter kernel by writing GEMM output directly to NCHW Y via strided batched GEMM.

Positives:

  • AoSoA layout (kPlanAoSoALanes=8) aligns with 256-bit AVX2; the gather/interpolate inner loop can SIMD-unroll 8-wide.
  • Plan reuse across channels within an offset group is the key insight that eliminates redundant bilinear coordinate work.
  • Zero-copy GEMM output with cublasGemmStridedBatchedHelper — the stride algebra is correct for direct NCHW writes.
  • UseMask promoted to template parameter eliminates runtime branches from the hottest kernel loop.
  • The branchless bilinear interpolation (safe-address + validity masks) eliminates warp divergence.
  • Overflow checks (CheckedMulSizeT, CheckedBatchSpan) are systematic and well-placed.

One high-priority issue (signed integer overflow UB in CeilDiv) and a few suggestions below.

# Severity Component Issue
1 High CUDA GetDeformConvParallelChunkSize CeilDiv signed integer overflow UB when N is large
2 Suggestion CUDA zero-copy GEMM Behavioral change from evenly-divisible to uneven chunks should be documented
3 Suggestion CUDA im2col kernel 5×5 specialization removed — mention in PR description
4 Suggestion Tests No new test cases for AoSoA tail, prime batch, 7×7 kernel, or overflow paths
5 Nitpick CUDA im2col kernel offset_byte_offset is an element offset, not byte offset
6 Nitpick CUDA bias kernel max_grid_y > 32 threshold lacks rationale comment
7 Nitpick CPU deform_conv Plan block allocation is not zero-initialized (currently safe but fragile)

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: PR 27824

High-Priority

  1. onnxruntime/core/providers/cuda/nn/deform_conv.cc

    The new balanced chunking helper can now produce a final iteration with cur_parallel == 1 after earlier iterations used a larger n_parallel_imgs, but the cur_parallel == 1 GEMM fast path still computes stride_col from the outer-scope col_stride (kernel_dim * n_parallel_imgs * output_image_size). DeformConvIm2ColImpl repacks col_buffer for each iteration using the current cur_parallel, so on a one-image tail the actual per-group stride in col_buffer is only kernel_dim * output_image_size. As soon as group > 1, the strided-batched cuBLAS call skips past the compactly written tail buffer and reads uninitialized memory for later groups, corrupting the last chunk. This regression was masked before because the old divisor-based chunking avoided one-image tails after larger chunks. The fix is to derive the group stride from the current iteration (kernel_dim * cur_out_size, which collapses to kernel_dim * output_image_size here) and to add a CUDA test that combines group > 1 with a non-divisible batch size.

Suggestion

  1. onnxruntime/core/providers/cpu/nn/deform_conv.cc

  2. onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

    The late ORT_CPU_RESTRICT / __restrict__ pass now marks external model inputs (X, offset, mask) as non-aliasing. ONNX Runtime does not guarantee that distinct tensor inputs are backed by distinct memory regions, and DeformConv takes the same element type for all three tensors, so callers can legally bind overlapping views of one allocation. Under those conditions the new qualifiers make the optimized code undefined behavior and give the compiler permission to reorder loads as if aliasing were impossible. Unless the op contract explicitly forbids overlapping input buffers, the restrict annotations should be limited to internally owned temporaries (data_col, sampling_plan_blocks, etc.) rather than user-provided input tensors.

@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

Code Review: PR 27824

High-Priority

  1. onnxruntime/core/providers/cuda/nn/deform_conv.cc
    The new balanced chunking helper can now produce a final iteration with cur_parallel == 1 after earlier iterations used a larger n_parallel_imgs, but the cur_parallel == 1 GEMM fast path still computes stride_col from the outer-scope col_stride (kernel_dim * n_parallel_imgs * output_image_size). DeformConvIm2ColImpl repacks col_buffer for each iteration using the current cur_parallel, so on a one-image tail the actual per-group stride in col_buffer is only kernel_dim * output_image_size. As soon as group > 1, the strided-batched cuBLAS call skips past the compactly written tail buffer and reads uninitialized memory for later groups, corrupting the last chunk. This regression was masked before because the old divisor-based chunking avoided one-image tails after larger chunks. The fix is to derive the group stride from the current iteration (kernel_dim * cur_out_size, which collapses to kernel_dim * output_image_size here) and to add a CUDA test that combines group > 1 with a non-divisible batch size.

Suggestion

  1. onnxruntime/core/providers/cpu/nn/deform_conv.cc
  2. onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu
    The late ORT_CPU_RESTRICT / __restrict__ pass now marks external model inputs (X, offset, mask) as non-aliasing. ONNX Runtime does not guarantee that distinct tensor inputs are backed by distinct memory regions, and DeformConv takes the same element type for all three tensors, so callers can legally bind overlapping views of one allocation. Under those conditions the new qualifiers make the optimized code undefined behavior and give the compiler permission to reorder loads as if aliasing were impossible. Unless the op contract explicitly forbids overlapping input buffers, the restrict annotations should be limited to internally owned temporaries (data_col, sampling_plan_blocks, etc.) rather than user-provided input tensors.
  1. 76dfdba: Fixed and test cases added.
  2. 2574ee2: Updated DeformConv aliasing comments to make the stronger guarantee explicit: if output Y overlaps any input buffer, results can be incorrect regardless of restrict, because output writes may clobber source elements before they are fully consumed. restrict further tightens this by introducing undefined behavior when aliasing assumptions are violated.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup here. The CPU/CUDA refactor is much easier to follow now, but I found one remaining correctness issue in the grouped CUDA tail path and one follow-up coverage gap.

@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Apr 7, 2026

Thanks for the cleanup here. The CPU/CUDA refactor is much easier to follow now, but I found one remaining correctness issue in the grouped CUDA tail path and one follow-up coverage gap.

These two issues have been fixed in commit cdb979c , and I've reviewed the code again. There shouldn't be any similar issues left. Could you please review it again? Thank you very much!

Besides that, I actually have another question: I’ve currently set THREADS_PER_BLOCK in CUDA to 256, but I’m not sure if there’s a better strategy for determining this value dynamically.

@ShirasawaSama ShirasawaSama marked this pull request as draft April 7, 2026 13:58
@ShirasawaSama ShirasawaSama marked this pull request as ready for review April 7, 2026 14:15
@ShirasawaSama ShirasawaSama marked this pull request as draft April 7, 2026 14:27
@ShirasawaSama
Copy link
Copy Markdown
Contributor Author

ShirasawaSama commented Apr 7, 2026

I also ran separate benchmarks for the hand-coded AVX2 and AVX512 implementations, as well as the AoSoA, AoS, and SoA memory layouts, and found that the differences weren’t significant—they can essentially be considered noise. It seems the bottleneck is still in random memory reads. I’ll leave it at that for now and not make any changes.

In fact, I’ve found that the current AoSoA cannot be automatically vectorized by the compiler. 🥲🥲

MSVC:
image

GCC:
image

My bad.


However, after I implemented SIMD manually, AVX2 was indeed being used:
image

And I haven’t seen any examples of SIMD being written directly in the OP, so I’m not sure how to handle this situation. Should I add a new function to mlas?

@ShirasawaSama ShirasawaSama marked this pull request as ready for review April 7, 2026 15:01
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Current head looks sound overall; I left one small non-blocking suggestion inline around reusing the new checked common-dimension computation before CUDA chunk sizing.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comments above

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu merged commit 9d7e6d5 into microsoft:main Apr 9, 2026
99 of 100 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants