diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 5057b58700..c27ecdbba0 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -12,6 +12,7 @@ mod scorer; pub mod tokenizer; mod wand; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow_schema::{DataType, Field}; @@ -24,6 +25,89 @@ pub use lance_tokenizer::Language; pub use scorer::MemBM25Scorer; pub use tokenizer::*; +use crate::scalar::inverted::query::{FtsSearchParams, Tokens}; + +/// Collect the unique terms needed to build a shared BM25 scorer. +/// +/// The scorer only needs corpus-level document frequencies, so we keep a +/// deduplicated term list here instead of constructing a full `Tokens` +/// object with positions. When fuzziness is enabled, each segment may +/// contribute additional terms (via `expand_fuzzy_tokens`); the union of +/// those terms is what the global scorer must cover. +fn scorer_terms( + indices: &[Arc], + query_tokens: &Tokens, + params: &FtsSearchParams, +) -> Result> { + let mut terms = Vec::new(); + let mut seen = HashSet::new(); + + if !matches!(params.fuzziness, Some(n) if n != 0) { + for token in query_tokens { + if seen.insert(token.to_string()) { + terms.push(token.to_string()); + } + } + return Ok(terms); + } + + for index in indices { + let expanded = index.expand_fuzzy_tokens(query_tokens, params)?; + for idx in 0..expanded.len() { + let token = expanded.get_token(idx); + if seen.insert(token.to_string()) { + terms.push(token.to_string()); + } + } + } + Ok(terms) +} + +/// Build a shared [`MemBM25Scorer`] across a set of FTS index segments. +/// +/// Aggregates each segment's `(total_tokens, num_docs, per_term_doc_freq)` +/// statistics — obtained via [`InvertedIndex::bm25_stats_for_terms`] — into a +/// single corpus-wide scorer, so that BM25 IDF scoring uses *global* +/// statistics rather than per-segment statistics. Computes the union of +/// fuzzy-expanded terms when `params.fuzziness` is set. +/// +/// Public as the canonical producer paired with the `with_base_scorer` +/// consumer on FTS exec types: callers holding `Arc` segment +/// handles locally can construct an injectable scorer without reimplementing +/// per-segment stat aggregation, term deduplication, and fuzzy-expansion +/// union. Keeps a single source of truth for BM25 IDF arithmetic across +/// segments. +pub fn build_global_bm25_scorer( + indices: &[Arc], + query_tokens: &Tokens, + params: &FtsSearchParams, +) -> Result { + let terms = scorer_terms(indices, query_tokens, params)?; + let first_index = indices.first().ok_or_else(|| { + lance_core::Error::invalid_input("FTS index requires at least one segment") + })?; + let (mut total_tokens, mut num_docs, first_token_docs) = + first_index.bm25_stats_for_terms(&terms); + let mut token_docs = HashMap::with_capacity(terms.len()); + for (term, count) in terms.iter().cloned().zip(first_token_docs.into_iter()) { + token_docs.insert(term, count); + } + + for index in indices.iter().skip(1) { + let (segment_total_tokens, segment_num_docs, segment_token_docs) = + index.bm25_stats_for_terms(&terms); + total_tokens += segment_total_tokens; + num_docs += segment_num_docs; + for (term, count) in terms.iter().zip(segment_token_docs.into_iter()) { + *token_docs + .get_mut(term) + .expect("global scorer terms should already be initialized") += count; + } + } + + Ok(MemBM25Scorer::new(total_tokens, num_docs, token_docs)) +} + use lance_core::Error; use crate::pbold; diff --git a/rust/lance-index/src/scalar/inverted/query.rs b/rust/lance-index/src/scalar/inverted/query.rs index 21b631d12a..7388bc0401 100644 --- a/rust/lance-index/src/scalar/inverted/query.rs +++ b/rust/lance-index/src/scalar/inverted/query.rs @@ -8,7 +8,7 @@ use serde::ser::SerializeMap; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FtsSearchParams { pub limit: Option, pub wand_factor: f32, diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 61b851095e..6e68656a55 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -103,7 +103,10 @@ use crate::io::exec::{ }; use crate::io::exec::{AddRowOffsetExec, LanceFilterExec, LanceScanConfig, get_physical_optimizer}; use crate::{Error, Result}; -use crate::{datatypes::Schema, io::exec::fts::BooleanQueryExec}; +use crate::{ + datatypes::Schema, + io::exec::fts::{BoolSlot, BooleanQueryExec, build_boolean_query_children}, +}; pub use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts}; #[cfg(feature = "substrait")] @@ -3240,83 +3243,48 @@ impl Scanner { // so that we won't miss possible matches let unlimited_params = params.clone().with_limit(None); - // For should queries, union the results of each subquery let mut should = Vec::with_capacity(query.should.len()); for subquery in &query.should { - let plan = Box::pin(self.plan_fts( - subquery, - &unlimited_params, - filter_plan, - prefilter_source, - )) - .await?; - should.push(plan); + should.push( + Box::pin(self.plan_fts( + subquery, + &unlimited_params, + filter_plan, + prefilter_source, + )) + .await?, + ); } - let should = if should.is_empty() { - Arc::new(EmptyExec::new(FTS_SCHEMA.clone())) - } else if should.len() == 1 { - should.pop().unwrap() - } else { - let unioned = UnionExec::try_new(should)?; - Arc::new(RepartitionExec::try_new( - unioned, - Partitioning::RoundRobinBatch(1), - )?) - }; - - // For must queries, inner join the results of each subquery on row_id - let mut must = None; - for query in &query.must { - let plan = Box::pin(self.plan_fts( - query, - &unlimited_params, - filter_plan, - prefilter_source, - )) - .await?; - if let Some(joined_plan) = must { - must = Some(Arc::new(HashJoinExec::try_new( - joined_plan, - plan, - vec![( - Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?), - Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?), - )], - None, - &datafusion_expr::JoinType::Inner, - None, - datafusion_physical_plan::joins::PartitionMode::CollectLeft, - NullEquality::NullEqualsNothing, - false, - )?) as _); - } else { - must = Some(plan); - } + let mut must = Vec::with_capacity(query.must.len()); + for subquery in &query.must { + must.push( + Box::pin(self.plan_fts( + subquery, + &unlimited_params, + filter_plan, + prefilter_source, + )) + .await?, + ); } - - // For must_not queries, union the results of each subquery let mut must_not = Vec::with_capacity(query.must_not.len()); - for query in &query.must_not { - let plan = Box::pin(self.plan_fts( - query, - &unlimited_params, - filter_plan, - prefilter_source, - )) - .await?; - must_not.push(plan); + for subquery in &query.must_not { + must_not.push( + Box::pin(self.plan_fts( + subquery, + &unlimited_params, + filter_plan, + prefilter_source, + )) + .await?, + ); } - let must_not = if must_not.is_empty() { - Arc::new(EmptyExec::new(FTS_SCHEMA.clone())) - } else if must_not.len() == 1 { - must_not.pop().unwrap() - } else { - let unioned = UnionExec::try_new(must_not)?; - Arc::new(RepartitionExec::try_new( - unioned, - Partitioning::RoundRobinBatch(1), - )?) - }; + + let should = build_boolean_query_children(BoolSlot::Should, should)? + .expect("Should slot always returns Some"); + let must = build_boolean_query_children(BoolSlot::Must, must)?; + let must_not = build_boolean_query_children(BoolSlot::MustNot, must_not)? + .expect("MustNot slot always returns Some"); if query.should.is_empty() && must.is_none() { return Err(Error::invalid_input( diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 149a2cfb46..0c38fd9238 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -6,6 +6,8 @@ pub(crate) mod inverted; +pub use inverted::{load_segment_details, load_segments}; + use std::sync::{Arc, LazyLock}; use crate::index::DatasetIndexExt; diff --git a/rust/lance/src/index/scalar/inverted.rs b/rust/lance/src/index/scalar/inverted.rs index c9094f6f01..5869b9c5bc 100644 --- a/rust/lance/src/index/scalar/inverted.rs +++ b/rust/lance/src/index/scalar/inverted.rs @@ -118,11 +118,14 @@ pub(crate) async fn build_segment( Ok(built_segment) } -/// Load all committed inverted-index segments that belong to the same named index. -pub(crate) async fn load_segments( - dataset: &Dataset, - column: &str, -) -> Result>> { +/// Load all committed inverted-index segments that belong to the same named +/// FTS index on `column`. +/// +/// Returns `Ok(None)` if no FTS index exists on the column. When an index +/// exists, the returned vector contains every committed segment's +/// [`IndexMetadata`] (UUID, fragment coverage, index details). All segments +/// must share the same indexed fields; mismatched fields return an error. +pub async fn load_segments(dataset: &Dataset, column: &str) -> Result>> { let Some(index_meta) = dataset .load_scalar_index( lance_index::IndexCriteria::default() @@ -152,8 +155,14 @@ pub(crate) async fn load_segments( Ok(Some(indices)) } -/// Load and validate the shared inverted-index details across committed segments. -pub(crate) async fn load_segment_details( +/// Load and validate the shared [`InvertedIndexDetails`] across committed +/// segments returned by [`load_segments`]. +/// +/// All segments are required to agree on their decoded `InvertedIndexDetails` +/// payload (analyzer, tokenizer, position settings, etc.); inconsistent +/// segments return an error. Returns the canonical details that may be used +/// when constructing a tokenizer or running a query against the index. +pub async fn load_segment_details( dataset: &Dataset, column: &str, segments: &[IndexMetadata], diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index ba775ecfe0..240fc70beb 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -1,21 +1,26 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; use arrow::array::{AsArray, BooleanBuilder}; use arrow::datatypes::{Float32Type, UInt64Type}; use arrow_array::{Array, BooleanArray, Float32Array, OffsetSizeTrait, RecordBatch, UInt64Array}; use arrow_schema::DataType; -use datafusion::common::Statistics; +use datafusion::common::{NullEquality, Statistics}; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{Distribution, EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::metrics::{BaselineMetrics, Count}; use futures::future::try_join_all; use futures::stream::{self}; @@ -42,7 +47,8 @@ use lance_index::scalar::inverted::query::{ }; use lance_index::scalar::inverted::tokenizer::document_tokenizer::TextTokenizer; use lance_index::scalar::inverted::{ - FTS_SCHEMA, InvertedIndex, MemBM25Scorer, SCORE_COL, flat_bm25_search_stream, + FTS_SCHEMA, InvertedIndex, MemBM25Scorer, SCORE_COL, build_global_bm25_scorer, + flat_bm25_search_stream, }; use lance_index::{prefilter::PreFilter, scalar::inverted::query::BooleanQuery}; use lance_tokenizer::{SimpleTokenizer, TextAnalyzer}; @@ -87,71 +93,6 @@ async fn open_fts_segments( .await } -/// Collect the unique terms needed to build a shared BM25 scorer. -/// -/// The scorer only needs corpus-level document frequencies, so we keep a deduplicated -/// term list here instead of constructing a full `Tokens` object with positions. -fn scorer_terms( - indices: &[Arc], - query_tokens: &Tokens, - params: &FtsSearchParams, -) -> Result> { - let mut terms = Vec::new(); - let mut seen = HashSet::new(); - - if !matches!(params.fuzziness, Some(n) if n != 0) { - for token in query_tokens { - if seen.insert(token.to_string()) { - terms.push(token.to_string()); - } - } - return Ok(terms); - } - - for index in indices { - let expanded = index.expand_fuzzy_tokens(query_tokens, params)?; - for idx in 0..expanded.len() { - let token = expanded.get_token(idx); - if seen.insert(token.to_string()) { - terms.push(token.to_string()); - } - } - } - Ok(terms) -} - -/// Build a shared BM25 scorer for a set of committed FTS segments. -fn build_global_bm25_scorer( - indices: &[Arc], - query_tokens: &Tokens, - params: &FtsSearchParams, -) -> Result { - let terms = scorer_terms(indices, query_tokens, params)?; - let first_index = indices.first().ok_or_else(|| { - Error::invalid_input("FTS index requires at least one segment".to_string()) - })?; - let (mut total_tokens, mut num_docs, first_token_docs) = - first_index.bm25_stats_for_terms(&terms); - let mut token_docs = HashMap::with_capacity(terms.len()); - for (term, count) in terms.iter().cloned().zip(first_token_docs.into_iter()) { - token_docs.insert(term, count); - } - - for index in indices.iter().skip(1) { - let (segment_total_tokens, segment_num_docs, segment_token_docs) = - index.bm25_stats_for_terms(&terms); - total_tokens += segment_total_tokens; - num_docs += segment_num_docs; - for (term, count) in terms.iter().zip(segment_token_docs.into_iter()) { - *token_docs - .get_mut(term) - .expect("global scorer terms should already be initialized") += count; - } - } - - Ok(MemBM25Scorer::new(total_tokens, num_docs, token_docs)) -} - async fn search_segments( indices: &[Arc], tokens: Arc, @@ -254,6 +195,12 @@ pub struct MatchQueryExec { query: MatchQuery, params: FtsSearchParams, prefilter_source: PreFilterSource, + /// When set, `execute()` skips `build_global_bm25_scorer` and threads this + /// scorer down to `InvertedIndex::bm25_search`. + base_scorer: Option>, + /// When set, `execute()` skips `load_segments` and searches exactly these + /// segments. + preset_segments: Option>, properties: Arc, metrics: ExecutionPlanMetricsSet, @@ -283,6 +230,15 @@ impl DisplayAs for MatchQueryExec { } impl MatchQueryExec { + /// Merge the fuzzy fields from `query` into `params` so that the stored + /// params reflect what BM25 stat collection and search will actually use. + fn effective_params(query: &MatchQuery, params: FtsSearchParams) -> FtsSearchParams { + params + .with_fuzziness(query.fuzziness) + .with_max_expansions(query.max_expansions) + .with_prefix_length(query.prefix_length) + } + pub fn new( dataset: Arc, query: MatchQuery, @@ -295,15 +251,94 @@ impl MatchQueryExec { EmissionType::Final, Boundedness::Bounded, )); + let params = Self::effective_params(&query, params); Self { dataset, query, params, prefilter_source, + base_scorer: None, + preset_segments: None, properties, metrics: ExecutionPlanMetricsSet::new(), } } + + /// Construct a `MatchQueryExec` bound to an explicit, pre-resolved set of + /// FTS segments. Unlike [`Self::new`], `execute()` will not call + /// [`load_segments`] — it will search exactly the segments supplied here. + /// + /// Useful when a caller has already enumerated segments and wants to scope + /// this exec to a strict subset — for example, a distributed query that + /// routes per-segment work across hosts, where each per-host leaf should + /// only search its own assigned subset of the dataset's committed + /// segments. + pub fn new_with_segments( + dataset: Arc, + query: MatchQuery, + params: FtsSearchParams, + prefilter_source: PreFilterSource, + segments: Vec, + ) -> Self { + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(FTS_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Final, + Boundedness::Bounded, + )); + let params = Self::effective_params(&query, params); + Self { + dataset, + query, + params, + prefilter_source, + base_scorer: None, + preset_segments: Some(segments), + properties, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Override the BM25 scorer used by `execute()`. When set, the local + /// `build_global_bm25_scorer` call is skipped and the supplied scorer is + /// threaded down to `InvertedIndex::bm25_search`. + /// + /// The default path builds a scorer from the segments this exec searches, + /// which is correct when those segments are the entire corpus. A caller + /// would override that scorer to keep BM25 IDFs corpus-wide when the exec + /// is searching only a subset — for example, a distributed query that + /// routes per-segment work to multiple hosts and aggregates stats + /// out-of-band, so each per-host leaf scores against the full corpus + /// rather than its local segment subset. See [`build_global_bm25_scorer`] + /// for constructing one. + pub fn with_base_scorer(mut self, scorer: Arc) -> Self { + self.base_scorer = Some(scorer); + self + } + + pub fn query(&self) -> &MatchQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn dataset(&self) -> &Arc { + &self.dataset + } + + pub fn prefilter_source(&self) -> &PreFilterSource { + &self.prefilter_source + } + + pub fn base_scorer(&self) -> Option<&Arc> { + self.base_scorer.as_ref() + } + + pub fn preset_segments(&self) -> Option<&[IndexMetadata]> { + self.preset_segments.as_deref() + } } impl ExecutionPlan for MatchQueryExec { @@ -348,6 +383,8 @@ impl ExecutionPlan for MatchQueryExec { query: self.query.clone(), params: self.params.clone(), prefilter_source: PreFilterSource::None, + base_scorer: self.base_scorer.clone(), + preset_segments: self.preset_segments.clone(), properties: self.properties.clone(), metrics: ExecutionPlanMetricsSet::new(), } @@ -373,6 +410,8 @@ impl ExecutionPlan for MatchQueryExec { query: self.query.clone(), params: self.params.clone(), prefilter_source, + base_scorer: self.base_scorer.clone(), + preset_segments: self.preset_segments.clone(), properties: self.properties.clone(), metrics: ExecutionPlanMetricsSet::new(), } @@ -396,6 +435,8 @@ impl ExecutionPlan for MatchQueryExec { let params = self.params.clone(); let ds = self.dataset.clone(); let prefilter_source = self.prefilter_source.clone(); + let preset_base_scorer = self.base_scorer.clone(); + let preset_segments = self.preset_segments.clone(); let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition)); let column = query.column.ok_or(DataFusionError::Execution(format!( "column not set for MatchQuery {}", @@ -403,12 +444,15 @@ impl ExecutionPlan for MatchQueryExec { )))?; let stream = stream::once(async move { let _timer = metrics.baseline_metrics.elapsed_compute().timer(); - let segments = load_segments(&ds, &column) - .await? - .ok_or(DataFusionError::Execution(format!( - "No Inverted index found for column {}", - column, - )))?; + let segments = match preset_segments { + Some(segments) => segments, + None => load_segments(&ds, &column) + .await? + .ok_or(DataFusionError::Execution(format!( + "No Inverted index found for column {}", + column, + )))?, + }; let _details = load_segment_details(&ds, &column, &segments).await?; let indices = open_fts_segments(&ds, &column, &segments, &metrics.index_metrics).await?; @@ -431,10 +475,6 @@ impl ExecutionPlan for MatchQueryExec { .record_parts_searched(indices.iter().map(|index| index.partition_count()).sum()); let is_fuzzy = matches!(query.fuzziness, Some(n) if n != 0); - let params = params - .with_fuzziness(query.fuzziness) - .with_max_expansions(query.max_expansions) - .with_prefix_length(query.prefix_length); let first_index = indices.first().ok_or(DataFusionError::Execution(format!( "FTS index for column {} has no segments", column @@ -454,7 +494,10 @@ impl ExecutionPlan for MatchQueryExec { } }; let tokens = collect_query_tokens(&query.terms, &mut tokenizer); - let base_scorer = build_global_bm25_scorer(&indices, &tokens, ¶ms)?; + let base_scorer = match preset_base_scorer { + Some(scorer) => scorer, + None => Arc::new(build_global_bm25_scorer(&indices, &tokens, ¶ms)?), + }; pre_filter.wait_for_ready().await?; let tokens = Arc::new(tokens); @@ -466,7 +509,7 @@ impl ExecutionPlan for MatchQueryExec { query.operator, pre_filter, metrics.clone(), - Arc::new(base_scorer), + base_scorer, ) .await?; scores.iter_mut().for_each(|s| { @@ -510,6 +553,11 @@ pub struct FlatMatchFilterExec { input: Arc, query: MatchQuery, params: FtsSearchParams, + /// Optional pre-resolved segment list. See + /// [`MatchQueryExec::new_with_segments`]. `FlatMatchFilterExec` only + /// uses the first segment's tokenizer, but the full list is preserved so + /// the field round-trips through `with_new_children`. + preset_segments: Option>, metrics: ExecutionPlanMetricsSet, } @@ -557,6 +605,20 @@ impl FlatMatchFilterExec { Ok(default_text_tokenizer()) } + async fn load_tokenizer_from_preset_segments( + dataset: &Dataset, + column: &str, + segments: &[IndexMetadata], + metrics: &IndexMetrics, + ) -> DataFusionResult> { + let index_meta = segments.first().ok_or_else(|| { + DataFusionError::Execution(format!("FTS index for column {} has no segments", column)) + })?; + Ok(open_fts_segment(dataset, column, index_meta, metrics) + .await? + .tokenizer()) + } + pub fn new( input: Arc, dataset: Arc, @@ -568,10 +630,47 @@ impl FlatMatchFilterExec { input, query, params, + preset_segments: None, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// See [`MatchQueryExec::new_with_segments`]. `FlatMatchFilterExec` + /// uses the first segment's tokenizer; the rest are kept for caller-side + /// bookkeeping. + pub fn new_with_segments( + input: Arc, + dataset: Arc, + query: MatchQuery, + params: FtsSearchParams, + segments: Vec, + ) -> Self { + Self { + dataset, + input, + query, + params, + preset_segments: Some(segments), metrics: ExecutionPlanMetricsSet::new(), } } + pub fn query(&self) -> &MatchQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn dataset(&self) -> &Arc { + &self.dataset + } + + pub fn preset_segments(&self) -> Option<&[IndexMetadata]> { + self.preset_segments.as_deref() + } + fn find_matches( text_col: &dyn Array, tokenizer: &mut Box, @@ -590,6 +689,7 @@ impl FlatMatchFilterExec { input: SendableRecordBatchStream, dataset: Arc, query: MatchQuery, + preset_segments: Option>, metrics: Arc, ) -> DataFusionResult> + Send> { let column = query @@ -599,7 +699,18 @@ impl FlatMatchFilterExec { "column not set for MatchQuery {}", query.terms )))?; - let mut tokenizer = Self::load_tokenizer(&dataset, column, &metrics.index_metrics).await?; + let mut tokenizer = match preset_segments { + Some(segments) => { + Self::load_tokenizer_from_preset_segments( + &dataset, + column, + &segments, + &metrics.index_metrics, + ) + .await? + } + None => Self::load_tokenizer(&dataset, column, &metrics.index_metrics).await?, + }; let query_tokens = Arc::new(collect_query_tokens(&query.terms, &mut tokenizer)); let column = column.clone(); @@ -658,6 +769,7 @@ impl ExecutionPlan for FlatMatchFilterExec { input, query: self.query.clone(), params: self.params.clone(), + preset_segments: self.preset_segments.clone(), metrics: ExecutionPlanMetricsSet::new(), })) } @@ -669,23 +781,25 @@ impl ExecutionPlan for FlatMatchFilterExec { context: Arc, ) -> DataFusionResult { let query = self.query.clone(); + let preset_segments = self.preset_segments.clone(); let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition)); let metrics_clone = metrics.clone(); let dataset = self.dataset.clone(); let input = self.input.execute(partition, context)?; - let stream = - stream::once(async move { Self::do_filter(input, dataset, query, metrics).await }) - .try_flatten() - .map(move |batch| { - if let Ok(batch) = &batch { - metrics_clone - .baseline_metrics - .record_output(batch.num_rows()); - } - batch - }); + let stream = stream::once(async move { + Self::do_filter(input, dataset, query, preset_segments, metrics).await + }) + .try_flatten() + .map(move |batch| { + if let Ok(batch) = &batch { + metrics_clone + .baseline_metrics + .record_output(batch.num_rows()); + } + batch + }); Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( self.schema(), stream.stream_in_current_span().boxed(), @@ -718,6 +832,12 @@ pub struct FlatMatchQueryExec { query: MatchQuery, params: FtsSearchParams, unindexed_input: Arc, + /// Optional override for the BM25 scorer normally built locally inside + /// `execute()`. See [`MatchQueryExec::with_base_scorer`]. + base_scorer: Option>, + /// Optional pre-resolved segment list. See + /// [`MatchQueryExec::new_with_segments`]. + preset_segments: Option>, properties: Arc, metrics: ExecutionPlanMetricsSet, @@ -764,10 +884,64 @@ impl FlatMatchQueryExec { query, params, unindexed_input, + base_scorer: None, + preset_segments: None, properties, metrics: ExecutionPlanMetricsSet::new(), } } + + /// See [`MatchQueryExec::new_with_segments`]. + pub fn new_with_segments( + dataset: Arc, + query: MatchQuery, + params: FtsSearchParams, + unindexed_input: Arc, + segments: Vec, + ) -> Self { + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(FTS_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Self { + dataset, + query, + params, + unindexed_input, + base_scorer: None, + preset_segments: Some(segments), + properties, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Override the local BM25 scorer; see [`MatchQueryExec::with_base_scorer`]. + pub fn with_base_scorer(mut self, scorer: Arc) -> Self { + self.base_scorer = Some(scorer); + self + } + + pub fn query(&self) -> &MatchQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn dataset(&self) -> &Arc { + &self.dataset + } + + pub fn base_scorer(&self) -> Option<&Arc> { + self.base_scorer.as_ref() + } + + pub fn preset_segments(&self) -> Option<&[IndexMetadata]> { + self.preset_segments.as_deref() + } } impl ExecutionPlan for FlatMatchQueryExec { @@ -798,6 +972,8 @@ impl ExecutionPlan for FlatMatchQueryExec { query: self.query.clone(), params: self.params.clone(), unindexed_input, + base_scorer: self.base_scorer.clone(), + preset_segments: self.preset_segments.clone(), properties: self.properties.clone(), metrics: ExecutionPlanMetricsSet::new(), })) @@ -811,6 +987,8 @@ impl ExecutionPlan for FlatMatchQueryExec { ) -> DataFusionResult { let query = self.query.clone(); let ds = self.dataset.clone(); + let preset_base_scorer = self.base_scorer.clone(); + let preset_segments = self.preset_segments.clone(); let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition)); let metrics_clone = metrics.clone(); let target_batch_size = context.session_config().batch_size(); @@ -823,7 +1001,10 @@ impl ExecutionPlan for FlatMatchQueryExec { document_input(self.unindexed_input.execute(partition, context)?, &column)?; let stream = stream::once(async move { - let segments = load_segments(&ds, &column).await?; + let segments = match preset_segments { + Some(segments) => Some(segments), + None => load_segments(&ds, &column).await?, + }; let (tokenizer, base_scorer) = match segments { Some(segments) => { let _details = load_segment_details(&ds, &column, &segments).await?; @@ -836,12 +1017,23 @@ impl ExecutionPlan for FlatMatchQueryExec { format!("FTS index for column {} has no segments", column), ))?; let mut tokenizer = first_index.tokenizer(); - let query_tokens = collect_query_tokens(&query.terms, &mut tokenizer); - let base_scorer = - build_global_bm25_scorer(&indices, &query_tokens, &FtsSearchParams::new())?; + let base_scorer = match preset_base_scorer { + Some(scorer) => (*scorer).clone(), + None => { + let query_tokens = collect_query_tokens(&query.terms, &mut tokenizer); + build_global_bm25_scorer( + &indices, + &query_tokens, + &FtsSearchParams::new(), + )? + } + }; (tokenizer, Some(base_scorer)) } - None => (default_text_tokenizer(), None), + None => ( + default_text_tokenizer(), + preset_base_scorer.map(|s| (*s).clone()), + ), }; flat_bm25_search_stream( @@ -890,6 +1082,12 @@ pub struct PhraseQueryExec { query: PhraseQuery, params: FtsSearchParams, prefilter_source: PreFilterSource, + /// Optional override for the BM25 scorer normally built locally inside + /// `execute()`. See [`MatchQueryExec::with_base_scorer`]. + base_scorer: Option>, + /// Optional pre-resolved segment list. See + /// [`MatchQueryExec::new_with_segments`]. + preset_segments: Option>, properties: Arc, metrics: ExecutionPlanMetricsSet, } @@ -937,10 +1135,70 @@ impl PhraseQueryExec { query, params, prefilter_source, + base_scorer: None, + preset_segments: None, + properties, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// See [`MatchQueryExec::new_with_segments`]. + pub fn new_with_segments( + dataset: Arc, + query: PhraseQuery, + mut params: FtsSearchParams, + prefilter_source: PreFilterSource, + segments: Vec, + ) -> Self { + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(FTS_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Final, + Boundedness::Bounded, + )); + params = params.with_phrase_slop(Some(query.slop)); + + Self { + dataset, + query, + params, + prefilter_source, + base_scorer: None, + preset_segments: Some(segments), properties, metrics: ExecutionPlanMetricsSet::new(), } } + + /// Override the local BM25 scorer; see [`MatchQueryExec::with_base_scorer`]. + pub fn with_base_scorer(mut self, scorer: Arc) -> Self { + self.base_scorer = Some(scorer); + self + } + + pub fn query(&self) -> &PhraseQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn dataset(&self) -> &Arc { + &self.dataset + } + + pub fn prefilter_source(&self) -> &PreFilterSource { + &self.prefilter_source + } + + pub fn base_scorer(&self) -> Option<&Arc> { + self.base_scorer.as_ref() + } + + pub fn preset_segments(&self) -> Option<&[IndexMetadata]> { + self.preset_segments.as_deref() + } } impl ExecutionPlan for PhraseQueryExec { @@ -978,6 +1236,8 @@ impl ExecutionPlan for PhraseQueryExec { query: self.query.clone(), params: self.params.clone(), prefilter_source: PreFilterSource::None, + base_scorer: self.base_scorer.clone(), + preset_segments: self.preset_segments.clone(), properties: self.properties.clone(), metrics: ExecutionPlanMetricsSet::new(), }, @@ -1001,6 +1261,8 @@ impl ExecutionPlan for PhraseQueryExec { query: self.query.clone(), params: self.params.clone(), prefilter_source, + base_scorer: self.base_scorer.clone(), + preset_segments: self.preset_segments.clone(), properties: self.properties.clone(), metrics: ExecutionPlanMetricsSet::new(), } @@ -1024,6 +1286,8 @@ impl ExecutionPlan for PhraseQueryExec { let params = self.params.clone(); let ds = self.dataset.clone(); let prefilter_source = self.prefilter_source.clone(); + let preset_base_scorer = self.base_scorer.clone(); + let preset_segments = self.preset_segments.clone(); let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition)); let stream = stream::once(async move { let _timer = metrics.baseline_metrics.elapsed_compute().timer(); @@ -1031,12 +1295,15 @@ impl ExecutionPlan for PhraseQueryExec { "column not set for PhraseQuery {}", query.terms )))?; - let segments = load_segments(&ds, &column) - .await? - .ok_or(DataFusionError::Execution(format!( - "No Inverted index found for column {}", - column, - )))?; + let segments = match preset_segments { + Some(segments) => segments, + None => load_segments(&ds, &column) + .await? + .ok_or(DataFusionError::Execution(format!( + "No Inverted index found for column {}", + column, + )))?, + }; let _details = load_segment_details(&ds, &column, &segments).await?; let indices = open_fts_segments(&ds, &column, &segments, &metrics.index_metrics).await?; @@ -1064,7 +1331,10 @@ impl ExecutionPlan for PhraseQueryExec { )))?; let mut tokenizer = first_index.tokenizer(); let tokens = collect_query_tokens(&query.terms, &mut tokenizer); - let base_scorer = build_global_bm25_scorer(&indices, &tokens, ¶ms)?; + let base_scorer = match preset_base_scorer { + Some(scorer) => scorer, + None => Arc::new(build_global_bm25_scorer(&indices, &tokens, ¶ms)?), + }; pre_filter.wait_for_ready().await?; let tokens = Arc::new(tokens); @@ -1076,7 +1346,7 @@ impl ExecutionPlan for PhraseQueryExec { lance_index::scalar::inverted::query::Operator::And, pre_filter, metrics.clone(), - Arc::new(base_scorer), + base_scorer, ) .await?; metrics.baseline_metrics.record_output(doc_ids.len()); @@ -1162,6 +1432,22 @@ impl BoostQueryExec { metrics: ExecutionPlanMetricsSet::new(), } } + + pub fn query(&self) -> &BoostQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn positive(&self) -> &Arc { + &self.positive + } + + pub fn negative(&self) -> &Arc { + &self.negative + } } impl ExecutionPlan for BoostQueryExec { @@ -1279,6 +1565,74 @@ impl ExecutionPlan for BoostQueryExec { } } +/// Identifies which clause of a [`BooleanQuery`] a list of child execs +/// belongs to. Used by [`build_boolean_query_children`] to pick the +/// right exec shape per slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BoolSlot { + Should, + Must, + MustNot, +} + +/// Combine N children into the per-slot exec shape that +/// [`BooleanQueryExec::new`] expects. Used by `Scanner::plan_fts` to +/// assemble the per-slot exec shape: +/// +/// | slot | 0 children | 1 child | N children | +/// |-----------|----------------------------|---------------|-----------------------------------------------------| +/// | Should | `Some(EmptyExec(FTS))` | `Some(child)` | `Some(Union -> Repartition(RoundRobinBatch(1)))` | +/// | Must | `None` | `Some(child)` | `Some(chained HashJoin on row_id)` | +/// | MustNot | `Some(EmptyExec(FTS))` | `Some(child)` | `Some(Union -> Repartition(RoundRobinBatch(1)))` | +/// +/// Errors only on internal invariants (HashJoin construction, Schema +/// lookups). Returns `Result>>` so the +/// `Must` slot's `None` case is naturally expressible. +pub fn build_boolean_query_children( + slot: BoolSlot, + mut children: Vec>, +) -> Result>> { + match slot { + BoolSlot::Should | BoolSlot::MustNot => { + if children.is_empty() { + Ok(Some(Arc::new(EmptyExec::new(FTS_SCHEMA.clone())))) + } else if children.len() == 1 { + Ok(Some(children.pop().unwrap())) + } else { + let unioned = UnionExec::try_new(children)?; + Ok(Some(Arc::new(RepartitionExec::try_new( + unioned, + Partitioning::RoundRobinBatch(1), + )?))) + } + } + BoolSlot::Must => { + let mut joined: Option> = None; + for plan in children { + if let Some(left) = joined { + joined = Some(Arc::new(HashJoinExec::try_new( + left, + plan, + vec![( + Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?), + Arc::new(Column::new_with_schema(ROW_ID, &FTS_SCHEMA)?), + )], + None, + &datafusion_expr::JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as _); + } else { + joined = Some(plan); + } + } + Ok(joined) + } + } +} + #[derive(Debug)] pub struct BooleanQueryExec { query: BooleanQuery, @@ -1342,6 +1696,26 @@ impl BooleanQueryExec { metrics: ExecutionPlanMetricsSet::new(), } } + + pub fn query(&self) -> &BooleanQuery { + &self.query + } + + pub fn params(&self) -> &FtsSearchParams { + &self.params + } + + pub fn should(&self) -> &Arc { + &self.should + } + + pub fn must(&self) -> Option<&Arc> { + self.must.as_ref() + } + + pub fn must_not(&self) -> &Arc { + &self.must_not + } } impl ExecutionPlan for BooleanQueryExec { @@ -1531,23 +1905,34 @@ mod tests { use std::sync::{Arc, Mutex}; use crate::index::DatasetIndexExt; + use arrow_array::{ + ArrayRef, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray, + UInt64Array, + }; + use arrow_schema::DataType; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::{execution::TaskContext, physical_plan::ExecutionPlan}; + use futures::TryStreamExt; + use lance_core::ROW_ID; use lance_datafusion::datagen::DatafusionDatagenExt; use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts}; use lance_datafusion::utils::PARTITIONS_SEARCHED_METRIC; use lance_datagen::{BatchCount, ByteCount, RowCount}; use lance_index::metrics::NoOpMetricsCollector; - use lance_index::scalar::inverted::InvertedIndex; - use lance_index::scalar::inverted::Language; use lance_index::scalar::inverted::query::{ BooleanQuery, BoostQuery, FtsQuery, FtsSearchParams, MatchQuery, Occur, Operator, PhraseQuery, collect_query_tokens, has_query_token, }; + use lance_index::scalar::inverted::{ + FTS_SCHEMA, InvertedIndex, Language, SCORE_COL, build_global_bm25_scorer, + }; use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams}; use lance_index::{IndexCriteria, IndexType}; + use lance_table::format::IndexMetadata; use crate::{ + Dataset, + dataset::WriteParams, dataset::transaction::{Operation, TransactionBuilder}, index::DatasetIndexInternalExt, io::exec::PreFilterSource, @@ -1555,9 +1940,14 @@ mod tests { }; use super::{ - BoostQueryExec, FlatMatchFilterExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec, + BoolSlot, BoostQueryExec, FlatMatchFilterExec, FlatMatchQueryExec, MatchQueryExec, + PhraseQueryExec, build_boolean_query_children, open_fts_segments, }; use crate::io::exec::utils::IndexMetrics; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::union::UnionExec; + use datafusion_physical_plan::joins::HashJoinExec; #[derive(Default)] struct StatsHolder { @@ -1850,4 +2240,321 @@ mod tests { "BooleanQuery metrics missing partitions_searched: {boolean_line}" ); } + + #[tokio::test] + async fn test_match_query_exec_with_base_scorer_matches_baseline() { + let test_dir = tempfile::tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + // Skewed term distributions across two fragments — "lance" is common in + // segment 0 and rare in segment 1 — so any local-IDF computation will + // disagree with the global-IDF baseline. That makes the test sensitive + // to a bug where `with_base_scorer` is silently ignored. + let batches = vec![ + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int32Array::from(vec![0, 1])) as ArrayRef), + ( + "text", + Arc::new(StringArray::from(vec![ + Some("lance database"), + Some("lance search"), + ])) as ArrayRef, + ), + ]) + .unwrap(), + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int32Array::from(vec![2, 3])) as ArrayRef), + ( + "text", + Arc::new(StringArray::from(vec![ + Some("alpha beta"), + Some("gamma lance"), + ])) as ArrayRef, + ), + ]) + .unwrap(), + ]; + let schema = batches[0].schema(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + let mut ds = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 2, + max_rows_per_group: 2, + ..Default::default() + }), + ) + .await + .unwrap(); + + let params = InvertedIndexParams::new("simple".to_string(), Language::English) + .with_position(false) + .lower_case(true) + .stem(false) + .remove_stop_words(false) + .ascii_folding(false) + .max_token_length(None); + let fragment_ids = ds + .get_fragments() + .iter() + .map(|fragment| fragment.id() as u32) + .collect::>(); + assert!( + fragment_ids.len() >= 2, + "test setup should produce >= 2 fragments, got {}", + fragment_ids.len() + ); + + let mut metadatas = Vec::::with_capacity(fragment_ids.len()); + for fragment_id in fragment_ids { + let mut builder = ds + .create_index_builder(&["text"], IndexType::Inverted, ¶ms) + .name("seg_fts".to_string()) + .fragments(vec![fragment_id]); + metadatas.push(builder.execute_uncommitted().await.unwrap()); + } + let segments = ds + .create_index_segment_builder() + .with_index_type(IndexType::Inverted) + .with_segments(metadatas.clone()) + .build_all() + .await + .unwrap(); + ds.commit_existing_index_segments("seg_fts", "text", segments) + .await + .unwrap(); + assert_eq!( + ds.load_indices_by_name("seg_fts").await.unwrap().len(), + metadatas.len(), + "expected one committed segment per fragment" + ); + + let dataset = Arc::new(ds); + let query = MatchQuery::new("lance".to_string()).with_column(Some("text".to_string())); + let search_params = FtsSearchParams::default().with_limit(Some(10)); + + // Baseline: the existing path that builds the global scorer locally. + let baseline_exec = MatchQueryExec::new( + dataset.clone(), + query.clone(), + search_params.clone(), + PreFilterSource::None, + ); + let baseline_batches: Vec = baseline_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap() + .try_collect() + .await + .unwrap(); + let baseline = concat_score_batches(&baseline_batches); + assert!( + !baseline.is_empty(), + "baseline should return at least one hit" + ); + + // Override: build the global scorer manually via the public helper, then + // construct the exec with the preset segments and the preset scorer. + let preset_segments = crate::index::scalar::inverted::load_segments(&dataset, "text") + .await + .unwrap() + .expect("FTS index just created"); + let metrics_set = ExecutionPlanMetricsSet::new(); + let metrics = IndexMetrics::new(&metrics_set, 0); + let indices = open_fts_segments(&dataset, "text", &preset_segments, &metrics) + .await + .unwrap(); + assert!( + indices.len() >= 2, + "expected >= 2 segments to exercise global IDF, got {}", + indices.len() + ); + let mut tokenizer = indices[0].tokenizer(); + let tokens = collect_query_tokens(&query.terms, &mut tokenizer); + let global_scorer = + Arc::new(build_global_bm25_scorer(&indices, &tokens, &search_params).unwrap()); + + let override_exec = MatchQueryExec::new_with_segments( + dataset.clone(), + query.clone(), + search_params.clone(), + PreFilterSource::None, + preset_segments, + ) + .with_base_scorer(global_scorer); + let override_batches: Vec = override_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap() + .try_collect() + .await + .unwrap(); + let overridden = concat_score_batches(&override_batches); + + assert_eq!( + baseline.len(), + overridden.len(), + "row count differs: baseline={}, override={}", + baseline.len(), + overridden.len() + ); + for (i, (b, o)) in baseline.iter().zip(overridden.iter()).enumerate() { + assert_eq!( + b.0, o.0, + "row id mismatch at rank {}: baseline={}, override={}", + i, b.0, o.0 + ); + assert_eq!( + b.1, o.1, + "score mismatch at rank {} (row id {}): baseline={}, override={}", + i, b.0, b.1, o.1 + ); + } + + // Sanity check on FTS schema before extracting columns above. + for batch in baseline_batches.iter().chain(override_batches.iter()) { + assert!( + batch.column_by_name(ROW_ID).is_some(), + "FTS output is expected to carry a row id column" + ); + assert_eq!( + batch.column_by_name(SCORE_COL).unwrap().data_type(), + &DataType::Float32, + "FTS score column should be Float32" + ); + } + + // Locally-bound helper: collect (row_id, score) pairs sorted by score desc. + fn concat_score_batches(batches: &[RecordBatch]) -> Vec<(u64, f32)> { + let mut out: Vec<(u64, f32)> = Vec::new(); + for batch in batches { + let row_ids = batch + .column_by_name(ROW_ID) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let scores = batch + .column_by_name(SCORE_COL) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + out.push((row_ids.value(i), scores.value(i))); + } + } + // Stable order for diffing — descending score, ties broken by row id. + out.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0))); + out + } + } + + fn empty_fts_child() -> Arc { + Arc::new(EmptyExec::new(FTS_SCHEMA.clone())) + } + + #[test] + fn build_boolean_should_empty_returns_empty_exec() { + let plan = build_boolean_query_children(BoolSlot::Should, vec![]) + .unwrap() + .expect("Should slot always returns Some"); + assert!( + plan.as_any().downcast_ref::().is_some(), + "expected EmptyExec for empty Should slot, got {plan:?}" + ); + } + + #[test] + fn build_boolean_should_single_child_passthrough() { + let child = empty_fts_child(); + let child_ptr = Arc::as_ptr(&child); + let plan = build_boolean_query_children(BoolSlot::Should, vec![child]) + .unwrap() + .expect("Should slot always returns Some"); + assert_eq!( + Arc::as_ptr(&plan), + child_ptr, + "single-child Should should return the child unchanged" + ); + } + + #[test] + fn build_boolean_should_multi_child_union_repartition() { + let plan = build_boolean_query_children( + BoolSlot::Should, + vec![empty_fts_child(), empty_fts_child()], + ) + .unwrap() + .expect("Should slot always returns Some"); + let repartition = plan + .as_any() + .downcast_ref::() + .expect("multi-child Should should be wrapped in RepartitionExec"); + let inner = repartition + .input() + .as_any() + .downcast_ref::() + .expect("RepartitionExec should wrap a UnionExec"); + assert_eq!(inner.children().len(), 2); + } + + #[test] + fn build_boolean_must_empty_returns_none() { + let plan = build_boolean_query_children(BoolSlot::Must, vec![]).unwrap(); + assert!(plan.is_none(), "empty Must slot should return None"); + } + + #[test] + fn build_boolean_must_single_child_passthrough_some() { + let child = empty_fts_child(); + let child_ptr = Arc::as_ptr(&child); + let plan = build_boolean_query_children(BoolSlot::Must, vec![child]) + .unwrap() + .expect("single-child Must should be Some"); + assert_eq!( + Arc::as_ptr(&plan), + child_ptr, + "single-child Must should return the child unchanged" + ); + } + + #[test] + fn build_boolean_must_multi_child_chained_hashjoin() { + let children = vec![empty_fts_child(), empty_fts_child(), empty_fts_child()]; + let n = children.len(); + let plan = build_boolean_query_children(BoolSlot::Must, children) + .unwrap() + .expect("multi-child Must should be Some"); + + // Walk the left spine: each layer is a HashJoinExec whose left child is + // either another HashJoinExec or the original leaf. With N children + // there are N-1 joins. + let mut joins = 0usize; + let mut current: Arc = plan; + while let Some(join) = current.clone().as_any().downcast_ref::() { + joins += 1; + current = join.children()[0].clone(); + } + assert_eq!(joins, n - 1, "expected {} joins for {n} children", n - 1); + } + + #[test] + fn build_boolean_must_not_multi_child_union_repartition() { + let plan = build_boolean_query_children( + BoolSlot::MustNot, + vec![empty_fts_child(), empty_fts_child()], + ) + .unwrap() + .expect("MustNot slot always returns Some"); + let repartition = plan + .as_any() + .downcast_ref::() + .expect("multi-child MustNot should be wrapped in RepartitionExec"); + let inner = repartition + .input() + .as_any() + .downcast_ref::() + .expect("RepartitionExec should wrap a UnionExec"); + assert_eq!(inner.children().len(), 2); + } }