diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 693c2b3cfe7..dfa4e2d419f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,9 +23,6 @@ jobs: sudo apt install -y protobuf-compiler libssl-dev - name: Run cargo fmt run: cargo fmt --check - - name: Run clippy - #run: cargo clippy -- --deny "warnings" - run: cargo clippy - name: Run tests run: | cargo build --all-features diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index ae30698f9e1..6262450fb27 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -23,9 +23,9 @@ use datafusion::execution::{ context::SessionState, runtime_env::{RuntimeConfig, RuntimeEnv}, }; -use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::{ - limit::GlobalLimitExec, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream, + filter::FilterExec, limit::GlobalLimitExec, ExecutionPlan, PhysicalExpr, + SendableRecordBatchStream, }; use datafusion::prelude::*; use futures::stream::{Stream, StreamExt}; @@ -36,9 +36,7 @@ use crate::datafusion::physical_expr::column_names_in_expr; use crate::datatypes::Schema; use crate::format::Index; use crate::index::vector::{MetricType, Query}; -use crate::io::exec::{ - KNNFlatExec, KNNIndexExec, LanceScanExec, LocalTakeExec, ProjectionExec, TakeExec, -}; +use crate::io::exec::{KNNFlatExec, KNNIndexExec, LanceScanExec, ProjectionExec, TakeExec}; use crate::utils::sql::parse_sql_filter; use crate::{Error, Result}; @@ -273,11 +271,11 @@ impl Scanner { } } - let knn_node = self.ann(q, &index); // score, _rowid + let knn_node = self.ann(q, &index)?; // score, _rowid let with_vector = self.dataset.schema().project(&[&q.column])?; let knn_node_with_vector = self.take(knn_node, &with_vector)?; let knn_node = if q.refine_factor.is_some() { - self.flat_knn(knn_node_with_vector, q) + self.flat_knn(knn_node_with_vector, q)? } else { knn_node_with_vector }; // vector, score, _rowid @@ -290,7 +288,7 @@ impl Scanner { let vector_scan_projection = Arc::new(self.dataset.schema().project(&[&q.column]).unwrap()); let scan_node = self.scan(true, vector_scan_projection); - let knn_node = self.flat_knn(scan_node, q); + let knn_node = self.flat_knn(scan_node, q)?; let knn_node = filter_expr .map(|f| self.filter_knn(knn_node.clone(), f)) @@ -308,7 +306,7 @@ impl Scanner { )?, ); let scan = self.scan(true, filter_schema); - self.filter_node(filter, scan, true, None)? + self.filter_node(filter, scan)? } else { self.scan(with_row_id, Arc::new(self.projections.clone())) }; @@ -346,12 +344,7 @@ impl Scanner { knn_node, Arc::new(filter_projection), )?); - self.filter_node( - filter_expression, - take_node, - false, - Some(Arc::new(self.vector_search_schema()?)), - ) + self.filter_node(filter_expression, take_node) } /// Create an Execution plan with a scan node @@ -367,17 +360,17 @@ impl Scanner { } /// Add a knn search node to the input plan - fn flat_knn(&self, input: Arc, q: &Query) -> Arc { - Arc::new(KNNFlatExec::new(input, q.clone())) + fn flat_knn(&self, input: Arc, q: &Query) -> Result> { + Ok(Arc::new(KNNFlatExec::try_new(input, q.clone())?)) } /// Create an Execution plan to do indexed ANN search - fn ann(&self, q: &Query, index: &Index) -> Arc { - Arc::new(KNNIndexExec::new( + fn ann(&self, q: &Query, index: &Index) -> Result> { + Ok(Arc::new(KNNIndexExec::try_new( self.dataset.clone(), &index.uuid.to_string(), q, - )) + )?)) } /// Take row indices produced by input plan from the dataset (with projection) @@ -406,18 +399,14 @@ impl Scanner { &self, filter: Arc, input: Arc, - drop_row_id: bool, - ann_schema: Option>, ) -> Result> { let filter_node = Arc::new(FilterExec::try_new(filter, input)?); let output_schema = self.scanner_output_schema()?; - Ok(Arc::new(LocalTakeExec::new( - filter_node, + Ok(Arc::new(TakeExec::try_new( self.dataset.clone(), + filter_node, output_schema, - ann_schema, - drop_row_id, - ))) + )?)) } } @@ -451,16 +440,23 @@ impl Stream for RecordBatchStream { #[cfg(test)] mod test { + use std::collections::BTreeSet; use std::path::PathBuf; use super::*; + use arrow::array::as_primitive_array; use arrow::compute::concat_batches; - use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatchReader, StringArray}; + use arrow::datatypes::Int32Type; + use arrow_array::{ + ArrayRef, FixedSizeListArray, Int32Array, Int64Array, RecordBatchReader, StringArray, + }; use arrow_schema::DataType; use futures::TryStreamExt; use tempfile::tempdir; + use crate::index::vector::VectorIndexParams; + use crate::index::IndexType; use crate::{arrow::RecordBatchBuffer, dataset::WriteParams}; #[tokio::test] @@ -620,4 +616,214 @@ mod test { .unwrap(); expected_batches } + + async fn create_vector_dataset(path: &str, build_index: bool) -> Dataset { + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new("s", DataType::Utf8, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ])); + + let batches = RecordBatchBuffer::new( + (0..5) + .map(|i| { + let vector_values: Float32Array = (0..32 * 80).map(|v| v as f32).collect(); + let vectors = FixedSizeListArray::try_new(&vector_values, 32).unwrap(); + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(i * 80..(i + 1) * 80)), + Arc::new(StringArray::from_iter_values( + (i * 80..(i + 1) * 80).map(|v| format!("s-{}", v)), + )), + Arc::new(vectors), + ], + ) + .unwrap() + }) + .collect(), + ); + + let mut params = WriteParams::default(); + params.max_rows_per_group = 10; + let mut reader: Box = Box::new(batches); + + let dataset = Dataset::write(&mut reader, path, Some(params)) + .await + .unwrap(); + + if build_index { + let params = VectorIndexParams::ivf_pq(2, 8, 2, false, MetricType::L2, 2); + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some("idx".to_string()), + ¶ms, + true, + ) + .await + .unwrap(); + } + + Dataset::open(path).await.unwrap() + } + + #[tokio::test] + async fn test_knn_nodes() { + for build_index in &[true, false] { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, *build_index).await; + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 5).unwrap(); + scan.refine(5); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.num_rows(), 5); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new("s", DataType::Utf8, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ArrowField::new("score", DataType::Float32, false), + ]) + ); + + let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); + } + } + + #[tokio::test] + async fn test_knn_with_filter() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, true).await; + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 5).unwrap(); + scan.filter("i > 100").unwrap(); + scan.project(&["i"]).unwrap(); + scan.refine(5); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.num_rows(), 3); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ArrowField::new("score", DataType::Float32, false), + ]) + ); + + let expected_i = BTreeSet::from_iter(vec![161, 241, 321]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); + } + + #[tokio::test] + async fn test_refine_factor() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let dataset = create_vector_dataset(test_uri, true).await; + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 5).unwrap(); + scan.refine(5); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.num_rows(), 5); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new("s", DataType::Utf8, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ArrowField::new("score", DataType::Float32, false), + ]) + ); + + let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); + } } diff --git a/rust/src/index/vector.rs b/rust/src/index/vector.rs index f53495ad63e..70f555b0f19 100644 --- a/rust/src/index/vector.rs +++ b/rust/src/index/vector.rs @@ -53,7 +53,7 @@ const MAX_ITERATIONS: usize = 50; /// Maximum number of iterations for OPQ. /// See OPQ paper for details. const MAX_OPQ_ITERATIONS: usize = 100; -const SCORE_COL: &str = "score"; +pub(crate) const SCORE_COL: &str = "score"; const INDEX_FILE_NAME: &str = "index.idx"; /// Query parameters for the vector indices diff --git a/rust/src/index/vector/flat.rs b/rust/src/index/vector/flat.rs index 8d6cbe8e8d4..b4aba8a3dd9 100644 --- a/rust/src/index/vector/flat.rs +++ b/rust/src/index/vector/flat.rs @@ -26,7 +26,7 @@ use arrow_select::{concat::concat_batches, take::take}; use async_trait::async_trait; use futures::stream::{repeat_with, Stream, StreamExt, TryStreamExt}; -use super::{Query, VectorIndex}; +use super::{Query, VectorIndex, SCORE_COL}; use crate::arrow::*; use crate::dataset::Dataset; use crate::io::object_reader::ObjectReader; @@ -63,16 +63,14 @@ pub async fn flat_search( stream: impl Stream>, query: &Query, ) -> Result { - const SCORE_COLUMN: &str = "score"; - let batches = stream .zip(repeat_with(|| query.metric_type)) .map(|(batch, mt)| async move { let k = query.key.clone(); let mut batch = batch?; - if batch.column_by_name(SCORE_COLUMN).is_some() { + if batch.column_by_name(SCORE_COL).is_some() { // Ignore the score calculated from inner vector index. - batch = batch.drop_column(SCORE_COLUMN)?; + batch = batch.drop_column(SCORE_COL)?; } let vectors = batch .column_by_name(&query.column) @@ -88,10 +86,8 @@ pub async fn flat_search( // TODO: use heap let indices = sort_to_indices(&scores, None, Some(query.k))?; - let batch_with_score = batch.try_with_column( - ArrowField::new(SCORE_COLUMN, DataType::Float32, false), - scores, - )?; + let batch_with_score = batch + .try_with_column(ArrowField::new(SCORE_COL, DataType::Float32, false), scores)?; let struct_arr = StructArray::from(batch_with_score); let selected_arr = take(&struct_arr, &indices, None)?; Ok::(as_struct_array(&selected_arr).into()) @@ -100,7 +96,7 @@ pub async fn flat_search( .try_collect::>() .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; - let scores = batch.column_by_name(SCORE_COLUMN).unwrap(); + let scores = batch.column_by_name(SCORE_COL).unwrap(); let indices = sort_to_indices(scores, None, Some(query.k))?; let struct_arr = StructArray::from(batch); diff --git a/rust/src/io/exec.rs b/rust/src/io/exec.rs index c497cd4595c..aeda258d257 100644 --- a/rust/src/io/exec.rs +++ b/rust/src/io/exec.rs @@ -17,9 +17,11 @@ mod planner; mod projection; mod scan; mod take; +#[cfg(test)] +pub(crate) mod testing; pub use knn::*; pub use planner::Planner; pub(crate) use projection::ProjectionExec; pub use scan::LanceScanExec; -pub(crate) use take::{LocalTakeExec, TakeExec}; +pub(crate) use take::TakeExec; diff --git a/rust/src/io/exec/knn.rs b/rust/src/io/exec/knn.rs index 6e398998ad1..2d1ac120dd8 100644 --- a/rust/src/io/exec/knn.rs +++ b/rust/src/io/exec/knn.rs @@ -31,7 +31,8 @@ use tokio::task::JoinHandle; use crate::dataset::scanner::RecordBatchStream; use crate::dataset::{Dataset, ROW_ID}; use crate::index::vector::flat::flat_search; -use crate::index::vector::{open_index, Query}; +use crate::index::vector::{open_index, Query, SCORE_COL}; +use crate::{Error, Result}; /// KNN node for post-filtering. pub struct KNNFlatStream { @@ -91,9 +92,17 @@ impl DFRecordBatchStream for KNNFlatStream { } } -/// Physical [ExecutionPlan] for Flat KNN node. +/// [ExecutionPlan] for Flat KNN (bruteforce) search. +/// +/// Preconditions: +/// - `input` schema must contains `query.column`, +/// - The column must be a vector. +/// - `input` schema does not have "score" column. pub struct KNNFlatExec { + /// Input node. input: Arc, + + /// The query to execute. query: Query, } @@ -108,8 +117,29 @@ impl std::fmt::Debug for KNNFlatExec { } impl KNNFlatExec { - pub fn new(input: Arc, query: Query) -> Self { - Self { input, query } + /// Create a new [KNNFlatExec] node. + /// + /// Returns an error if the preconditions are not met. + pub fn try_new(input: Arc, query: Query) -> Result { + let schema = input.schema(); + let field = schema.field_with_name(&query.column).map_err(|_| { + Error::IO(format!( + "KNNFlatExec node: query column {} not found in input schema", + query.column + )) + })?; + let is_vector = match field.data_type() { + DataType::FixedSizeList(item, _) => item.as_ref().data_type() == &DataType::Float32, + _ => false, + }; + if !is_vector { + return Err(Error::IO(format!( + "KNNFlatExec node: query column {} is not a vector", + query.column + ))); + }; + + Ok(Self { input, query }) } } @@ -118,11 +148,16 @@ impl ExecutionPlan for KNNFlatExec { self } + /// Flat KNN inherits the schema from input node, and add one score column. fn schema(&self) -> arrow_schema::SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("score", DataType::Float32, false), - Field::new(ROW_ID, DataType::UInt16, false), - ])) + let input_schema = self.input.schema(); + let mut fields = input_schema.fields().to_vec(); + fields.push(Field::new(SCORE_COL, DataType::Float32, false)); + + Arc::new(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )) } fn output_partitioning(&self) -> Partitioning { @@ -154,7 +189,10 @@ impl ExecutionPlan for KNNFlatExec { } fn statistics(&self) -> Statistics { - todo!() + Statistics { + num_rows: Some(self.query.k as usize), + ..Default::default() + } } } @@ -212,7 +250,10 @@ impl KNNIndexStream { impl DFRecordBatchStream for KNNIndexStream { fn schema(&self) -> arrow_schema::SchemaRef { - todo!() + Arc::new(Schema::new(vec![ + Field::new(SCORE_COL, DataType::Float32, false), + Field::new(ROW_ID, DataType::UInt16, false), + ])) } } @@ -226,8 +267,11 @@ impl Stream for KNNIndexStream { /// [ExecutionPlan] for KNNIndex node. pub struct KNNIndexExec { + /// Dataset to read from. dataset: Arc, + /// The UUID of the index. index_name: String, + /// The vector query to execute. query: Query, } @@ -242,12 +286,21 @@ impl std::fmt::Debug for KNNIndexExec { } impl KNNIndexExec { - pub fn new(dataset: Arc, index_name: &str, query: &Query) -> Self { - Self { + /// Create a new [KNNIndexExec]. + pub fn try_new(dataset: Arc, index_name: &str, query: &Query) -> Result { + let schema = dataset.schema(); + if schema.field(query.column.as_str()).is_none() { + return Err(Error::IO(format!( + "KNNIndexExec node: query column {} does not exist in dataset.", + query.column + ))); + }; + + Ok(Self { dataset, index_name: index_name.to_string(), query: query.clone(), - } + }) } } @@ -258,7 +311,7 @@ impl ExecutionPlan for KNNIndexExec { fn schema(&self) -> arrow_schema::SchemaRef { Arc::new(Schema::new(vec![ - Field::new("score", DataType::Float32, false), + Field::new(SCORE_COL, DataType::Float32, false), Field::new(ROW_ID, DataType::UInt16, false), ])) } @@ -296,7 +349,10 @@ impl ExecutionPlan for KNNIndexExec { } fn statistics(&self) -> datafusion::physical_plan::Statistics { - todo!() + Statistics { + num_rows: Some(self.query.k * self.query.refine_factor.unwrap_or(1) as usize), + ..Default::default() + } } } @@ -316,6 +372,7 @@ mod tests { use crate::arrow::*; use crate::dataset::{Dataset, WriteParams}; use crate::index::vector::MetricType; + use crate::io::exec::testing::TestingExec; use crate::utils::testing::generate_random_array; #[tokio::test] @@ -400,4 +457,51 @@ mod tests { assert_eq!(expected, results[0]); } + + #[test] + fn test_create_knn_flat() { + let dim: usize = 128; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("key", DataType::Int32, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + dim as i32, + ), + true, + ), + ArrowField::new("uri", DataType::Utf8, true), + ])); + let batch = RecordBatch::new_empty(schema.clone()); + + let query = Query { + column: "vector".to_string(), + key: Arc::new(generate_random_array(dim)), + k: 10, + nprobes: 0, + refine_factor: None, + metric_type: MetricType::L2, + use_index: false, + }; + + let input: Arc = Arc::new(TestingExec::new(vec![batch.into()])); + let idx = KNNFlatExec::try_new(input, query).unwrap(); + assert_eq!( + idx.schema().as_ref(), + &ArrowSchema::new(vec![ + ArrowField::new("key", DataType::Int32, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Box::new(ArrowField::new("item", DataType::Float32, true)), + dim as i32, + ), + true, + ), + ArrowField::new("uri", DataType::Utf8, true), + ArrowField::new(SCORE_COL, DataType::Float32, false), + ]) + ); + } } diff --git a/rust/src/io/exec/take.rs b/rust/src/io/exec/take.rs index 33ad925fa33..6c3a2d24aed 100644 --- a/rust/src/io/exec/take.rs +++ b/rust/src/io/exec/take.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow_array::{cast::as_primitive_array, RecordBatch, UInt64Array}; -use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; +use arrow_schema::{Schema as ArrowSchema, SchemaRef}; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use futures::stream::{self, Stream, StreamExt, TryStreamExt}; @@ -169,17 +169,6 @@ impl TakeExec { } } -fn projection_with_row_id(projection: &Schema, drop_row_id: bool) -> SchemaRef { - let schema = ArrowSchema::from(projection); - if drop_row_id { - Arc::new(schema) - } else { - let mut fields = schema.fields; - fields.push(Field::new(ROW_ID, DataType::UInt64, false)); - Arc::new(ArrowSchema::new(fields)) - } -} - impl ExecutionPlan for TakeExec { fn as_any(&self) -> &dyn std::any::Any { self @@ -232,206 +221,12 @@ impl ExecutionPlan for TakeExec { } } -pub struct LocalTake { - /// The output schema. - schema: Arc, - - rx: Receiver>, - _bg_thread: JoinHandle<()>, -} - -impl LocalTake { - pub fn try_new( - input: SendableRecordBatchStream, - dataset: Arc, - schema: Arc, - ann_schema: Option>, // TODO add input/output schema contract to exec nodes and remove this - drop_row_id: bool, - ) -> Result { - let (tx, rx) = mpsc::channel(4); - let inner_schema = Schema::try_from(input.schema().as_ref())?; - let mut take_schema = schema.exclude(&inner_schema)?; - if ann_schema.is_some() { - take_schema = take_schema.exclude(&ann_schema.unwrap())?; - } - let projection = schema.clone(); - - let _bg_thread = tokio::spawn(async move { - if let Err(e) = input - .zip(stream::repeat_with(|| { - (dataset.clone(), take_schema.clone(), projection.clone()) - })) - .then(|(b, (dataset, take_schema, projection))| async move { - // TODO: need to cache the fragments. - let batch = b?; - let projection_schema = ArrowSchema::from(projection.as_ref()); - if batch.num_rows() == 0 { - return Ok(RecordBatch::new_empty(Arc::new(projection_schema))); - } - - let row_id_arr = batch.column_by_name(ROW_ID).unwrap(); - let row_ids: &UInt64Array = as_primitive_array(row_id_arr); - let batch = if take_schema.fields.is_empty() { - batch.project_by_schema(&projection_schema)? - } else { - let remaining_columns = - dataset.take_rows(row_ids.values(), &take_schema).await?; - batch - .merge(&remaining_columns)? - .project_by_schema(&projection_schema)? - }; - - if !drop_row_id { - Ok(batch.try_with_column( - Field::new(ROW_ID, DataType::UInt64, false), - Arc::new(row_id_arr.clone()), - )?) - } else { - Ok(batch) - } - }) - .try_for_each(|b| async { - if tx.is_closed() { - return Err(datafusion::error::DataFusionError::Execution( - "ExecNode(Take): channel closed".to_string(), - )); - } - if let Err(_) = tx.send(Ok(b)).await { - return Err(datafusion::error::DataFusionError::Execution( - "ExecNode(Take): channel closed".to_string(), - )); - } - Ok(()) - }) - .await - { - if let Err(e) = tx.send(Err(e)).await { - eprintln!("ExecNode(Take): {}", e); - } - } - drop(tx) - }); - Ok(Self { - schema, - rx, - _bg_thread, - }) - } -} - -impl Stream for LocalTake { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::into_inner(self).rx.poll_recv(cx) - } -} - -impl RecordBatchStream for LocalTake { - fn schema(&self) -> SchemaRef { - Arc::new(self.schema.as_ref().into()) - } -} - -/// [LocalTakeExec] is a physical [`ExecutionPlan`] that takes the rows within the same fragment -/// as its children [super::LanceScanExec] node. -/// -/// It is used to support filter/predicates push-down: -/// -/// `LocalTakeExec` -> `FilterExec` -> `LanceScanExec`: -/// -#[derive(Debug)] -pub struct LocalTakeExec { - dataset: Arc, - input: Arc, - schema: Arc, - ann_schema: Option>, - drop_row_id: bool, -} - -impl LocalTakeExec { - pub fn new( - input: Arc, - dataset: Arc, - schema: Arc, - ann_schema: Option>, - drop_row_id: bool, - ) -> Self { - assert!(input.schema().column_with_name(ROW_ID).is_some()); - Self { - dataset, - input, - schema, - ann_schema, - drop_row_id, - } - } -} - -impl ExecutionPlan for LocalTakeExec { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> SchemaRef { - projection_with_row_id(&self.schema, self.drop_row_id) - } - - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { - self.input.output_ordering() - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - if children.len() != 1 { - return Err(DataFusionError::Plan( - "LocalTakeExec only takes 1 child".to_string(), - )); - } - Ok(Arc::new(Self { - input: children[0].clone(), - dataset: self.dataset.clone(), - schema: self.schema.clone(), - ann_schema: self.ann_schema.clone(), - drop_row_id: self.drop_row_id, - })) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let input_stream = self.input.execute(partition, context)?; - Ok(Box::pin(LocalTake::try_new( - input_stream, - self.dataset.clone(), - self.schema.clone(), - self.ann_schema.clone(), - self.drop_row_id, - )?)) - } - - fn statistics(&self) -> datafusion::physical_plan::Statistics { - self.input.statistics() - } -} - #[cfg(test)] mod tests { use super::*; use arrow_array::{ArrayRef, Float32Array, Int32Array, RecordBatchReader, StringArray}; + use arrow_schema::{DataType, Field}; use tempfile::tempdir; use crate::{dataset::WriteParams, io::exec::LanceScanExec}; diff --git a/rust/src/io/exec/testing.rs b/rust/src/io/exec/testing.rs new file mode 100644 index 00000000000..9d06737cf84 --- /dev/null +++ b/rust/src/io/exec/testing.rs @@ -0,0 +1,77 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Testing Node +//! + +use std::any::Any; +use std::sync::Arc; + +use arrow_array::RecordBatch; +use datafusion::{ + execution::context::TaskContext, + physical_plan::{ExecutionPlan, SendableRecordBatchStream}, +}; + +#[derive(Debug)] +pub(crate) struct TestingExec { + pub(crate) batches: Vec, +} + +impl TestingExec { + pub(crate) fn new(batches: Vec) -> Self { + Self { batches } + } +} + +impl ExecutionPlan for TestingExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + self.batches[0].schema() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + todo!() + } + + fn output_ordering(&self) -> Option<&[datafusion::physical_expr::PhysicalSortExpr]> { + todo!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::error::Result> { + todo!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> datafusion::error::Result { + todo!() + } + + fn statistics(&self) -> datafusion::physical_plan::Statistics { + todo!() + } +}