From 70cfe6216c393343a7cee3f92453aa4c2052a7f4 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:07:15 -0700 Subject: [PATCH 1/9] Overhaul PagedSearch. --- diskann-providers/src/index/diskann_async.rs | 19 +- diskann-providers/src/index/wrapped_async.rs | 92 +++--- diskann/src/graph/index.rs | 312 +++++++++---------- diskann/src/graph/search/scratch.rs | 1 + diskann/src/graph/test/cases/paged_search.rs | 82 ++--- 5 files changed, 233 insertions(+), 273 deletions(-) diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 69c254c86..ddb642295 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -436,25 +436,16 @@ pub(crate) mod tests { Q: Copy + std::fmt::Debug + Send + Sync, { assert!(max_candidates <= groundtruth.len()); - let mut state = index - .start_paged_search(strategy, ¶meters.context, query, parameters.search_l) + let mut search = index + .paged_search(strategy, ¶meters.context, query, parameters.search_l) .await .unwrap(); - let mut buffer = vec![Neighbor::::default(); parameters.search_k]; let mut iter = 0; let mut seen = 0; while !groundtruth.is_empty() { - let count = index - .next_search_results::( - ¶meters.context, - &mut state, - parameters.search_k, - &mut buffer, - ) - .await - .unwrap(); - for (i, b) in buffer.iter().enumerate().take(count) { + let page = search.next_page(parameters.search_k).await.unwrap(); + for (i, b) in page.iter().enumerate() { let m = is_match(groundtruth, *b, 0.01); match m { None => { @@ -469,7 +460,7 @@ pub(crate) mod tests { b, iter, i, - &buffer[i..], + &page[i..], ); } Some(j) => groundtruth.remove(j), diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index e8d26a609..73e5308cf 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -13,7 +13,7 @@ use diskann::{ Batch, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchStrategy, }, - index::{DegreeStats, PagedSearchState, PartitionedNeighbors, SearchState}, + index::{DegreeStats, PartitionedNeighbors}, search_output_buffer, }, neighbor::Neighbor, @@ -339,59 +339,51 @@ where ) } - #[allow(clippy::type_complexity)] - pub fn start_paged_search( - &self, + /// Begin a paged search over the index (synchronous wrapper). + /// + /// Returns a [`PagedSearch`] handle. See + /// [`graph::index::PagedSearch::next_page`] for retrieving results. + pub fn paged_search<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - ) -> ANNResult> + ) -> ANNResult> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { - self.handle.block_on( - self.inner - .start_paged_search(strategy, context, query, l_value), - ) + let inner = self + .handle + .block_on(self.inner.paged_search(strategy, context, query, l_value))?; + Ok(PagedSearch { + handle: self.handle.clone(), + inner, + }) } - #[allow(clippy::type_complexity)] - pub fn start_paged_search_with_init_ids( - &self, + /// Begin a paged search with explicit initial seed IDs (synchronous wrapper). + pub fn paged_search_with_init_ids<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> ANNResult> + init_ids: Option<&'a [DP::InternalId]>, + ) -> ANNResult> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { - self.handle.block_on( + let inner = self.handle.block_on( self.inner - .start_paged_search_with_init_ids(strategy, context, query, l_value, init_ids), - ) - } - - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> ANNResult - where - S: SearchStrategy, - { - self.handle.block_on(self.inner.next_search_results( - context, - search_state, - k, - result_output, - )) + .paged_search_with_init_ids(strategy, context, query, l_value, init_ids), + )?; + Ok(PagedSearch { + handle: self.handle.clone(), + inner, + }) } pub fn count_reachable_nodes( @@ -416,6 +408,28 @@ where } } +/// Synchronous wrapper around [`graph::index::PagedSearch`] that owns a tokio runtime handle. +/// +/// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) +/// blocks the current thread to drive the underlying async search forward. +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + handle: tokio::runtime::Handle, + inner: graph::index::PagedSearch<'a, DP, S, T>, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Blocks the current thread. Returns an empty `Vec` when the search is exhausted. + pub fn next_page(&mut self, k: usize) -> ANNResult>> { + self.handle.block_on(self.inner.next_page(k)) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index d696135a0..0f0637975 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -137,34 +137,6 @@ pub struct PartitionedNeighbors { pub deleted: Vec, } -/// Placeholder for extra state. -/// -/// The contents of the search state are designed for the synchronous index. -/// However, use cases in the asynchronous index require some extra state. -/// -/// This placeholder is used in the synchronous code-paths. -pub struct NoExtraState; - -/// Represents the state of the pagged search. -/// It can be used to do paged search by doing multiple `nextSearchResults()` queries. -/// -/// Generic extra state can be included to facilitate extra use-cases. -/// However, this extra state **must** be '`static' as we do not know how long the search -/// state will live for. -#[derive(Debug)] -pub struct SearchState { - /// Scratch space for query processing. - pub scratch: SearchScratch, - /// The computed search results ready to be returned in `nextSearchResults()` query - pub computed_result: Vec>, - /// The index of the next result to be returned. - pub next_result_index: usize, - /// The search computes results in the multiple of `search_param_l`. - pub search_param_l: usize, - /// Any extra data needed by down-stream implementations. - pub extra: ExtraState, -} - /// Edge pending submission for multi-insert. #[derive(Debug)] struct PendingEdge { @@ -204,16 +176,6 @@ where /// and `Err` paths. type BatchResult = Result; -/// State used during by paged search to perform multiple, consecutive searches over the index. -/// -/// Type parameters: -/// -/// * `DP`: The type of the [`DataProvider`]. -/// * `S`: The type of the [`SearchStrategy`]. -/// * `C`: The type of `S`'s [`BuildQueryComputer`] computer. This exists as a separate -/// type parameter because the type of the query computer depends on the type of the query. -pub type PagedSearchState = SearchState<::InternalId, (S, C)>; - impl DiskANNIndex where DP: DataProvider, @@ -2280,34 +2242,42 @@ where // Paged Search // ////////////////// - pub fn start_paged_search( - &self, + /// Begin a paged search over the index. + /// + /// Returns a [`PagedSearch`] handle whose [`next_page`](PagedSearch::next_page) method + /// yields successive pages of nearest-neighbor results. + pub fn paged_search<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - ) -> impl SendFuture>> + ) -> impl SendFuture>> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { async move { - self.start_paged_search_with_init_ids(strategy, context, query, l_value, None) + self.paged_search_with_init_ids(strategy, context, query, l_value, None) .await } } - pub fn start_paged_search_with_init_ids( - &self, + /// Begin a paged search with explicit initial seed IDs. + /// + /// This is the same as [`paged_search`](Self::paged_search) but allows the caller to + /// provide custom starting points for the graph traversal. + pub fn paged_search_with_init_ids<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> impl SendFuture>> + init_ids: Option<&'a [DP::InternalId]>, + ) -> impl SendFuture>> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { async move { let (computer, scratch) = { @@ -2350,121 +2320,20 @@ where (computer, scratch) }; - ANNResult::Ok(SearchState { + ANNResult::Ok(PagedSearch { + index: self, + context, scratch, computed_result: vec![Neighbor::default(); l_value], next_result_index: l_value, search_param_l: l_value, - extra: (strategy, computer), + strategy, + computer, + _query: std::marker::PhantomData, }) } } - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> impl SendFuture> - where - S: SearchStrategy, - { - async move { - if k > search_state.search_param_l { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be less than or equal to search_param_l".to_string(), - )); - } - if k == 0 { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be greater than 0".to_string(), - )); - } - if result_output.len() < k { - return ANNResult::Err(ANNError::log_paged_search_error( - "The size of result_output should be greater than or equal to k".to_string(), - )); - } - - let copy_to_output = - |search_state: &mut SearchState, - count: usize, - result_output: &mut [Neighbor], - result_output_offset: usize| { - result_output[result_output_offset..result_output_offset + count] - .copy_from_slice( - &search_state.computed_result[search_state.next_result_index - ..search_state.next_result_index + count], - ); - search_state.next_result_index += count; - }; - - let used_computed_result_count: usize = cmp::min( - k, - search_state.computed_result.len() - search_state.next_result_index, - ); - if used_computed_result_count > 0 { - copy_to_output( - search_state, - used_computed_result_count, - result_output, - 0, // result_output_offset - ); - - if used_computed_result_count == k { - return ANNResult::Ok(k); - } - } - - let start_points = { - let mut accessor = search_state - .extra - .0 - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - self.search_internal( - None, // beam_width - &start_ids, - &mut accessor, - &search_state.extra.1, - &mut search_state.scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - start_ids - }; - - let (mut candidates, total_considered) = self - .filter_search_candidates(&start_points, k, &mut search_state.scratch.best) - .await?; - search_state.scratch.best.drain_best(total_considered); - - let computed_result_count = candidates.len(); - search_state.computed_result.clear(); - search_state.computed_result.append(&mut candidates); - - search_state.next_result_index = 0; - if computed_result_count != search_state.search_param_l { - search_state.computed_result.truncate(computed_result_count); - } - - let leftover_results = cmp::min(k - used_computed_result_count, computed_result_count); - - copy_to_output( - search_state, - leftover_results, // count of results to copy - result_output, - used_computed_result_count, // result_output_offset - ); - - ANNResult::Ok(used_computed_result_count + leftover_results) - } - } - /// Count the number of nodes in the graph reachable from the given `start_points`. /// /// This function has a large memory footprint for large graphs and should not be called @@ -3188,6 +3057,135 @@ struct BatchIdMismatch { ids_len: usize, } +////////////////// +// Paged Search // +////////////////// + +/// A paged search handle that owns all search state internally. +/// +/// Created by [`DiskANNIndex::paged_search`] or +/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to +/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of +/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// +/// # Type Parameters +/// +/// * `'idx` — lifetime of the borrowed [`DiskANNIndex`]. +/// * `'ctx` — lifetime of the borrowed [`DataProvider::Context`]. +/// * `DP` — the [`DataProvider`] type. +/// * `S` — the [`SearchStrategy`] type. +/// * `T` — the original query type (carried only for trait-bound resolution). +#[derive(Debug)] +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + index: &'a DiskANNIndex, + context: &'a DP::Context, + scratch: SearchScratch, + computed_result: Vec>, + next_result_index: usize, + search_param_l: usize, + strategy: S, + computer: S::QueryComputer, + // Note: + _query: std::marker::PhantomData T>, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// When the search is exhausted, returns an empty `Vec`. + pub fn next_page( + &mut self, + k: usize, + ) -> impl SendFuture>>> { + async move { + if k > self.search_param_l { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be less than or equal to search_param_l".to_string(), + )); + } + if k == 0 { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be greater than 0".to_string(), + )); + } + + let mut result = Vec::with_capacity(k); + + // Drain any already-computed results first. + let available = self + .computed_result + .len() + .saturating_sub(self.next_result_index); + let from_cache = cmp::min(k, available); + if from_cache > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + from_cache], + ); + self.next_result_index += from_cache; + + if result.len() == k { + return ANNResult::Ok(result); + } + } + + // Resume graph search to fill the next batch. + let start_points = { + let mut accessor = self + .strategy + .search_accessor(&self.index.data_provider, self.context) + .into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + self.index + .search_internal( + None, // beam_width + &start_ids, + &mut accessor, + &self.computer, + &mut self.scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + start_ids + }; + + let (mut candidates, total_considered) = self + .index + .filter_search_candidates(&start_points, k, &mut self.scratch.best) + .await?; + self.scratch.best.drain_best(total_considered); + + let computed_result_count = candidates.len(); + self.computed_result.clear(); + self.computed_result.append(&mut candidates); + + self.next_result_index = 0; + if computed_result_count != self.search_param_l { + self.computed_result.truncate(computed_result_count); + } + + let remaining_need = k - result.len(); + let leftover = cmp::min(remaining_need, computed_result_count); + if leftover > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + leftover], + ); + self.next_result_index += leftover; + } + + ANNResult::Ok(result) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/diskann/src/graph/search/scratch.rs b/diskann/src/graph/search/scratch.rs index 98a5b3127..2670f118c 100644 --- a/diskann/src/graph/search/scratch.rs +++ b/diskann/src/graph/search/scratch.rs @@ -158,6 +158,7 @@ where } /// Return the currently configured `search_l`: the number of best candidates to track. + #[cfg(test)] pub fn search_l(&self) -> usize { self.best.search_l() } diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 0b8fa4706..8c2f79ee1 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -6,8 +6,8 @@ //! Tests for paged (iterative) search. //! //! Paged search returns results in pages of k neighbors via a stateful -//! `SearchState`. Tests cover basic pagination, single-page retrieval, -//! and small page sizes that stress the iteration machinery. +//! [`PagedSearch`](crate::graph::index::PagedSearch) handle. Tests cover basic pagination, +//! single-page retrieval, and small page sizes that stress the iteration machinery. use std::sync::Arc; @@ -155,8 +155,8 @@ fn basic_paged_search() { let page_size = 4; let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -165,24 +165,13 @@ fn basic_paged_search() { .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -209,8 +198,8 @@ fn single_page() { let page_size = 200; // larger than total points (125) let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -218,21 +207,8 @@ fn single_page() { )) .unwrap(); - let mut buffer = vec![Neighbor::::default(); page_size]; - - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - let results: Vec> = buffer[..count].to_vec(); - let pages = vec![results.clone()]; + let results = rt.block_on(search.next_page(page_size)).unwrap(); + let pages = vec![results]; let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -242,18 +218,9 @@ fn single_page() { assert_no_duplicates_across_pages(&pages); assert_non_decreasing_distances(&pages); - // Verify second call returns 0 (nothing left) - let count2 = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - assert_eq!(count2, 0, "second page should be empty"); + // Verify second call returns empty (nothing left) + let page2 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page2.is_empty(), "second page should be empty"); } #[test] @@ -270,8 +237,8 @@ fn small_page_size() { let page_size = 1; // one result per page, maximum iterations let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -280,24 +247,13 @@ fn small_page_size() { .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); From 219302537bfcdf3a83e183f300f850411fc0071d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:08:23 -0700 Subject: [PATCH 2/9] Query Computers no longer need to be '`static`. --- diskann/src/graph/index.rs | 2 +- diskann/src/provider.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 0f0637975..cb6e75907 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -3085,7 +3085,7 @@ pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { search_param_l: usize, strategy: S, computer: S::QueryComputer, - // Note: + // Note: The use of `fn` here is so _query: std::marker::PhantomData T>, } diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f55af356c..262e9b15c 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -508,8 +508,7 @@ pub trait BuildQueryComputer: Accessor { /// elements yielded by the [`Accessor`]. type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + Send - + Sync - + 'static; + + Sync; /// Build the query computer for this accessor. /// From 9c773eb2b779ed14f2c90e0fc0f57d8e80330ecc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:10:52 -0700 Subject: [PATCH 3/9] Clarify use of `fn` pointer. --- diskann/src/graph/index.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index cb6e75907..27c6baa81 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -3067,14 +3067,6 @@ struct BatchIdMismatch { /// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to /// [`next_page`](Self::next_page) resumes the graph search and returns the next page of /// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. -/// -/// # Type Parameters -/// -/// * `'idx` — lifetime of the borrowed [`DiskANNIndex`]. -/// * `'ctx` — lifetime of the borrowed [`DataProvider::Context`]. -/// * `DP` — the [`DataProvider`] type. -/// * `S` — the [`SearchStrategy`] type. -/// * `T` — the original query type (carried only for trait-bound resolution). #[derive(Debug)] pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { index: &'a DiskANNIndex, @@ -3085,8 +3077,8 @@ pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { search_param_l: usize, strategy: S, computer: S::QueryComputer, - // Note: The use of `fn` here is so - _query: std::marker::PhantomData T>, + // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. + _query: std::marker::PhantomData, } impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> From b63d009893f362e07ace518314f7c85733cba212 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 12:23:09 -0700 Subject: [PATCH 4/9] Paged search 2.0 --- diskann-providers/src/index/wrapped_async.rs | 4 +- diskann/src/graph/index.rs | 158 +-------------- diskann/src/graph/search/mod.rs | 3 + diskann/src/graph/search/paged.rs | 163 ++++++++++++++++ diskann/src/graph/test/cases/paged_search.rs | 4 + rfcs/01078-paged-search.md | 193 +++++++++++++++++++ 6 files changed, 369 insertions(+), 156 deletions(-) create mode 100644 diskann/src/graph/search/paged.rs create mode 100644 rfcs/01078-paged-search.md diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 73e5308cf..858218b55 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -408,13 +408,13 @@ where } } -/// Synchronous wrapper around [`graph::index::PagedSearch`] that owns a tokio runtime handle. +/// Synchronous wrapper around [`graph::search::PagedSearch`] that owns a tokio runtime handle. /// /// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) /// blocks the current thread to drive the underlying async search forward. pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { handle: tokio::runtime::Handle, - inner: graph::index::PagedSearch<'a, DP, S, T>, + inner: graph::search::PagedSearch<'a, DP, S, T>, } impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 27c6baa81..54d40cc81 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -30,7 +30,7 @@ use super::{ }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ - Knn, + Knn, PagedSearch, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -42,7 +42,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::{ErrorExt, IntoANNResult}, internal, - neighbor::{self, Neighbor, NeighborPriorityQueue, NeighborQueue}, + neighbor::{self, Neighbor, NeighborQueue}, provider::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, Delete, ElementStatus, ExecutionContext, Guard, NeighborAccessor, @@ -2052,35 +2052,6 @@ where } } - /// Filter out start nodes from the best candidates in the scratch. - fn filter_search_candidates( - &self, - start_points: &[DP::InternalId], - l_value: usize, - best: &mut NeighborPriorityQueue, - ) -> impl SendFuture>, usize)>> { - async move { - let mut total = 0usize; - let mut candidates = Vec::with_capacity(l_value); - for n in best.iter() { - total += 1; - if !start_points.contains(&n.id) { - candidates.push(n); - if candidates.len() >= l_value { - break; - } - } - } - - debug_assert!( - l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), - "Not enough candidates after filtering starting points", - ); - - Ok((candidates, total)) - } - } - /// Execute a search using the unified search interface. /// /// This method provides a single entry point for all search types. The `search_params` argument @@ -2254,7 +2225,7 @@ where l_value: usize, ) -> impl SendFuture>> where - S: SearchStrategy + 'static, + S: SearchStrategy, T: Copy + Send + 'a, { async move { @@ -2276,7 +2247,7 @@ where init_ids: Option<&'a [DP::InternalId]>, ) -> impl SendFuture>> where - S: SearchStrategy + 'static, + S: SearchStrategy, T: Copy + Send + 'a, { async move { @@ -3057,127 +3028,6 @@ struct BatchIdMismatch { ids_len: usize, } -////////////////// -// Paged Search // -////////////////// - -/// A paged search handle that owns all search state internally. -/// -/// Created by [`DiskANNIndex::paged_search`] or -/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to -/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of -/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. -#[derive(Debug)] -pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { - index: &'a DiskANNIndex, - context: &'a DP::Context, - scratch: SearchScratch, - computed_result: Vec>, - next_result_index: usize, - search_param_l: usize, - strategy: S, - computer: S::QueryComputer, - // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. - _query: std::marker::PhantomData, -} - -impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> -where - DP: DataProvider, - S: SearchStrategy, -{ - /// Returns the next page of at most `k` nearest-neighbor results. - /// - /// Results across pages are non-overlapping and ordered by non-decreasing distance. - /// When the search is exhausted, returns an empty `Vec`. - pub fn next_page( - &mut self, - k: usize, - ) -> impl SendFuture>>> { - async move { - if k > self.search_param_l { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be less than or equal to search_param_l".to_string(), - )); - } - if k == 0 { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be greater than 0".to_string(), - )); - } - - let mut result = Vec::with_capacity(k); - - // Drain any already-computed results first. - let available = self - .computed_result - .len() - .saturating_sub(self.next_result_index); - let from_cache = cmp::min(k, available); - if from_cache > 0 { - result.extend_from_slice( - &self.computed_result - [self.next_result_index..self.next_result_index + from_cache], - ); - self.next_result_index += from_cache; - - if result.len() == k { - return ANNResult::Ok(result); - } - } - - // Resume graph search to fill the next batch. - let start_points = { - let mut accessor = self - .strategy - .search_accessor(&self.index.data_provider, self.context) - .into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - self.index - .search_internal( - None, // beam_width - &start_ids, - &mut accessor, - &self.computer, - &mut self.scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - start_ids - }; - - let (mut candidates, total_considered) = self - .index - .filter_search_candidates(&start_points, k, &mut self.scratch.best) - .await?; - self.scratch.best.drain_best(total_considered); - - let computed_result_count = candidates.len(); - self.computed_result.clear(); - self.computed_result.append(&mut candidates); - - self.next_result_index = 0; - if computed_result_count != self.search_param_l { - self.computed_result.truncate(computed_result_count); - } - - let remaining_need = k - result.len(); - let leftover = cmp::min(remaining_need, computed_result_count); - if leftover > 0 { - result.extend_from_slice( - &self.computed_result - [self.next_result_index..self.next_result_index + leftover], - ); - self.next_result_index += leftover; - } - - ANNResult::Ok(result) - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index fac279421..350930671 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -46,6 +46,9 @@ mod knn_search; mod multihop_search; mod range_search; +mod paged; +pub use paged::PagedSearch; + pub mod record; pub(crate) mod scratch; diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs new file mode 100644 index 000000000..3b1cebcbb --- /dev/null +++ b/diskann/src/graph/search/paged.rs @@ -0,0 +1,163 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::future::SendFuture; + +use crate::{ + ANNError, ANNResult, + error::IntoANNResult, + graph::{ + DiskANNIndex, + glue::{SearchExt, SearchStrategy}, + search::{record::NoopSearchRecord, scratch::SearchScratch}, + }, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::DataProvider, + utils::VectorId, +}; + +/// A paged search handle that owns all search state internally. +/// +/// Created by [`DiskANNIndex::paged_search`] or +/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to +/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of +/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +#[derive(Debug)] +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + pub(in crate::graph) index: &'a DiskANNIndex, + pub(in crate::graph) context: &'a DP::Context, + pub(in crate::graph) scratch: SearchScratch, + pub(in crate::graph) computed_result: Vec>, + pub(in crate::graph) next_result_index: usize, + pub(in crate::graph) search_param_l: usize, + pub(in crate::graph) strategy: S, + pub(in crate::graph) computer: S::QueryComputer, + // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. + pub(in crate::graph) _query: std::marker::PhantomData, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// When the search is exhausted, returns an empty `Vec`. + pub fn next_page( + &mut self, + k: usize, + ) -> impl SendFuture>>> { + async move { + if k > self.search_param_l { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be less than or equal to search_param_l".to_string(), + )); + } + if k == 0 { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be greater than 0".to_string(), + )); + } + + let mut result = Vec::with_capacity(k); + + // Drain any already-computed results first. + let available = self + .computed_result + .len() + .saturating_sub(self.next_result_index); + let from_cache = std::cmp::min(k, available); + if from_cache > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + from_cache], + ); + self.next_result_index += from_cache; + + if result.len() == k { + return ANNResult::Ok(result); + } + } + + // Resume graph search to fill the next batch. + let start_points = { + let mut accessor = self + .strategy + .search_accessor(&self.index.data_provider, self.context) + .into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + self.index + .search_internal( + None, // beam_width + &start_ids, + &mut accessor, + &self.computer, + &mut self.scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + start_ids + }; + + let (mut candidates, total_considered) = + filter_search_candidates(&start_points, k, &mut self.scratch.best); + self.scratch.best.drain_best(total_considered); + + let computed_result_count = candidates.len(); + self.computed_result.clear(); + self.computed_result.append(&mut candidates); + + self.next_result_index = 0; + if computed_result_count != self.search_param_l { + self.computed_result.truncate(computed_result_count); + } + + let remaining_need = k - result.len(); + let leftover = std::cmp::min(remaining_need, computed_result_count); + if leftover > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + leftover], + ); + self.next_result_index += leftover; + } + + ANNResult::Ok(result) + } + } +} + +// FIXME: Wire proper post-processing support into paged search. +fn filter_search_candidates( + start_points: &[I], + l_value: usize, + best: &mut NeighborPriorityQueue, +) -> (Vec>, usize) +where + I: VectorId, +{ + let mut total = 0usize; + let mut candidates = Vec::with_capacity(l_value); + for n in best.iter() { + total += 1; + if !start_points.contains(&n.id) { + candidates.push(n); + if candidates.len() >= l_value { + break; + } + } + } + + debug_assert!( + l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), + "Not enough candidates after filtering starting points", + ); + + (candidates, total) +} diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 8c2f79ee1..9b7fd27f6 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -221,6 +221,10 @@ fn single_page() { // Verify second call returns empty (nothing left) let page2 = rt.block_on(search.next_page(page_size)).unwrap(); assert!(page2.is_empty(), "second page should be empty"); + + // Verify repeated calls after exhaustion remain empty (idempotent, no panic). + let page3 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page3.is_empty(), "third page should still be empty"); } #[test] diff --git a/rfcs/01078-paged-search.md b/rfcs/01078-paged-search.md new file mode 100644 index 000000000..50ac58d2c --- /dev/null +++ b/rfcs/01078-paged-search.md @@ -0,0 +1,193 @@ +# Overhaul Paged Search + +| | | +|---|---| +| **Authors** | Mark Hildebrand | +| **Contributors** | | +| **Created** | 2026-05-18 | +| **Updated** | 2026-05-18 | + +## Summary + +Replace the `SearchState<..., ExtraState: 'static>` pattern for paged search with a lifetime-bound `PagedSearch<'a, ...>` in `diskann`, and document a channel-based spawned-task pattern for downstream consumers that need to cross `tokio::spawn` or FFI boundaries. +This removes the `'static` requirement on query computers and search strategies, enabling future trait simplification. + +## Motivation + +### Background + +Paged search allows callers to retrieve nearest-neighbor results incrementally (one "page" at a time) without restarting the graph traversal. +The search state (scratch buffers, the priority queue, the query computer) must persist across page boundaries. + +Earlier synchronous version of DiskANN did this by persisting the search state manually and passing the search state explicitly to the next-page requests. +The async rewrite stuck with this pattern where callers were required to manage a `SearchState` struct whose `ExtraState` type parameter carried a `'static` bound: +```rust +// OLD: ExtraState must be 'static +pub struct SearchState { ... } +pub type PagedSearchState = SearchState<::InternalId, (S, C)>; + +// Note +// * S: SearchStrategy for some T +// * C: S::QueryComputer +``` +In downstream crates that expose paged search across FFI or task boundaries, the state was traditionally type-erased behind `Box` (which also captured the `DiskANNIndex`) and sent as an opaque pointer. +This required both the strategy and the query computer (parameters `S` and `C`) to be `'static`. + +### Problem Statement + +The `'static` bound on `BuildQueryComputer::QueryComputer` propagates throughout the trait hierarchy: + +```rust +// OLD +type QueryComputer: ... + Send + Sync + 'static; +``` + +This prevents: + +1. Query computers that borrow from the index or context (common with quantization tables). +2. Fusing the query-computer into the accessor. + Due to the lifetime needed by accessors, they can't be persisted in a `'static` struct this way. + However, query computers may contain non-trivial pre-processed state, meaning recreating them on each new page retrieval is a performance footgun. +3. Simplifying the `SearchStrategy` / `Accessor` trait tower by removing unnecessary indirection introduced solely to satisfy `'static`. + +### Goals + +1. Remove the `'static` bound from `BuildQueryComputer::QueryComputer`. +2. Remove `SearchState`, `NoExtraState`, and `PagedSearchState` from the public API. +3. Provide a `PagedSearch<'a, ...>` handle that is lifetime-bound to the index and context, encapsulating all search state. +4. Document the channel-based pattern for downstream consumers that need to cross task/FFI boundaries (where `'static` is inherently required by the runtime, not by `diskann`). + +## Proposal + +The key idea is this: DiskANN for better or worse is already fully async, and async Rust despite its flaws already provides a clean way of doing this without requiring our traits to bend over backwards. +So let's embrace async to actually help us for a change. + +### Core library (`diskann`) + +Replace the old split API: + +```rust +// OLD +index.start_paged_search(strategy, ctx, query, l) -> SearchState<...> +index.next_search_results(ctx, &mut state, k, &mut buf) -> usize +``` + +With a self-contained handle: + +```rust +// NEW +impl DiskANNIndex { + pub fn paged_search<'a, S, T>( + &'a self, + strategy: S, + context: &'a DP::Context, + query: T, + l_value: usize, + ) -> impl SendFuture>> + where + S: SearchStrategy, + T: Copy + Send + 'a; +} + +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + index: &'a DiskANNIndex, + context: &'a DP::Context, + scratch: SearchScratch, + computed_result: Vec>, + next_result_index: usize, + search_param_l: usize, + strategy: S, + computer: S::QueryComputer, + _query: PhantomData, // covariant, always Send+Sync +} + +impl PagedSearch<'a, DP, S, T> { + pub fn next_page(&mut self, k: usize) -> impl SendFuture>>>; +} +``` + +The key change is that `PagedSearch` borrows all the necessary components. +There is no need to deconstruct the search components into `'static` pieces after each paged search and "reassemble" them on subsequent searches. + +### Crossing spawn boundaries: the channel pattern + +`PagedSearch<'a, ...>` borrows the index, so it cannot be sent to a `tokio::spawn`'d task directly. +When a long-lived session or an FFI boundary requires `'static` ownership, the recommended pattern is to **spawn a task that owns the search state as a local variable** and communicate with it via channels: + +```rust +// Types are illustrative — adapt names to your crate. + +type PageResult = ANNResult>>; + +/// Spawn a paged search session. The index is held by Arc so the task is 'static. +/// +/// Returns a request channel and a result channel. The caller sends the desired +/// page size (`k`) and awaits the corresponding result on the other end. +fn spawn_paged_session( + index: Arc>, + context: Arc, + query: T, + l: usize, +) -> (mpsc::Sender, mpsc::Receiver) { + let (req_tx, mut req_rx) = mpsc::channel::(1); + let (res_tx, res_rx) = mpsc::channel::(1); + + tokio::spawn(async move { + // Borrow from the Arc — these references are scoped to the task. + let mut search = index.paged_search(strategy, &*context, query, l).await.unwrap(); + + while let Some(k) = req_rx.recv().await { + let page = search.next_page(k).await; + if res_tx.send(page).await.is_err() { + break; // caller dropped the result receiver + } + } + // Request channel closed -> caller dropped sender -> clean shutdown. + }); + + (req_tx, res_rx) +} +``` + +Key properties of this pattern: + +1. **`'static` is confined to the spawn boundary**: the `Arc` satisfies the runtime's requirement, while the borrow from it lives entirely inside the task's local scope. + Importantly, even though `PagedSearch` borrows, it can be embedded inside a `'static` future. +2. **State is fully encapsulated**: callers never see `SearchScratch`, `QueryComputer`, or any internal types. +3. **Clean shutdown**: dropping the request sender closes the channel; the task exits gracefully. +4. **Per-request context**: the request channel can carry additional metadata (profiling tokens, cancellation flags, etc.) without polluting the core API. + +### Migration guide + +| Old pattern | New pattern | +|---|---| +| `index.start_paged_search(s, ctx, q, l)` | `index.paged_search(s, ctx, q, l).await` | +| `index.next_search_results(ctx, &mut state, k, &mut buf)` | `search.next_page(k).await` | +| `SearchState` | `PagedSearch<'a, DP, S, T>` | +| `PagedSearchState` | `PagedSearch<'a, DP, S, T>` | +| Check return count for exhaustion | Check `page.is_empty()` | +| Type-erased `Box` across task/FFI boundaries | Channel + spawned task (see above) | + +## Feasibility via FFI + +This is a pretty big change in API, but it enables some significant future simplifications to our trait hierarchy by removing the `'static` special case introduced by paged search. +An internal user of paged search was ported to this new approach to check the feasibility. +While it was a bit of work to overcome the impedance mismatch of the quirks of that integration, the end result is cleaner, has fewer overall task spawns, and fewer FFI related race conditions. +And really, this integration was already basically doing this same thing behind the scenes. + +## Alternatives + +The main alternative I see is to keep the status quo with explicit state management. +While some of planned trait simplifications are still on the table, I think the opportunity to align paged search with the rest of the trait hierarchy is well worth it. + +## Benchmark Results + +No performance change expected (nor observed in the simulator for the aforementioned internal FFI user) since the search algorithm is identical. +Existing Rust code will have a similar pattern of future usage as before, just packaged slightly differently. + +## References + +1. [RFC 3498 — Lifetime Capture Rules 2024](https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html) — + Rust edition 2024 changes that make `impl Trait + 'a` returns more ergonomic. +2. [tokio::sync::mpsc](https://docs.rs/tokio/latest/tokio/sync/mpsc/index.html) — the channel + primitive used in the spawned-task pattern. From c357f024b9ddfacd53d314664ec440320db1ed32 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:55:02 -0700 Subject: [PATCH 5/9] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-providers/src/index/wrapped_async.rs | 2 +- diskann/src/graph/test/cases/paged_search.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 858218b55..2f7ee997a 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -342,7 +342,7 @@ where /// Begin a paged search over the index (synchronous wrapper). /// /// Returns a [`PagedSearch`] handle. See - /// [`graph::index::PagedSearch::next_page`] for retrieving results. + /// [`PagedSearch::next_page`] for retrieving results. pub fn paged_search<'a, S, T>( &'a self, strategy: S, diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 9b7fd27f6..61e2c8c45 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -6,7 +6,7 @@ //! Tests for paged (iterative) search. //! //! Paged search returns results in pages of k neighbors via a stateful -//! [`PagedSearch`](crate::graph::index::PagedSearch) handle. Tests cover basic pagination, +//! [`PagedSearch`](crate::graph::search::PagedSearch) handle. Tests cover basic pagination, //! single-page retrieval, and small page sizes that stress the iteration machinery. use std::sync::Arc; From 02e2ccc41ef7ef59bf488d61bace13250d0bcbb3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:55:22 -0700 Subject: [PATCH 6/9] Fix documentation. --- diskann/src/graph/search/paged.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index 3b1cebcbb..ab77829e3 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -18,12 +18,12 @@ use crate::{ utils::VectorId, }; -/// A paged search handle that owns all search state internally. +/// Intermediate state for paged search. /// -/// Created by [`DiskANNIndex::paged_search`] or -/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to -/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of -/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// Each call to [`next_page`](Self::next_page) resumes the graph search and returns the +/// next page of nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// +/// See also: [`DiskANNIndex::paged_search`], [`DiskANNIndex::paged_search_with_init_ids`]. #[derive(Debug)] pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { pub(in crate::graph) index: &'a DiskANNIndex, @@ -45,7 +45,11 @@ where { /// Returns the next page of at most `k` nearest-neighbor results. /// - /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// Results across pages are non-overlapping but not guaranteed to be monotonic with + /// respect to distance. + /// + /// Within a page, results ordered by non-decreasing distance. + /// /// When the search is exhausted, returns an empty `Vec`. pub fn next_page( &mut self, From feb2e97cf9473dbb71875dddb43bb21b4c09a0b6 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:59:47 -0700 Subject: [PATCH 7/9] Commit simplification. --- diskann/src/graph/search/paged.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index ab77829e3..34bdd16d6 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -116,11 +116,7 @@ where let computed_result_count = candidates.len(); self.computed_result.clear(); self.computed_result.append(&mut candidates); - self.next_result_index = 0; - if computed_result_count != self.search_param_l { - self.computed_result.truncate(computed_result_count); - } let remaining_need = k - result.len(); let leftover = std::cmp::min(remaining_need, computed_result_count); From 7344b0023070fd42201f2ca5b4c533474fa3b73f Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 19 May 2026 15:31:29 -0700 Subject: [PATCH 8/9] Add specialized noawait paged-search. --- Cargo.lock | 37 +- diskann-providers/Cargo.toml | 2 +- diskann-providers/src/index/wrapped_async.rs | 343 +++++++++++++++++++ 3 files changed, 362 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94c12d370..a4f7df368 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,9 +1236,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -1251,9 +1251,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -1261,15 +1261,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -1278,15 +1278,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -1295,15 +1295,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -1313,9 +1313,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -1325,7 +1325,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index c5447dca2..ef8d79182 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -36,7 +36,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tempfile = { workspace = true, optional = true } bf-tree = { workspace = true, optional = true } prost = "0.14.1" -futures-util.workspace = true +futures-util = { workspace = true, features = ["async-await"] } serde_json = { workspace = true, optional = true } vfs = { workspace = true, optional = true } diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 2f7ee997a..d546fcc81 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -20,6 +20,7 @@ use diskann::{ provider::{AsNeighbor, AsNeighborMut, DataProvider, Delete, SetElement}, utils::ONE, }; +use diskann_utils::Reborrow; use crate::storage::{LoadWith, StorageReadProvider}; @@ -386,6 +387,28 @@ where }) } + /// Begin a synchronous paged search over the index. + /// + /// This will construct a [`noawait::PagedSearch`] and initialize search with the + /// providers start points. Pages can be retrieved with [`noawait::PagedSearch::next`]. + /// + /// **Caution**: This method should only be used if is known that all functions reachable + /// via the implementation of [`SearchStrategy`] are known to be synchronous and never + /// truly await. This allows [`noawait::PagedSearch`] to be much more efficient. + pub fn paged_search_no_await( + &self, + strategy: S, + context: DP::Context, + query: T, + l_value: usize, + ) -> ANNResult> + where + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, + { + noawait::PagedSearch::new(self.inner.clone(), strategy, context, query, l_value) + } + pub fn count_reachable_nodes( &self, start_points: &[DP::InternalId], @@ -430,6 +453,210 @@ where } } +pub mod noawait { + //! Implementations of a synchronous wrapper around [`diskann::graph::DiskANNIndex`] that + //! assume the [`Accessor`] and associated implementations never truly `await` and are + //! in fact synchronous. + //! + //! With this assumption, we can perform lighter-weight communication with the index + //! by assuming that each `poll` returns ready. + //! + //! **Do not use this if your index ever actually await**: Doing so will lead to deadlock! + + use super::*; + + use std::{ + cell::RefCell, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, + }; + + use diskann::{ANNErrorKind, utils::VectorId}; + use diskann_utils::Reborrow; + use thiserror::Error; + + type Input = Rc>>; + type Output = Rc>>>>; + + fn channel() -> (Input, Output) + where + I: VectorId, + { + let input = Rc::new(RefCell::new(None)); + let output = Rc::new(RefCell::new(None)); + (input, output) + } + + fn step(fut: Pin<&mut dyn Future>) -> Option { + let mut cx = Context::from_waker(Waker::noop()); + match fut.poll(&mut cx) { + Poll::Ready(v) => Some(v), + Poll::Pending => None, + } + } + + /// A synchronous wrapper for [`graph::search::PagedSearch`] + /// + /// See: [`super::DiskANNIndex::paged_search_no_await`]. + pub struct PagedSearch { + // The `input` is wrapped in an `Option` so we can fuse `searcher` if it exits + // with an error. Polling a completed future risk panicking. + // + // We construct `searcher` to pull its next-page size from this input. + input: Option, + + // Output yielded from polling `searcher`. + output: Output, + + // We shut down the future by running `Drop`. Thus, the only way it can actually + // finish is if it returns with an error. + searcher: Pin>>, + } + + impl PagedSearch + where + I: VectorId, + { + /// Construct a new [`PagedSearch`]. + /// + /// This works by creating a small async task using [`graph::search::PagedSearch`] + /// internally. The requested k-nearest neighors are sent using a `Rc>` + /// channel and the actual neighbors are retrieved from a similar data structure. + /// + /// Under the assumption that the implementation of [`graph::search::PagedSearch`]'s + /// implementations are fully synchronous, we can directly poll this task instead + /// of going through a runtime since we (theoretically) control the only suspension + /// point. + /// + /// Doing so allows stepping the task state machine to be done with a single function + /// call to `Future::poll`. + /// + /// Obviously, if the "noawait" assumption is broken, then the inner async job may + /// yield before our control point, but we can detect this situation since no output + /// will be generated on the output channel. + /// + /// We rely on `Drop` to clean up the paged search resources. + pub(super) fn new( + index: Arc>, + strategy: S, + context: DP::Context, + query: T, + l_value: usize, + ) -> ANNResult + where + DP: DataProvider, + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, + { + // Prepare the input and output channels used to communicate with the search task. + let (input, output) = channel::(); + let input_clone = input.clone(); + let output_clone = output.clone(); + + // Create the search task. + let mut searcher: Pin>> = Box::pin(async move { + // The assumption of `noawait` is that this call will always resolve to + // `Poll::Ready`. + let mut state = match index + .paged_search(strategy, &context, query.reborrow(), l_value) + .await + { + Ok(state) => state, + Err(err) => return err, + }; + + loop { + // This is the await point that pauses the future. + // + // Under the "noawait" assumption, this should be the only point where + // this future ever yields `Pending` and is where we expect the future + // to stop every time we poll it. + futures_util::pending!(); + + // We control the invocation of poll and should always ensure that + // input is available. + let k_value = match input_clone.take() { + Some(value) => value, + None => return InternalInvariantViolated::MissingInput.into(), + }; + + // Step paged search and propagate any errors. + let page = match state.next_page(k_value).await { + Ok(page) => page, + Err(err) => return err, + }; + + // Send output to the caller. + output_clone.replace(Some(page)); + } + }); + + // Drive the inner future one step to initialize paged search. + if let Some(err) = step(searcher.as_mut()) { + return Err(err); + } + + let this = Self { + input: Some(input), + output, + searcher, + }; + Ok(this) + } + + /// Retrieve the next results from paged search, returning any errors. + /// + /// If [`next`](Self::next) previously returned with an error, it will continue + /// to do so. + pub fn next(&mut self, k: usize) -> ANNResult>> { + // Prepare input. We use the presence of the input channel to decide whether + // or not it is safe to poll search task. + match self.input.as_ref() { + Some(input) => input.replace(Some(k)), + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "paged searcher errored and is no longer runnable", + )); + } + }; + + // Progress the future. + // + // The only reason to return return `Some` is if the inner future aborts with + // an error. Here, we fuse the searcher to prevent panics on re-enters and + // forward the error. + if let Some(result) = step(self.searcher.as_mut()) { + self.input = None; + return Err(result); + } + + // Profit! + match self.output.take() { + Some(v) => Ok(v), + None => Err(InternalInvariantViolated::MissingOutput.into()), + } + } + } + + #[derive(Debug, Clone, Copy, Error)] + enum InternalInvariantViolated { + #[error("INTERNAL: input channel was not configured")] + MissingInput, + #[error("noawait contract violated: future suspended before expected yield point")] + MissingOutput, + } + + impl From for ANNError { + #[track_caller] + #[cold] + fn from(err: InternalInvariantViolated) -> Self { + Self::new(ANNErrorKind::Opaque, err) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -552,4 +779,120 @@ mod tests { assert_eq!(ids[0], 0); assert_eq!(distances[0], 0.0); } + + fn wrapped_test_provider() -> DiskANNIndex { + let provider = + graph::test::provider::Provider::grid(graph::test::synthetic::Grid::One, 100).unwrap(); + + DiskANNIndex::new_with_current_thread_runtime( + graph::config::Builder::new( + provider.max_degree(), + diskann::graph::config::MaxDegree::same(), + 100, + (Metric::L2).into(), + ) + .build() + .unwrap(), + provider, + ) + } + + // Test the `noawait` paged searcher. + // + // This relies on the test-provider being no-await. + #[test] + fn test_paged_search_noawait() { + let index = wrapped_test_provider(); + + for page_size in [1, 5, 9, 12] { + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + 10.max(page_size), + ) + .unwrap(); + + let mut i = 0u32; + loop { + let v = paged.next(page_size).unwrap(); + assert!( + v.len() <= page_size, + "candidates returned ({}) exceeded page size ({})", + v.len(), + page_size, + ); + + if v.is_empty() { + break; + } + + for neighbor in v { + assert_ne!( + neighbor.id, + u32::MAX, + "paged search should not return start point", + ); + assert_eq!( + neighbor.id, i, + "monotonicity should at least hold for the 1d grid" + ); + assert_eq!( + neighbor.distance, + (i as f32) * (i as f32), + "distance was computed incorrectly!", + ); + i += 1; + } + } + + // Search is exhausted - make sure that subsequent searches yield empty vectors. + let exhausted = paged.next(5).unwrap(); + assert!( + exhausted.is_empty(), + "expected an empty vector when exhausted - instead got {:?}", + exhausted + ); + } + } + + // Verify that the searcher is properly fused when it returns with an error. + #[test] + fn test_paged_search_noawait_fuse() { + let index = wrapped_test_provider(); + + // To do this test, we request more neighbors than the search-L, which triggers + // an inner error. + let search_l = 10; + let bigger_than_search_l = 20; + + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + search_l, + ) + .unwrap(); + + let expected = "search_param_l"; + let err = paged.next(bigger_than_search_l).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains(expected), + "expected error message to contain \"{}\" - instead got\n\n{}", + expected, + msg, + ); + + // Now that we've yielded an error - the next time we request pages should also error. + let err = paged.next(10).unwrap_err(); + let err_msg = err.to_string(); + assert!( + err_msg.contains("paged searcher errored"), + "unexpected error message:\n\n{}", + err_msg + ); + } } From 1431af516f98a347cf7f220cde52f9d8a01e13c0 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 21 May 2026 10:22:51 -0700 Subject: [PATCH 9/9] Nits. --- diskann/src/graph/search/paged.rs | 8 ++++---- rfcs/01078-paged-search.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index 34bdd16d6..927040272 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -136,26 +136,26 @@ where // FIXME: Wire proper post-processing support into paged search. fn filter_search_candidates( start_points: &[I], - l_value: usize, + page_size: usize, best: &mut NeighborPriorityQueue, ) -> (Vec>, usize) where I: VectorId, { let mut total = 0usize; - let mut candidates = Vec::with_capacity(l_value); + let mut candidates = Vec::with_capacity(page_size); for n in best.iter() { total += 1; if !start_points.contains(&n.id) { candidates.push(n); - if candidates.len() >= l_value { + if candidates.len() >= page_size { break; } } } debug_assert!( - l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), + page_size.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), "Not enough candidates after filtering starting points", ); diff --git a/rfcs/01078-paged-search.md b/rfcs/01078-paged-search.md index 50ac58d2c..d662e1519 100644 --- a/rfcs/01078-paged-search.md +++ b/rfcs/01078-paged-search.md @@ -20,7 +20,7 @@ Paged search allows callers to retrieve nearest-neighbor results incrementally ( The search state (scratch buffers, the priority queue, the query computer) must persist across page boundaries. Earlier synchronous version of DiskANN did this by persisting the search state manually and passing the search state explicitly to the next-page requests. -The async rewrite stuck with this pattern where callers were required to manage a `SearchState` struct whose `ExtraState` type parameter carried a `'static` bound: +The async rewrite stuck with this pattern where callers were required to manage a `SearchState` struct whose `ExtraState` type parameter carried a `'static` bound: ```rust // OLD: ExtraState must be 'static pub struct SearchState { ... }