Cache-Aware Block-Transposed Chamfer/MaxSim Distance for f32 and f16#863
Cache-Aware Block-Transposed Chamfer/MaxSim Distance for f32 and f16#863suri-kumkaran wants to merge 18 commits intomainfrom
Conversation
hildebrandmw
left a comment
There was a problem hiding this comment.
This is on the right track, but there are some claims made that I don't think are fully backed up yet.
First, this is not exactly type agnostic. While an implementation does exist for f32, one does not exist for another data type as proof of generality. It would be nice to see at least an f16 implementation, which requires functionality not present in this PR (i.e., lazily unpacking f16 panels into f32 panels before entering micropanel loops to hoist the conversion out of the micro-kernel).
It would also be nice to see this used to implement row-major x row-major kernels again with packing on each tile load. The packing algorithms need not be super optimized (SIMD shuffles can be added later), but it would be great to see that this is possible within the kernel abstraction.
Second, we really should make kernel implementations micro-architecture aware from the get-go. Passing around diskann_wide::arch::Current isn't super helpful as that type is always available. Rather, we should parameterize something (perhaps at the trait level) on the Architecture to enable a clean extension point for AVX-512, Neon, etc..
Finally, since the original experimentation for this contained implementations for 8-bit integers, showing that this abstraction layer works there too would make a strong case for the abstraction.
Thanks for the through and insightful feedback — addressed in the latest push. f16 implementation: Done. Architecture parameterization: Row-major × row-major: Not in this PR, but the 8-bit integers: Same story — an i8 kernel would dequantize in The goal of this PR is proving the abstraction with two concrete types (f32 identity + f16 lazy unpacking) sharing one tiling loop and micro-kernel. Remaining implementations are mechanical from here. |
There was a problem hiding this comment.
Pull request overview
This PR introduces a new SIMD-accelerated, cache-tiled kernel framework for multi-vector MaxSim/Chamfer distance, targeting block-transposed query layouts and supporting both f32 and f16 (via staged f16→f32 preparation).
Changes:
- Added a new
distance::kernelsmodule with an unsafeKernel<A>abstraction and a sharedtiled_reduce5-level cache tiling loop. - Implemented
f32(AVX2+FMA + scalar/Neon delegation) andf16(SIMD/scalar conversion + delegation to f32 microkernels) kernel families with correctness tests vs the fallback. - Extended matrix and block-transposed utilities (
MatRef::as_matrix_view,BlockTransposedRef::available_rows) and renamed the previous “simple” implementation tofallback.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| diskann-quantization/src/multi_vector/matrix.rs | Adds MatRef<Standard<T>>::as_slice() and as_matrix_view() plus a roundtrip test. |
| diskann-quantization/src/multi_vector/distance/mod.rs | Wires in fallback and new kernels module; updates re-exports and docs. |
| diskann-quantization/src/multi_vector/distance/fallback.rs | Renames Simple→Fallback kernel and disambiguates a test conversion. |
| diskann-quantization/src/multi_vector/block_transposed.rs | Adds available_rows() and validates it in tests. |
| diskann-quantization/src/multi_vector/distance/kernels/mod.rs | Introduces kernel framework module and cache-budget helpers. |
| diskann-quantization/src/multi_vector/distance/kernels/tiled_reduce.rs | Adds generic tiling loop + reduction helper trait + planner tests. |
| diskann-quantization/src/multi_vector/distance/kernels/f32/mod.rs | Adds f32 kernel entrypoint and MaxSim/Chamfer impls + tests. |
| diskann-quantization/src/multi_vector/distance/kernels/f32/v3.rs | Adds AVX2+FMA 16×4 microkernel and V4→V3 delegation. |
| diskann-quantization/src/multi_vector/distance/kernels/f32/scalar.rs | Adds scalar 8×2 microkernel and Neon→Scalar delegation. |
| diskann-quantization/src/multi_vector/distance/kernels/f16/mod.rs | Adds f16 kernel entrypoint and MaxSim/Chamfer impls + tests. |
| diskann-quantization/src/multi_vector/distance/kernels/f16/v3.rs | Adds SIMD f16→f32 prepare hooks and delegates to f32 V3 microkernel. |
| diskann-quantization/src/multi_vector/distance/kernels/f16/scalar.rs | Adds scalar f16→f32 prepare hooks and delegates to scalar f32 microkernel. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
arrayka
left a comment
There was a problem hiding this comment.
The PR description suggests that the existing simple kernel is slower than the proposed solution. Could you please support this claim with easy‑to‑reproduce benchmark results?
That would allow others to replicate the numbers to verify the performance claims.
…ides the GROUP const generic and arch token behind vtable
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 17 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks Suryansh - this is getting close. The big thing I noticed is that the data preparation step is happening in the wrong place. It should be done at the tile level, not the panel level, to maximize the reuse of the preparation step.
| a: *const Self::APrepared, | ||
| b: *const Self::BPrepared, | ||
| k: usize, | ||
| r: *mut f32, |
There was a problem hiding this comment.
With this type of accumulator, are we over-fitting for Chamfer/MaxSim? What if we wanted to implement arg max for brute-force search? Also, in the case of u8/i8, we wouldn't want f32 as the result. We'd want u32/i32, right?
There was a problem hiding this comment.
As discussed offline, we’ll address this in a follow-up. I’ll spend some time thinking it through, outline an approach, and add the proposed direction to the PR description/comments.
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks Suryansh - this is getting there. In addition to the inline nits, I have some larger overall concerns:
Documentation: While documentation is good, verbose documentation that is redundant (e.g. repeats what the type signature already says) or what can be retrieved from rustdoc, or the function name, or a quick glance as the code is more harmful than helpful. Documentation of intent, invariants, surprised, etc. is great! But please look through many of the comments/docs and prune out the ones that are low signal.
Testing: Testing a lot of edge cases for matrix sizes is great. However. The test cases exercised here do not actually hit the main body of the loop nest. They all fall into the peeled section. This means that as-is, one of the main loops in this PR is not being tested in any way. Please fix this. Unfortunately, matrices of a realistic dimension relative to the L1 and L2 cache sizes are needed to exercise these paths, which is slow and means Miri tests will be especially lethargic. The best way I see to remedy that is to make the cache sizes configurable or overridable (which we will in any case need). But this really needs coverage.
| /// | ||
| /// The blanket identity impl covers every layout converting to itself | ||
| /// with `Buffer = ()` and zero cost. Explicit impls handle f16→f32 via | ||
| /// [`SliceCast`]. |
There was a problem hiding this comment.
The terminology around row * k and being contiguous makes me worry a little bit about future scenarios for strided access. What are the plans for systematically and safely updating the code when that lands?
There was a problem hiding this comment.
Yes, let’s address it in a follow-up once the design for strided access is clearer. I’m still working out the right level at which to handle K-splitting and perform the conversion.
arkrishn94
left a comment
There was a problem hiding this comment.
Thanks Suryansh, I think this is almost ready to merge! Left some small comments but other than that, the things that I wanted to highlight -
- Testing. From what I can tell, the
tiled_reduceloop is not tested across architectures. If that's the case, we should definitely add that. - On specializing the output type for the scratch to support other reductions like argmax and inputs like
i8/u8- am I misunderstand that this should be easy by augmentingKernelwith an associated return type and wiring that through? - This is probably cause I don't understand all the details very well but I still don't see how supporting metadata for quantized vectors will be wired through in terms of where it will live and how it'll be accessed in the main micro-kernel for post-op processing.
| /// | ||
| /// * `a` must point to `A_PANEL * k` contiguous `APrepared` values. | ||
| /// * `b` must point to `B_PANEL * k` contiguous `BPrepared` values. | ||
| /// * `r` must point to at least `A_PANEL` writable `f32` values. |
There was a problem hiding this comment.
And be valid for the lifetime of this function execution only?
There was a problem hiding this comment.
Lifetime-of-call validity is the implicit raw-pointer convention (stdlib's pointer APIs don't spell it out either) - the kernel doesn't store any of the pointers across the call. Left the contracts in their concise form to keep them scannable; happy to add it explicitly if you'd rather have the trait be paranoid-self-contained.
| /// # Safety | ||
| /// | ||
| /// * `src` must point to `rows * k` valid elements in `Self`'s layout. | ||
| /// * `buf` must come from [`new_buffer`](Self::new_buffer) with |
There was a problem hiding this comment.
I'm guessing buf has to be created with the same k as used in convert?
There was a problem hiding this comment.
Correct - added to the safety contract: buf must come from new_buffer with the same k (and max_tile_rows >= rows). The f16→f32 impls allocate max_tile_rows * k and convert writes rows * k via &mut buf[..count], so a smaller k would short-write or panic on the slice bound. The blanket identity impl ignores both, so this is purely a contract for non-identity converters.
| //! | [`BlockTransposedRef`] | Immutable view of a block-transposed matrix | | ||
| //! | [`BlockTransposedMut`] | Mutable view of a block-transposed matrix | | ||
| //! | [`QueryMatRef`] | Query wrapper for asymmetric distances | | ||
| //! | [`QueryComputer`] | Architecture-dispatched SIMD query computer | |
There was a problem hiding this comment.
nit: Can we separate the matrix types from the computer type in the documentation? Might be wroth adding separate documentation for it here since it's a core type in the new distance computation path?
There was a problem hiding this comment.
Good call - I'd lean toward keeping the table flat. It's meant as a fast inventory of what multi_vector re-exports, and MaxSim/Chamfer are equally first-class on the new distance path; pulling QueryComputer out into its own section while leaving them in the table would be inconsistent. The detailed docs already live on the type itself in query_computer/mod.rs (dispatch model, build cost, usage). Happy to expand that type-level doc if you feel anything's missing — just don't think the module-level overview is the right place for it. WDYT?
What
SIMD-accelerated MaxSim / Chamfer distance for f32 and f16 multi-vector queries, using block-transposed layout with L2/L1 cache-aware tiling. Introduces
QueryComputer<T>— a runtime-dispatched type that hides GROUP and architecture behind a vtable.The focus is proving the
Kernel/tiled_reduce/ConvertToabstraction is solid and type-agnostic. f32 and f16 share the same tiling loop and micro-kernel body.Why
The fallback kernel iterates query×doc in a flat nested loop, causing repeated cache evictions. Block-transposing the query and tiling both sides to fit in L2/L1 keeps hot data resident and feeds the FMA pipeline efficiently. f16 comes for free:
ConvertToconverts f16→f32 once per tile, then the f32 micro-kernel runs unchanged.Changed Files
All paths relative to
diskann-quantization/src/multi_vector/.New —
distance/kernels/mod.rsKernel<A>unsafe trait (Left/Rightlayouts,full_panel/partial_panel), cache budget helpers.layouts.rsLayoutmarker trait.BlockTransposed/RowMajorZST markers.DescribeLayoutbridge.ConvertTo<A, To>with blanket identity and f16→f32 specializations.tiled_reduce.rsK: Kernel,LA/LB: ConvertTo).FullReducetile planner.Reduceunroll trait.f32/mod.rsF32Kernel<GROUP>,max_ip_kernel,Target3dispatch, tests.f32/v3.rsretarget().f32/scalar.rsfma()). Neon delegates viaretarget().f16.rsF16Entry<GROUP>,Target3dispatch, tests — drivestiled_reducewithF32Kernel+ f16→f32ConvertTo. NoKernelimpl.New —
distance/query_computer/mod.rsQueryComputer<T>(Box<dyn DynQueryComputer<T>>).chamfer/max_simmethods.Chamfer/MaxSimtrait impls. Tests.f32.rsQueryComputer<f32>::newviadispatch1_no_features.BuildComputerTarget1impls.f16.rsQueryComputer<half::f16>::new— same pattern, delegates throughF16Entry.Modified
block_transposed.rspadded_nrows().matrix.rsas_matrix_view().multi_vector/mod.rsQueryComputer,Chamfer,MaxSim,MaxSimError,QueryMatRef.distance/mod.rsQueryComputerre-export, doc example.distance/max_sim.rsMaxSim/Chamfertypes,MaxSimErrorenum.distance/fallback.rsFallbackKernel(wasSimpleKernel),QueryMatRef, fallback trait impls.Renamed:
simple.rs→fallback.rsDesign Decisions
Kernel trait
Kernel<A>declaresLeft/Rightlayout types andfull_panel/partial_panel. The kernel receives already-converted pointers — it knows nothing about storage formats. V4→V3 and Neon→Scalar delegate viaretarget(). GROUP const generic (16 for V3/V4, 8 for Scalar/Neon) acts as a closed-world filter.Layout markers and ConvertTo
BlockTransposed<T, GROUP, PACK>andRowMajor<T>are ZST markers.Layoutimpl requiresT: Copy(micro-kernels load via raw pointers);Copy/Cloneon the markers themselves are unconditional (PhantomData wrappers).DescribeLayoutbridges matrix types to markers for type inference.ConvertTo<A, To>: blanket identity (Buffer = (), zero cost) + f16→f32 specializations (Vec<f32>buffer, SIMD-acceleratedSliceCast). Conversion is once per tile, not per panel.SliceCastdispatches through the runtime architecture token viaarch.run2()— the same SIMD level used by the micro-kernel.Tiling loop (reducing-GEMM)
5-level loop: L2 A-tiles → L1 B-tiles → A-panels → B-panels → micro-kernel.
ConvertTo::convertruns at tile boundaries. Cache budgets from kernel layout element sizes (~625 KB L2, ~36 KB L1). Geometry: 16×4 (V3/V4) or 8×2 (Scalar/Neon). Source strides and kernel strides differ when conversion is active (f16→f32).f16 path
F16Entry<GROUP>is a dispatch adapter, not aKernelimpl. Callstiled_reducewithF32Kerneland f16→f32ConvertToimpls. Zero SIMD code. Extends naturally to PQ/SQ/MinMax via newConvertToimpls.QueryComputer
QueryComputer<T>wrapsBox<dyn DynQueryComputer<T>>. CPU detection once at construction viadispatch1_no_features; hot path usesArchitecture::run3with#[target_feature]— no re-dispatch. Turbofish:QueryComputer::<f32>::new(q). Per-typeBuildComputerdispatch inf32.rs/f16.rs;mod.rsis generic.Suggested Review Order
distance/kernels/mod.rs—Kernel<A>trait.distance/kernels/layouts.rs— markers,ConvertTo, blanket identity, f16→f32.distance/kernels/tiled_reduce.rs— tiling loop, source vs kernel strides.distance/kernels/f32/v3.rs→f32/scalar.rs— micro-kernels.distance/kernels/f32/mod.rs—F32Kernel,max_ip_kernel, dispatch.distance/kernels/f16.rs—F16Entry, noKernelimpl, no submodules.distance/query_computer/mod.rs—QueryComputer<T>, tests.distance/query_computer/f32.rs→f16.rs— per-type dispatch.distance/max_sim.rs→distance/fallback.rs— types, fallback kernel.block_transposed.rs,matrix.rs,distance/mod.rs,multi_vector/mod.rs— supporting.Future Work
f32scratch and max-reduce, which fits Chamfer/MaxSim but may over-fit. Brute-force search would need arg-max (tracking indices, not just values), and u8/i8 kernels would naturally accumulate into u32/i32 rather than f32.Kernel+ConvertTofor PQ, SQ, MinMax.ConvertTotranspose.