From 9186ec016f49f3a698f3c8ed7dbdf875f6356de2 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 23 Mar 2023 23:07:39 -0700 Subject: [PATCH 01/20] list out the plan graph --- rust/src/dataset/scanner.rs | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 6262450fb27..32e7f1078c2 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -327,6 +327,47 @@ impl Scanner { )) } + /// Create [`ExecutionPlan`] for Scan. + /// + /// An ExecutionPlan is a graph of operators that can be executed. + /// + /// The following plans are supported: + /// + /// - **Plain scan without filter or limits.** + /// + /// ``` + /// Scan(projections) + /// ``` + /// + /// - **Scan with filter and/or limits.** + /// + /// ``` + /// Scan(filtered_cols) -> Filter(expr) + /// -> (*LimitExec(limit, offset)) + /// -> Take(remaining_cols) -> Projection() + /// ``` + /// + /// - **Use KNN Index (with filter and/or limits)** + /// + /// ``` + /// KNNIndex() -> Take(vector) -> FlatRefine() + /// -> Take(filtered_cols) -> Filter(expr) + /// -> (*LimitExec(limit, offset)) + /// -> Take(remaining_cols) -> Projection() + /// ``` + /// + /// - **Use KNN flat (brute force) with filter and/or limits** + /// + /// ``` + /// Scan(vector) -> FlatKNN() + /// -> Take(filtered_cols) -> Filter(expr) + /// -> (*LimitExec(limit, offset)) + /// -> Take(remaining_cols) -> Projection() + /// ``` + fn create_plan(&self) -> Result> { + todo!() + } + fn filter_knn( &self, knn_node: Arc, From 1db6231b6241aa9505a43581d387d23ed317b92f Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 23 Mar 2023 23:38:51 -0700 Subject: [PATCH 02/20] change project --- rust/src/dataset/scanner.rs | 91 ++++++++++++++++++++++++++++++------- rust/src/datatypes.rs | 6 +-- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 32e7f1078c2..3a07c869fda 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -33,7 +33,7 @@ use futures::stream::{Stream, StreamExt}; use super::Dataset; use crate::arrow::*; use crate::datafusion::physical_expr::column_names_in_expr; -use crate::datatypes::Schema; +use crate::datatypes::{Field, Schema}; use crate::format::Index; use crate::index::vector::{MetricType, Query}; use crate::io::exec::{KNNFlatExec, KNNIndexExec, LanceScanExec, ProjectionExec, TakeExec}; @@ -210,23 +210,27 @@ impl Scanner { /// The Arrow schema of the output, including projections and vector / score pub fn schema(&self) -> Result { let schema = self - .scanner_output_schema() + .output_schema() .map(|s| SchemaRef::new(ArrowSchema::from(s.as_ref())))?; - if self.with_row_id { - let row_id = ArrowField::new(ROW_ID, DataType::UInt64, false); - let schema = schema.as_ref().try_with_column(row_id)?; - Ok(schema.into()) - } else { - Ok(schema) - } + Ok(schema) } - fn scanner_output_schema(&self) -> Result> { - if self.nearest.as_ref().is_some() { - let merged = self.projections.merge(&self.vector_search_schema()?); - Ok(Arc::new(merged)) + fn output_schema(&self) -> Result> { + let schema = if self.nearest.as_ref().is_some() { + self.projections.merge(&self.vector_search_schema()?) + } else { + self.projections.clone() + }; + if self.with_row_id { + let row_id_schema = Schema::try_from(&ArrowSchema::new(vec![ArrowField::new( + ROW_ID, + DataType::UInt64, + false, + )]))?; + let schema = schema.merge(&row_id_schema); + Ok(schema.into()) } else { - Ok(Arc::new(self.projections.clone())) + Ok(schema.into()) } } @@ -364,7 +368,59 @@ impl Scanner { /// -> (*LimitExec(limit, offset)) /// -> Take(remaining_cols) -> Projection() /// ``` - fn create_plan(&self) -> Result> { + /// + /// In general, a plan has 4 stages: + /// + /// 1. Source (from dataset Scan or from index) + /// 2. Filter + /// 3. Limit / Offset + /// 4. Take remaining columns / Projection + async fn create_plan(&self) -> Result> { + let filter_expr = if let Some(filter) = self.filter.as_ref() { + let planner = Planner::new(Arc::new(self.dataset.schema().into())); + let logical_expr = planner.parse_filter(filter)?; + Some(planner.create_physical_expr(&logical_expr)?) + } else { + None + }; + + // Stage 1: source + let mut plan: Arc = if self.nearest.is_some() { + self.knn().await? + } else if let Some(expr) = filter_expr { + let columns_in_filter = column_names_in_expr(expr.as_ref()); + let filter_schema = Arc::new( + self.dataset.schema().project( + &columns_in_filter + .iter() + .map(|s| s.as_str()) + .collect::>(), + )?, + ); + self.scan(self.with_row_id, filter_schema) + } else { + // Scan without filter or limits + self.scan(self.with_row_id, self.output_schema()?) + }; + + // Stage 2: filter + Ok(plan) + } + + // + async fn knn(&self) -> Result> { + let Some(q) = self.nearest.as_ref() else { + return Err(Error::IO("No nearest query".to_string())); + }; + + let column_id = self.dataset.schema().field_id(q.column.as_str())?; + let use_index = self.nearest.as_ref().map(|q| q.use_index).unwrap_or(false); + let indices = if use_index { + self.dataset.load_indices().await? + } else { + vec![] + }; + let idx = indices.iter().find(|i| i.fields.contains(&column_id)); todo!() } @@ -442,8 +498,9 @@ impl Scanner { input: Arc, ) -> Result> { let filter_node = Arc::new(FilterExec::try_new(filter, input)?); - let output_schema = self.scanner_output_schema()?; - Ok(Arc::new(TakeExec::try_new( + let output_schema = self.output_schema()?; + Ok(Arc::new(LocalTakeExec::new( + filter_node, self.dataset.clone(), filter_node, output_schema, diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 67e54adf304..d5e8d236408 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -679,10 +679,10 @@ impl Schema { /// let schema = Schema::from(...); /// let projected = schema.project(&["col1", "col2.sub_col3.field4"])?; /// ``` - pub fn project(&self, columns: &[&str]) -> Result { + pub fn project>(&self, columns: &[T]) -> Result { let mut candidates: Vec = vec![]; for col in columns { - let split = (*col).split('.').collect::>(); + let split = col.as_ref().split('.').collect::>(); let first = split[0]; if let Some(field) = self.field(first) { let projected_field = field.project(&split[1..])?; @@ -692,7 +692,7 @@ impl Schema { candidates.push(projected_field) } } else { - return Err(Error::Schema(format!("Column {} does not exist", col))); + return Err(Error::Schema(format!("Column {} does not exist", col.as_ref()))); } } From 36a95b20921b869205c91adf5b99c89337160677 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 23 Mar 2023 23:42:49 -0700 Subject: [PATCH 03/20] extract out knn node --- rust/src/dataset/scanner.rs | 84 ++++++++++++++++--------------------- rust/src/datatypes.rs | 5 ++- 2 files changed, 40 insertions(+), 49 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 3a07c869fda..438ad35800b 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -257,48 +257,12 @@ impl Scanner { None }; - let mut plan: Arc = if let Some(q) = self.nearest.as_ref() { - let column_id = self.dataset.schema().field_id(q.column.as_str())?; - let use_index = self.nearest.as_ref().map(|q| q.use_index).unwrap_or(false); - let indices = if use_index { - self.dataset.load_indices().await? - } else { - vec![] - }; - let qcol_index = indices.iter().find(|i| i.fields.contains(&column_id)); - if let Some(index) = qcol_index { - // There is an index built for the column. - // We will use the index. - if let Some(rf) = q.refine_factor { - if rf == 0 { - return Err(Error::IO("Refine factor can not be zero".to_string())); - } - } - - let knn_node = self.ann(q, &index)?; // score, _rowid - let with_vector = self.dataset.schema().project(&[&q.column])?; - let knn_node_with_vector = self.take(knn_node, &with_vector)?; - let knn_node = if q.refine_factor.is_some() { - self.flat_knn(knn_node_with_vector, q)? - } else { - knn_node_with_vector - }; // vector, score, _rowid - - let knn_node = filter_expr - .map(|f| self.filter_knn(knn_node.clone(), f)) - .unwrap_or(Ok(knn_node))?; // vector, score, _rowid - self.take(knn_node, projection)? - } else { - let vector_scan_projection = - Arc::new(self.dataset.schema().project(&[&q.column]).unwrap()); - let scan_node = self.scan(true, vector_scan_projection); - let knn_node = self.flat_knn(scan_node, q)?; - - let knn_node = filter_expr - .map(|f| self.filter_knn(knn_node.clone(), f)) - .unwrap_or(Ok(knn_node))?; // vector, score, _rowid - self.take(knn_node, projection)? - } + let mut plan: Arc = if self.nearest.is_some() { + let knn_node = self.knn().await?; + let knn_node = filter_expr + .map(|f| self.filter_knn(knn_node.clone(), f)) + .unwrap_or(Ok(knn_node))?; // vector, score, _rowid + self.take(knn_node, projection)? } else if let Some(filter) = filter_expr { let columns_in_filter = column_names_in_expr(filter.as_ref()); let filter_schema = Arc::new( @@ -339,13 +303,13 @@ impl Scanner { /// /// - **Plain scan without filter or limits.** /// - /// ``` + /// ```ignore /// Scan(projections) /// ``` /// /// - **Scan with filter and/or limits.** /// - /// ``` + /// ```ignore /// Scan(filtered_cols) -> Filter(expr) /// -> (*LimitExec(limit, offset)) /// -> Take(remaining_cols) -> Projection() @@ -353,7 +317,7 @@ impl Scanner { /// /// - **Use KNN Index (with filter and/or limits)** /// - /// ``` + /// ```ignore /// KNNIndex() -> Take(vector) -> FlatRefine() /// -> Take(filtered_cols) -> Filter(expr) /// -> (*LimitExec(limit, offset)) @@ -362,7 +326,7 @@ impl Scanner { /// /// - **Use KNN flat (brute force) with filter and/or limits** /// - /// ``` + /// ```ignore /// Scan(vector) -> FlatKNN() /// -> Take(filtered_cols) -> Filter(expr) /// -> (*LimitExec(limit, offset)) @@ -420,8 +384,32 @@ impl Scanner { } else { vec![] }; - let idx = indices.iter().find(|i| i.fields.contains(&column_id)); - todo!() + let knn_idx = indices.iter().find(|i| i.fields.contains(&column_id)); + if let Some(index) = knn_idx { + // There is an index built for the column. + // We will use the index. + if let Some(rf) = q.refine_factor { + if rf == 0 { + return Err(Error::IO("Refine factor can not be zero".to_string())); + } + } + + let knn_node = self.ann(q, &index); // score, _rowid + let with_vector = self.dataset.schema().project(&[&q.column])?; + let knn_node_with_vector = self.take(knn_node, &with_vector)?; + let knn_node = if q.refine_factor.is_some() { + self.flat_knn(knn_node_with_vector, q) + } else { + knn_node_with_vector + }; // vector, score, _rowid + Ok(knn_node) + } else { + // No index found. use flat search. + let vector_scan_projection = + Arc::new(self.dataset.schema().project(&[&q.column]).unwrap()); + let scan_node = self.scan(true, vector_scan_projection); + Ok(self.flat_knn(scan_node, q)) + } } fn filter_knn( diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index d5e8d236408..60a311b52f8 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -692,7 +692,10 @@ impl Schema { candidates.push(projected_field) } } else { - return Err(Error::Schema(format!("Column {} does not exist", col.as_ref()))); + return Err(Error::Schema(format!( + "Column {} does not exist", + col.as_ref() + ))); } } From d18e89cf56cad4b9493ed58387421f4c2fa77485 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 23 Mar 2023 23:43:20 -0700 Subject: [PATCH 04/20] extract out knn node --- rust/src/dataset/scanner.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 438ad35800b..ab6d772129f 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -368,6 +368,12 @@ impl Scanner { }; // Stage 2: filter + + // Stage 3: limit / offset + if (self.limit.unwrap_or(0) > 0) || self.offset.is_some() { + plan = self.limit_node(plan); + } + Ok(plan) } From a19100c818ec508e8e3dfc3c9ad8125ff5f8a144 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 23 Mar 2023 23:49:49 -0700 Subject: [PATCH 05/20] cargo fmt --- rust/src/dataset/scanner.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index ab6d772129f..3bef7e8c24c 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -31,9 +31,8 @@ use datafusion::prelude::*; use futures::stream::{Stream, StreamExt}; use super::Dataset; -use crate::arrow::*; use crate::datafusion::physical_expr::column_names_in_expr; -use crate::datatypes::{Field, Schema}; +use crate::datatypes::Schema; use crate::format::Index; use crate::index::vector::{MetricType, Query}; use crate::io::exec::{KNNFlatExec, KNNIndexExec, LanceScanExec, ProjectionExec, TakeExec}; @@ -351,16 +350,9 @@ impl Scanner { // Stage 1: source let mut plan: Arc = if self.nearest.is_some() { self.knn().await? - } else if let Some(expr) = filter_expr { + } else if let Some(expr) = filter_expr.as_ref() { let columns_in_filter = column_names_in_expr(expr.as_ref()); - let filter_schema = Arc::new( - self.dataset.schema().project( - &columns_in_filter - .iter() - .map(|s| s.as_str()) - .collect::>(), - )?, - ); + let filter_schema = Arc::new(self.dataset.schema().project(&columns_in_filter)?); self.scan(self.with_row_id, filter_schema) } else { // Scan without filter or limits @@ -368,12 +360,17 @@ impl Scanner { }; // Stage 2: filter + if let Some(expr) = filter_expr.as_ref() { + todo!("filter") + } // Stage 3: limit / offset if (self.limit.unwrap_or(0) > 0) || self.offset.is_some() { plan = self.limit_node(plan); } + // Stage 4: take remaining columns / projection + Ok(plan) } From c74f51bbe06c79812d3989832d34bd112f75c639 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 00:28:39 -0700 Subject: [PATCH 06/20] make exclude takes arrow schema as well --- rust/src/dataset/scanner.rs | 46 ++++++++++++++++++++----------------- rust/src/datatypes.rs | 17 ++++++++++++-- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 3bef7e8c24c..66a28ab92ba 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow_array::{Float32Array, RecordBatch}; -use arrow_schema::DataType::{self, Float32}; +use arrow_schema::DataType; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema, SchemaRef}; use datafusion::execution::{ context::SessionState, @@ -215,30 +215,31 @@ impl Scanner { } fn output_schema(&self) -> Result> { - let schema = if self.nearest.as_ref().is_some() { - self.projections.merge(&self.vector_search_schema()?) - } else { - self.projections.clone() + let mut extra_columns = vec![]; + + if let Some(q) = self.nearest.as_ref() { + let vector_field = self + .dataset + .schema() + .field(&q.column) + .ok_or(Error::IO(format!("Column {} not found", q.column)))?; + let vector_field = ArrowField::try_from(vector_field).map_err(|e| { + Error::IO(format!("Failed to convert vector field: {}", e.to_string())) + })?; + extra_columns.push(vector_field); + extra_columns.push(ArrowField::new("score", DataType::Float32, false)); }; if self.with_row_id { - let row_id_schema = Schema::try_from(&ArrowSchema::new(vec![ArrowField::new( - ROW_ID, - DataType::UInt64, - false, - )]))?; - let schema = schema.merge(&row_id_schema); - Ok(schema.into()) - } else { - Ok(schema.into()) + extra_columns.push(ArrowField::new(ROW_ID, DataType::UInt64, false)); } - } - fn vector_search_schema(&self) -> Result { - let q = self.nearest.as_ref().unwrap(); - let vector_schema = self.dataset.schema().project(&[&q.column])?; - let score = ArrowField::new("score", Float32, false); - let score_schema = Schema::try_from(&ArrowSchema::new(vec![score]))?; - Ok(vector_schema.merge(&score_schema)) + let schema = if !extra_columns.is_empty() { + let extra_schema = Schema::try_from(&ArrowSchema::new(extra_columns))?; + self.projections.merge(&extra_schema) + } else { + self.projections.clone() + }; + Ok(Arc::new(schema)) } /// Create a stream of this Scanner. @@ -361,6 +362,9 @@ impl Scanner { // Stage 2: filter if let Some(expr) = filter_expr.as_ref() { + let columns_in_filter = column_names_in_expr(expr.as_ref()); + let filter_schema = Arc::new(self.dataset.schema().project(&columns_in_filter)?); + let remaining_schema = filter_schema.exclude(plan.schema().as_ref())?; todo!("filter") } diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 60a311b52f8..698cb10ef9b 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -2,7 +2,7 @@ use std::cmp::max; use std::collections::HashMap; -use std::fmt::Formatter; +use std::fmt::{Formatter, Debug}; use std::fmt::{self}; use arrow_array::cast::{as_large_list_array, as_list_array}; @@ -717,7 +717,13 @@ impl Schema { } /// Exclude the fields from `other` Schema, and returns a new Schema. - pub fn exclude(&self, other: &Self) -> Result { + pub fn exclude + Debug>(&self, other: T) -> Result { + let other = other.try_into().map_err(|_| { + Error::Schema(format!( + "The other schema {:?} is not compatible with this schema", + other + )) + })?; let mut fields = vec![]; for field in self.fields.iter() { if let Some(other_field) = other.field(&field.name) { @@ -857,6 +863,13 @@ impl From<&Schema> for ArrowSchema { } } +/// Convert Lance Schema to Arrow Schema +impl From<&Schema> for Schema { + fn from(schema: &Schema) -> Self { + schema.clone() + } +} + /// Convert list of protobuf `Field` to a Schema. impl From<&Vec> for Schema { fn from(fields: &Vec) -> Self { From 2079f4c76a7cb179c0aaf793cb4bf6bdc26782d7 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 00:43:57 -0700 Subject: [PATCH 07/20] separate create plan --- rust/src/dataset/scanner.rs | 64 ++++++++++++------------------------- rust/src/datatypes.rs | 11 +++---- 2 files changed, 26 insertions(+), 49 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 66a28ab92ba..57e03d6e540 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -246,45 +246,7 @@ impl Scanner { /// /// TODO: implement as IntoStream/IntoIterator. pub async fn try_into_stream(&self) -> Result { - let with_row_id = self.with_row_id; - let projection = &self.projections; - - let filter_expr = if let Some(filter) = self.filter.as_ref() { - let planner = crate::io::exec::Planner::new(Arc::new(self.dataset.schema().into())); - let logical_expr = planner.parse_filter(filter)?; - Some(planner.create_physical_expr(&logical_expr)?) - } else { - None - }; - - let mut plan: Arc = if self.nearest.is_some() { - let knn_node = self.knn().await?; - let knn_node = filter_expr - .map(|f| self.filter_knn(knn_node.clone(), f)) - .unwrap_or(Ok(knn_node))?; // vector, score, _rowid - self.take(knn_node, projection)? - } else if let Some(filter) = filter_expr { - let columns_in_filter = column_names_in_expr(filter.as_ref()); - let filter_schema = Arc::new( - self.dataset.schema().project( - &columns_in_filter - .iter() - .map(|s| s.as_str()) - .collect::>(), - )?, - ); - let scan = self.scan(true, filter_schema); - self.filter_node(filter, scan)? - } else { - self.scan(with_row_id, Arc::new(self.projections.clone())) - }; - - if (self.limit.unwrap_or(0) > 0) || self.offset.is_some() { - plan = self.limit_node(plan); - } - - let project_schema = Schema::try_from(self.schema()?.as_ref())?; - plan = Arc::new(ProjectionExec::try_new(plan, project_schema.into())?); + let plan = self.create_plan().await?; let session_config = SessionConfig::new(); let runtime_config = RuntimeConfig::new(); @@ -354,18 +316,22 @@ impl Scanner { } else if let Some(expr) = filter_expr.as_ref() { let columns_in_filter = column_names_in_expr(expr.as_ref()); let filter_schema = Arc::new(self.dataset.schema().project(&columns_in_filter)?); - self.scan(self.with_row_id, filter_schema) + self.scan(true, filter_schema) } else { // Scan without filter or limits self.scan(self.with_row_id, self.output_schema()?) }; // Stage 2: filter - if let Some(expr) = filter_expr.as_ref() { - let columns_in_filter = column_names_in_expr(expr.as_ref()); + if let Some(predicates) = filter_expr.as_ref() { + let columns_in_filter = column_names_in_expr(predicates.as_ref()); let filter_schema = Arc::new(self.dataset.schema().project(&columns_in_filter)?); let remaining_schema = filter_schema.exclude(plan.schema().as_ref())?; - todo!("filter") + if !remaining_schema.fields.is_empty() { + // Not all columns for filter are ready, so we need to take them first + plan = self.take(plan, &remaining_schema)?; + } + plan = Arc::new(FilterExec::try_new(predicates.clone(), plan)?); } // Stage 3: limit / offset @@ -374,6 +340,12 @@ impl Scanner { } // Stage 4: take remaining columns / projection + let output_schema = self.output_schema()?; + let remaining_schema = output_schema.exclude(plan.schema().as_ref())?; + if !remaining_schema.fields.is_empty() { + plan = self.take(plan, &remaining_schema)?; + } + plan = Arc::new(ProjectionExec::try_new(plan, output_schema)?); Ok(plan) } @@ -419,6 +391,7 @@ impl Scanner { } } +<<<<<<< HEAD fn filter_knn( &self, knn_node: Arc, @@ -439,6 +412,8 @@ impl Scanner { self.filter_node(filter_expression, take_node) } +======= +>>>>>>> b330654 (separate create plan) /// Create an Execution plan with a scan node fn scan(&self, with_row_id: bool, projection: Arc) -> Arc { Arc::new(LanceScanExec::new( @@ -486,6 +461,7 @@ impl Scanner { self.limit.map(|l| l as usize), )) } +<<<<<<< HEAD fn filter_node( &self, @@ -501,6 +477,8 @@ impl Scanner { output_schema, )?)) } +======= +>>>>>>> b330654 (separate create plan) } /// ScannerStream is a container to wrap different types of ExecNode. diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 698cb10ef9b..24fc2114a13 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -717,12 +717,11 @@ impl Schema { } /// Exclude the fields from `other` Schema, and returns a new Schema. - pub fn exclude + Debug>(&self, other: T) -> Result { - let other = other.try_into().map_err(|_| { - Error::Schema(format!( - "The other schema {:?} is not compatible with this schema", - other - )) + pub fn exclude + Debug>(&self, schema: T) -> Result { + let other = schema.try_into().map_err(|_| { + Error::Schema( + "The other schema is not compatible with this schema" + .to_string()) })?; let mut fields = vec![]; for field in self.fields.iter() { From 6eddfa8cab77312adfd054fb15eb1ea1e10a3717 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 00:58:33 -0700 Subject: [PATCH 08/20] print --- rust/src/dataset/scanner.rs | 43 +----------------------------------- rust/src/index/vector/opq.rs | 4 ++-- 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 57e03d6e540..351f0bbb50a 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -391,29 +391,6 @@ impl Scanner { } } -<<<<<<< HEAD - fn filter_knn( - &self, - knn_node: Arc, - filter_expression: Arc, - ) -> Result> { - let columns_in_filter = column_names_in_expr(filter_expression.as_ref()); - let columns_refs = columns_in_filter - .iter() - .map(|c| c.as_str()) - .collect::>(); - let filter_projection = self.dataset.schema().project(&columns_refs)?; - - let take_node = Arc::new(TakeExec::try_new( - self.dataset.clone(), - knn_node, - Arc::new(filter_projection), - )?); - self.filter_node(filter_expression, take_node) - } - -======= ->>>>>>> b330654 (separate create plan) /// Create an Execution plan with a scan node fn scan(&self, with_row_id: bool, projection: Arc) -> Arc { Arc::new(LanceScanExec::new( @@ -460,25 +437,7 @@ impl Scanner { *self.offset.as_ref().unwrap_or(&0) as usize, self.limit.map(|l| l as usize), )) - } -<<<<<<< HEAD - - fn filter_node( - &self, - filter: Arc, - input: Arc, - ) -> Result> { - let filter_node = Arc::new(FilterExec::try_new(filter, input)?); - let output_schema = self.output_schema()?; - Ok(Arc::new(LocalTakeExec::new( - filter_node, - self.dataset.clone(), - filter_node, - output_schema, - )?)) - } -======= ->>>>>>> b330654 (separate create plan) + }= } /// ScannerStream is a container to wrap different types of ExecNode. diff --git a/rust/src/index/vector/opq.rs b/rust/src/index/vector/opq.rs index 51289f872ef..56b3ac0f1c9 100644 --- a/rust/src/index/vector/opq.rs +++ b/rust/src/index/vector/opq.rs @@ -364,7 +364,6 @@ mod tests { .next() .unwrap() .unwrap(); - println!("{:?}", index_file.path()); let uuid = index_file.file_name().to_str().unwrap().to_string(); let index = open_index(&dataset, &uuid).await.unwrap(); @@ -390,9 +389,10 @@ mod tests { key: Float32Array::from_iter_values((0..64).map(|x| x as f32 + 640.0)).into(), }; let results = index.search(&query).await.unwrap(); - + println!("{:?}", results); let row_ids: &UInt64Array = as_primitive_array(&results[ROW_ID]); assert_eq!(row_ids.len(), 4); + println!("{:?}", row_ids.values()); assert!(row_ids.values().contains(&10)); assert_eq!(min(row_ids).unwrap() + 3, max(row_ids).unwrap()); } From 471da28125302058b60108854c37509847591d6c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 10:01:05 -0700 Subject: [PATCH 09/20] fix plain scan --- rust/src/dataset/scanner.rs | 7 +++++-- rust/src/index/vector/ivf.rs | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 351f0bbb50a..2b4db117cac 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -319,7 +319,7 @@ impl Scanner { self.scan(true, filter_schema) } else { // Scan without filter or limits - self.scan(self.with_row_id, self.output_schema()?) + self.scan(self.with_row_id, self.projections.clone().into()) }; // Stage 2: filter @@ -345,7 +345,7 @@ impl Scanner { if !remaining_schema.fields.is_empty() { plan = self.take(plan, &remaining_schema)?; } - plan = Arc::new(ProjectionExec::try_new(plan, output_schema)?); + plan = Arc::new(ProjectionExec::try_new(plan, output_schema.clone())?); Ok(plan) } @@ -856,4 +856,7 @@ mod test { .collect(); assert_eq!(expected_i, actual_i); } + + #[tokio::test] + async fn test_simple_scan_plan() {} } diff --git a/rust/src/index/vector/ivf.rs b/rust/src/index/vector/ivf.rs index fdf6257aa0f..0134c2d8ac1 100644 --- a/rust/src/index/vector/ivf.rs +++ b/rust/src/index/vector/ivf.rs @@ -465,6 +465,7 @@ async fn maybe_sample_training_data( .await?; concat_batches(&Arc::new(ArrowSchema::from(&projection)), &batches)? }; + let array = batch.column_by_name(column).ok_or(Error::Index(format!( "Sample training data: column {} does not exist in return", column From 1fd942e4b85f9df2db3f4b268b799f1a0723acb6 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 10:04:23 -0700 Subject: [PATCH 10/20] cargo fmt --- rust/src/dataset/scanner.rs | 2 +- rust/src/datatypes.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 2b4db117cac..be271baf287 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -331,7 +331,7 @@ impl Scanner { // Not all columns for filter are ready, so we need to take them first plan = self.take(plan, &remaining_schema)?; } - plan = Arc::new(FilterExec::try_new(predicates.clone(), plan)?); + plan = Arc::new(FilterExec::try_new(predicates.clone(), plan)?); } // Stage 3: limit / offset diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index 24fc2114a13..5628d6df4bb 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -2,8 +2,8 @@ use std::cmp::max; use std::collections::HashMap; -use std::fmt::{Formatter, Debug}; use std::fmt::{self}; +use std::fmt::{Debug, Formatter}; use arrow_array::cast::{as_large_list_array, as_list_array}; use arrow_array::types::{ @@ -719,9 +719,7 @@ impl Schema { /// Exclude the fields from `other` Schema, and returns a new Schema. pub fn exclude + Debug>(&self, schema: T) -> Result { let other = schema.try_into().map_err(|_| { - Error::Schema( - "The other schema is not compatible with this schema" - .to_string()) + Error::Schema("The other schema is not compatible with this schema".to_string()) })?; let mut fields = vec![]; for field in self.fields.iter() { From 3c3928a30921c81f5273608838cdfbb134155bff Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 10:32:55 -0700 Subject: [PATCH 11/20] add test for scan with row id --- rust/src/dataset/scanner.rs | 124 +++++++++++++++++++++++++++++++++++- 1 file changed, 123 insertions(+), 1 deletion(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index be271baf287..9a9d9f35460 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -487,6 +487,8 @@ mod test { use crate::index::vector::VectorIndexParams; use crate::index::IndexType; + use super::*; + use crate::arrow::*; use crate::{arrow::RecordBatchBuffer, dataset::WriteParams}; #[tokio::test] @@ -647,7 +649,11 @@ mod test { expected_batches } +<<<<<<< HEAD async fn create_vector_dataset(path: &str, build_index: bool) -> Dataset { +======= + async fn create_dataset() -> Arc { +>>>>>>> e0c6aed (add test for scan with row id) let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("i", DataType::Int32, true), ArrowField::new("s", DataType::Utf8, true), @@ -661,6 +667,7 @@ mod test { ), ])); +<<<<<<< HEAD let batches = RecordBatchBuffer::new( (0..5) .map(|i| { @@ -855,8 +862,123 @@ mod test { .copied() .collect(); assert_eq!(expected_i, actual_i); +======= + let vector_data = Float32Array::from_iter_values((0..3200).map(|v| v as f32)); + let vector = FixedSizeListArray::try_new(&vector_data, 32).unwrap(); + let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..100)), + Arc::new(StringArray::from_iter_values( + (0..100).map(|v| format!("s-{}", v)), + )), + Arc::new(vector), + ], + ) + .unwrap()]); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let mut params = WriteParams::default(); + params.max_rows_per_group = 100; + let mut reader: Box = Box::new(batches); + Dataset::write(&mut reader, test_uri, Some(params)) + .await + .unwrap(); + Arc::new(Dataset::open(test_uri).await.unwrap()) +>>>>>>> e0c6aed (add test for scan with row id) } #[tokio::test] - async fn test_simple_scan_plan() {} + async fn test_simple_scan_plan() { + let dataset = create_dataset().await; + let scan = dataset.scan(); + let plan = scan.create_plan().await.unwrap(); + + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "s", "vec"] + ); + let scan = &plan.children()[0]; + assert!(scan.as_any().is::()); + assert_eq!( + scan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "s", "vec"] + ); + + let mut scan = dataset.scan(); + scan.project(&["s"]).unwrap(); + let plan = scan.create_plan().await.unwrap(); + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["s"] + ); + let scan = &plan.children()[0]; + assert!(scan.as_any().is::()); + assert_eq!( + scan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["s"] + ); + } + + #[tokio::test] + async fn test_scan_with_row_id() { + let dataset = create_dataset().await; + let mut scan = dataset.scan(); + scan.project(&["i"]).unwrap(); + scan.with_row_id(); + let plan = scan.create_plan().await.unwrap(); + + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "_rowid"] + ); + let scan = &plan.children()[0]; + assert!(scan.as_any().is::()); + assert_eq!( + scan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "_rowid"] + ); + } + + // #[tokio::test] + // async fn test_filter_plan() { + // let dataset = create_dataset().await; + // let mut scan = dataset.scan(); + // scan.filter("i > 50").unwrap(); + // let plan = scan.create_plan().await.unwrap(); + + // println!("Plan is {:?}", plan); + // assert!(plan.as_any().is::()); + // let filter = plan.as_any().downcast_ref::().unwrap(); + // // assert!(filter.input.as_any().is::()); + // // assert_eq!(filter.predicate, "i > 50".to_string()); + // } } From e24f6dc5a147ca684619df9c4033ae580ed0c460 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 10:35:18 -0700 Subject: [PATCH 12/20] remove prints --- rust/src/index/vector/opq.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/rust/src/index/vector/opq.rs b/rust/src/index/vector/opq.rs index 56b3ac0f1c9..65ae9c70e04 100644 --- a/rust/src/index/vector/opq.rs +++ b/rust/src/index/vector/opq.rs @@ -389,10 +389,8 @@ mod tests { key: Float32Array::from_iter_values((0..64).map(|x| x as f32 + 640.0)).into(), }; let results = index.search(&query).await.unwrap(); - println!("{:?}", results); let row_ids: &UInt64Array = as_primitive_array(&results[ROW_ID]); assert_eq!(row_ids.len(), 4); - println!("{:?}", row_ids.values()); assert!(row_ids.values().contains(&10)); assert_eq!(min(row_ids).unwrap() + 3, max(row_ids).unwrap()); } From 15ccb91baab350af5533cf7599c00da6caf3e3e4 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 10:42:14 -0700 Subject: [PATCH 13/20] add simple filter case --- rust/src/dataset/scanner.rs | 64 +++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 9a9d9f35460..91d78ccb7c0 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -968,6 +968,70 @@ mod test { ); } + /// Test scan with filter. + /// + /// Query: + /// + /// ``` + /// SELECT s FROM dataset WHERE i > 10 and i < 20 + /// ``` + /// + /// Expected plan: + /// scan(i) -> filter(i) -> take(s) -> projection(s) + #[tokio::test] + async fn test_scan_with_filter() { + let dataset = create_dataset().await; + let mut scan = dataset.scan(); + scan.project(&["s"]).unwrap(); + scan.filter("i > 10 and i < 20").unwrap(); + let plan = scan.create_plan().await.unwrap(); + + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["s"] + ); + + let take = &plan.children()[0]; + assert!(take.as_any().is::()); + assert_eq!( + take.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "_rowid", "s"] + ); + + let filter = &take.children()[0]; + assert!(filter.as_any().is::()); + assert_eq!( + filter + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "_rowid"] + ); + + let scan = &filter.children()[0]; + assert!(scan.as_any().is::()); + assert_eq!( + filter + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["i", "_rowid"] + ); + } + // #[tokio::test] // async fn test_filter_plan() { // let dataset = create_dataset().await; From c379f1bdf683922c7dd44f65765811e2c3db3992 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:18:47 -0700 Subject: [PATCH 14/20] test case for knn plan --- rust/src/dataset/scanner.rs | 183 ++++++++++++++++++++---------------- rust/src/io/exec/take.rs | 4 +- 2 files changed, 107 insertions(+), 80 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 91d78ccb7c0..e89a389e319 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -489,6 +489,8 @@ mod test { use crate::index::IndexType; use super::*; use crate::arrow::*; + use crate::index::IndexType; + use crate::index::vector::VectorIndexParams; use crate::{arrow::RecordBatchBuffer, dataset::WriteParams}; #[tokio::test] @@ -649,11 +651,7 @@ mod test { expected_batches } -<<<<<<< HEAD - async fn create_vector_dataset(path: &str, build_index: bool) -> Dataset { -======= - async fn create_dataset() -> Arc { ->>>>>>> e0c6aed (add test for scan with row id) + async fn create_vector_dataset(path: &str, build_index: bool) -> Arc { let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("i", DataType::Int32, true), ArrowField::new("s", DataType::Utf8, true), @@ -667,7 +665,6 @@ mod test { ), ])); -<<<<<<< HEAD let batches = RecordBatchBuffer::new( (0..5) .map(|i| { @@ -710,7 +707,7 @@ mod test { .unwrap(); } - Dataset::open(path).await.unwrap() + Arc::new(Dataset::open(path).await.unwrap()) } #[tokio::test] @@ -862,86 +859,49 @@ mod test { .copied() .collect(); assert_eq!(expected_i, actual_i); -======= - let vector_data = Float32Array::from_iter_values((0..3200).map(|v| v as f32)); - let vector = FixedSizeListArray::try_new(&vector_data, 32).unwrap(); - let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_iter_values(0..100)), - Arc::new(StringArray::from_iter_values( - (0..100).map(|v| format!("s-{}", v)), - )), - Arc::new(vector), - ], - ) - .unwrap()]); - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let mut params = WriteParams::default(); - params.max_rows_per_group = 100; - let mut reader: Box = Box::new(batches); - Dataset::write(&mut reader, test_uri, Some(params)) - .await - .unwrap(); - Arc::new(Dataset::open(test_uri).await.unwrap()) ->>>>>>> e0c6aed (add test for scan with row id) + + fn get_exec_columns(plan: &dyn ExecutionPlan) -> Vec { + plan.schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect::>() } #[tokio::test] async fn test_simple_scan_plan() { - let dataset = create_dataset().await; + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, false).await; + let scan = dataset.scan(); let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!( - plan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "s", "vec"] - ); + assert_eq!(get_exec_columns(plan.as_ref()), ["i", "s", "vec"]); + let scan = &plan.children()[0]; assert!(scan.as_any().is::()); - assert_eq!( - scan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "s", "vec"] - ); + assert_eq!(get_exec_columns(scan.as_ref()), ["i", "s", "vec"]); let mut scan = dataset.scan(); scan.project(&["s"]).unwrap(); let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!( - plan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["s"] - ); + assert_eq!(get_exec_columns(plan.as_ref()), ["s"]); + let scan = &plan.children()[0]; assert!(scan.as_any().is::()); - assert_eq!( - scan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["s"] - ); + assert_eq!(get_exec_columns(scan.as_ref()), ["s"]); } #[tokio::test] async fn test_scan_with_row_id() { - let dataset = create_dataset().await; + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, false).await; + let mut scan = dataset.scan(); scan.project(&["i"]).unwrap(); scan.with_row_id(); @@ -980,7 +940,10 @@ mod test { /// scan(i) -> filter(i) -> take(s) -> projection(s) #[tokio::test] async fn test_scan_with_filter() { - let dataset = create_dataset().await; + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, false).await; + let mut scan = dataset.scan(); scan.project(&["s"]).unwrap(); scan.filter("i > 10 and i < 20").unwrap(); @@ -1032,17 +995,79 @@ mod test { ); } - // #[tokio::test] - // async fn test_filter_plan() { - // let dataset = create_dataset().await; - // let mut scan = dataset.scan(); - // scan.filter("i > 50").unwrap(); - // let plan = scan.create_plan().await.unwrap(); - - // println!("Plan is {:?}", plan); - // assert!(plan.as_any().is::()); - // let filter = plan.as_any().downcast_ref::().unwrap(); - // // assert!(filter.input.as_any().is::()); - // // assert_eq!(filter.predicate, "i > 50".to_string()); - // } + /// Test KNN with index + /// + /// Query: nearest(vec, [...], 10) + filter(i > 10 and i < 20) + /// + /// Expected plan: + /// KNNIndex(vec) -> Take(i) -> filter(i) -> take(s, vec) -> projection(s, vec, score) + #[tokio::test] + async fn test_ann_with_index() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_vector_dataset(test_uri, true).await; + + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 10).unwrap(); + scan.project(&["s"]).unwrap(); + scan.filter("i > 10 and i < 20").unwrap(); + + let plan = scan.create_plan().await.unwrap(); + + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["s", "vec", "score"] + ); + + let take = &plan.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i", "s"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["s"] + ); + + let filter = &take.children()[0]; + assert!(filter.as_any().is::()); + assert_eq!(get_exec_columns(filter.as_ref()), ["score", "_rowid", "vec", "i"]); + + let take = &filter.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["i"] + ); + + // TODO: Two continuous take execs, we can merge them into one. + let take = &take.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["vec"] + ); + + let knn = &take.children()[0]; + assert!(knn.as_any().is::()); + assert_eq!(get_exec_columns(knn.as_ref()), ["score", "_rowid"]); + } } diff --git a/rust/src/io/exec/take.rs b/rust/src/io/exec/take.rs index 6c3a2d24aed..a8a4d47b2dc 100644 --- a/rust/src/io/exec/take.rs +++ b/rust/src/io/exec/take.rs @@ -131,7 +131,9 @@ impl RecordBatchStream for Take { pub(crate) struct TakeExec { /// Dataset to read from. dataset: Arc, - extra_schema: Arc, + + pub(crate) extra_schema: Arc, + input: Arc, /// Output schema is the merged schema between input schema and extra schema. From b4db78662b759591f475325ca6e5143aeaafb981 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:23:18 -0700 Subject: [PATCH 15/20] add test for refine --- rust/src/dataset/scanner.rs | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index e89a389e319..ba5d0d6eff1 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -1070,4 +1070,85 @@ mod test { assert!(knn.as_any().is::()); assert_eq!(get_exec_columns(knn.as_ref()), ["score", "_rowid"]); } + + /// Test KNN index with refine factor + /// + /// Query: nearest(vec, [...], 10, refine_factor=10) + filter(i > 10 and i < 20) + /// + /// Expected plan: + /// KNNIndex(vec) -> Take(vec) -> KNNFlat(vec, 10) -> Take(i) -> Filter(i) + /// -> take(s, vec) -> projection(s, vec, score) + #[tokio::test] + async fn test_knn_with_refine() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = create_dataset(test_uri, true).await; + + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 10).unwrap(); + scan.refine(10); + scan.project(&["s"]).unwrap(); + scan.filter("i > 10 and i < 20").unwrap(); + + let plan = scan.create_plan().await.unwrap(); + + assert!(plan.as_any().is::()); + assert_eq!( + plan.schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + vec!["s", "vec", "score"] + ); + + let take = &plan.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "i", "s", "vec"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["s"] + ); + + let filter = &take.children()[0]; + assert!(filter.as_any().is::()); + assert_eq!(get_exec_columns(filter.as_ref()), ["score", "_rowid", "vec", "i"]); + + let take = &filter.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["i"] + ); + + let flat = &take.children()[0]; + assert!(flat.as_any().is::()); + + // TODO: Two continuous take execs, we can merge them into one. + let take = &flat.children()[0]; + let take = take.as_any().downcast_ref::().unwrap(); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec"]); + assert_eq!( + take.extra_schema + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + vec!["vec"] + ); + + let knn = &take.children()[0]; + assert!(knn.as_any().is::()); + assert_eq!(get_exec_columns(knn.as_ref()), ["score", "_rowid"]); + } } From d39ad3d4b73af767bf1f292ebe243da3addb76e5 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:27:53 -0700 Subject: [PATCH 16/20] add comments --- rust/src/dataset/scanner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index ba5d0d6eff1..f012f0c03dd 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -1131,10 +1131,10 @@ mod test { vec!["i"] ); + // Flat refine step let flat = &take.children()[0]; assert!(flat.as_any().is::()); - // TODO: Two continuous take execs, we can merge them into one. let take = &flat.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec"]); From 2fb8a87ce8e79a12fcf0dd3243d04091dbc799b1 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:43:14 -0700 Subject: [PATCH 17/20] rebase to main --- rust/src/dataset/scanner.rs | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index f012f0c03dd..bfb47d9959c 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -24,8 +24,7 @@ use datafusion::execution::{ runtime_env::{RuntimeConfig, RuntimeEnv}, }; use datafusion::physical_plan::{ - filter::FilterExec, limit::GlobalLimitExec, ExecutionPlan, PhysicalExpr, - SendableRecordBatchStream, + filter::FilterExec, limit::GlobalLimitExec, ExecutionPlan, SendableRecordBatchStream, }; use datafusion::prelude::*; use futures::stream::{Stream, StreamExt}; @@ -35,7 +34,9 @@ use crate::datafusion::physical_expr::column_names_in_expr; use crate::datatypes::Schema; use crate::format::Index; use crate::index::vector::{MetricType, Query}; -use crate::io::exec::{KNNFlatExec, KNNIndexExec, LanceScanExec, ProjectionExec, TakeExec}; +use crate::io::exec::{ + KNNFlatExec, KNNIndexExec, LanceScanExec, Planner, ProjectionExec, TakeExec, +}; use crate::utils::sql::parse_sql_filter; use crate::{Error, Result}; @@ -373,11 +374,11 @@ impl Scanner { } } - let knn_node = self.ann(q, &index); // score, _rowid + let knn_node = self.ann(q, &index)?; // score, _rowid let with_vector = self.dataset.schema().project(&[&q.column])?; let knn_node_with_vector = self.take(knn_node, &with_vector)?; let knn_node = if q.refine_factor.is_some() { - self.flat_knn(knn_node_with_vector, q) + self.flat_knn(knn_node_with_vector, q)? } else { knn_node_with_vector }; // vector, score, _rowid @@ -387,7 +388,7 @@ impl Scanner { let vector_scan_projection = Arc::new(self.dataset.schema().project(&[&q.column]).unwrap()); let scan_node = self.scan(true, vector_scan_projection); - Ok(self.flat_knn(scan_node, q)) + Ok(self.flat_knn(scan_node, q)?) } } @@ -437,7 +438,7 @@ impl Scanner { *self.offset.as_ref().unwrap_or(&0) as usize, self.limit.map(|l| l as usize), )) - }= + } } /// ScannerStream is a container to wrap different types of ExecNode. @@ -473,8 +474,6 @@ mod test { use std::collections::BTreeSet; use std::path::PathBuf; - use super::*; - use arrow::array::as_primitive_array; use arrow::compute::concat_batches; use arrow::datatypes::Int32Type; @@ -485,12 +484,9 @@ mod test { use futures::TryStreamExt; use tempfile::tempdir; - use crate::index::vector::VectorIndexParams; - use crate::index::IndexType; use super::*; use crate::arrow::*; - use crate::index::IndexType; - use crate::index::vector::VectorIndexParams; + use crate::index::{vector::VectorIndexParams, IndexType}; use crate::{arrow::RecordBatchBuffer, dataset::WriteParams}; #[tokio::test] @@ -859,7 +855,7 @@ mod test { .copied() .collect(); assert_eq!(expected_i, actual_i); - + } fn get_exec_columns(plan: &dyn ExecutionPlan) -> Vec { plan.schema() @@ -1039,7 +1035,10 @@ mod test { let filter = &take.children()[0]; assert!(filter.as_any().is::()); - assert_eq!(get_exec_columns(filter.as_ref()), ["score", "_rowid", "vec", "i"]); + assert_eq!( + get_exec_columns(filter.as_ref()), + ["score", "_rowid", "vec", "i"] + ); let take = &filter.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); @@ -1082,7 +1081,7 @@ mod test { async fn test_knn_with_refine() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let dataset = create_dataset(test_uri, true).await; + let dataset = create_vector_dataset(test_uri, true).await; let mut scan = dataset.scan(); let key: Float32Array = (32..64).map(|v| v as f32).collect(); @@ -1117,7 +1116,10 @@ mod test { let filter = &take.children()[0]; assert!(filter.as_any().is::()); - assert_eq!(get_exec_columns(filter.as_ref()), ["score", "_rowid", "vec", "i"]); + assert_eq!( + get_exec_columns(filter.as_ref()), + ["score", "_rowid", "vec", "i"] + ); let take = &filter.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); From 70e856df08267e8d05cdd58ad8175fc684c6f6e0 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:46:46 -0700 Subject: [PATCH 18/20] add test case for refine --- rust/src/dataset/scanner.rs | 2 +- rust/src/io/exec/knn.rs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index bfb47d9959c..3d60ced9456 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -1104,7 +1104,7 @@ mod test { let take = &plan.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "i", "s", "vec"]); + assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i", "s"]); assert_eq!( take.extra_schema .fields diff --git a/rust/src/io/exec/knn.rs b/rust/src/io/exec/knn.rs index 2d1ac120dd8..aec58ebd2ae 100644 --- a/rust/src/io/exec/knn.rs +++ b/rust/src/io/exec/knn.rs @@ -152,7 +152,9 @@ impl ExecutionPlan for KNNFlatExec { fn schema(&self) -> arrow_schema::SchemaRef { let input_schema = self.input.schema(); let mut fields = input_schema.fields().to_vec(); - fields.push(Field::new(SCORE_COL, DataType::Float32, false)); + if !input_schema.field_with_name(SCORE_COL).is_ok() { + fields.push(Field::new(SCORE_COL, DataType::Float32, false)); + } Arc::new(Schema::new_with_metadata( fields, From f3009c176b029b630c5dfd21e5d0ebd4a1dae1bd Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 11:48:31 -0700 Subject: [PATCH 19/20] add format --- rust/src/dataset/scanner.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index 3d60ced9456..e7e3811a40d 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -215,6 +215,7 @@ impl Scanner { Ok(schema) } + /// The output schema of the Scanner, in Lance Schema format. fn output_schema(&self) -> Result> { let mut extra_columns = vec![]; From 905d34a5659f54bda81d1a8d687a3ac9abcc4174 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 24 Mar 2023 13:41:01 -0700 Subject: [PATCH 20/20] add ArrowSchema::field_names() --- rust/src/arrow/schema.rs | 6 +++ rust/src/dataset/scanner.rs | 98 ++++++++++--------------------------- 2 files changed, 32 insertions(+), 72 deletions(-) diff --git a/rust/src/arrow/schema.rs b/rust/src/arrow/schema.rs index e63a098df04..8243ed267df 100644 --- a/rust/src/arrow/schema.rs +++ b/rust/src/arrow/schema.rs @@ -20,6 +20,8 @@ use arrow_schema::{ArrowError, Field, Schema}; pub trait SchemaExt { /// Create a new [`Schema`] with one extra field. fn try_with_column(&self, field: Field) -> std::result::Result; + + fn field_names(&self) -> Vec<&String>; } impl SchemaExt for Schema { @@ -35,4 +37,8 @@ impl SchemaExt for Schema { fields.push(field); Ok(Schema::new_with_metadata(fields, self.metadata.clone())) } + + fn field_names(&self) -> Vec<&String> { + self.fields().iter().map(|f| f.name()).collect() + } } diff --git a/rust/src/dataset/scanner.rs b/rust/src/dataset/scanner.rs index e7e3811a40d..d1b2135c837 100644 --- a/rust/src/dataset/scanner.rs +++ b/rust/src/dataset/scanner.rs @@ -858,14 +858,6 @@ mod test { assert_eq!(expected_i, actual_i); } - fn get_exec_columns(plan: &dyn ExecutionPlan) -> Vec { - plan.schema() - .fields() - .iter() - .map(|f| f.name().to_string()) - .collect::>() - } - #[tokio::test] async fn test_simple_scan_plan() { let test_dir = tempdir().unwrap(); @@ -876,21 +868,21 @@ mod test { let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!(get_exec_columns(plan.as_ref()), ["i", "s", "vec"]); + assert_eq!(plan.schema().field_names(), ["i", "s", "vec"]); let scan = &plan.children()[0]; assert!(scan.as_any().is::()); - assert_eq!(get_exec_columns(scan.as_ref()), ["i", "s", "vec"]); + assert_eq!(plan.schema().field_names(), ["i", "s", "vec"]); let mut scan = dataset.scan(); scan.project(&["s"]).unwrap(); let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!(get_exec_columns(plan.as_ref()), ["s"]); + assert_eq!(plan.schema().field_names(), ["s"]); let scan = &plan.children()[0]; assert!(scan.as_any().is::()); - assert_eq!(get_exec_columns(scan.as_ref()), ["s"]); + assert_eq!(scan.schema().field_names(), ["s"]); } #[tokio::test] @@ -905,24 +897,10 @@ mod test { let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!( - plan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "_rowid"] - ); + assert_eq!(plan.schema().field_names(), &["i", "_rowid"]); let scan = &plan.children()[0]; assert!(scan.as_any().is::()); - assert_eq!( - scan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "_rowid"] - ); + assert_eq!(scan.schema().field_names(), &["i", "_rowid"]); } /// Test scan with filter. @@ -947,49 +925,19 @@ mod test { let plan = scan.create_plan().await.unwrap(); assert!(plan.as_any().is::()); - assert_eq!( - plan.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["s"] - ); + assert_eq!(plan.schema().field_names(), ["s"]); let take = &plan.children()[0]; assert!(take.as_any().is::()); - assert_eq!( - take.schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "_rowid", "s"] - ); + assert_eq!(take.schema().field_names(), ["i", "_rowid", "s"]); let filter = &take.children()[0]; assert!(filter.as_any().is::()); - assert_eq!( - filter - .schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "_rowid"] - ); + assert_eq!(filter.schema().field_names(), ["i", "_rowid"]); let scan = &filter.children()[0]; assert!(scan.as_any().is::()); - assert_eq!( - filter - .schema() - .fields() - .iter() - .map(|f| f.name()) - .collect::>(), - vec!["i", "_rowid"] - ); + assert_eq!(filter.schema().field_names(), ["i", "_rowid"]); } /// Test KNN with index @@ -1024,7 +972,10 @@ mod test { let take = &plan.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i", "s"]); + assert_eq!( + take.schema().field_names(), + ["score", "_rowid", "vec", "i", "s"] + ); assert_eq!( take.extra_schema .fields @@ -1037,13 +988,13 @@ mod test { let filter = &take.children()[0]; assert!(filter.as_any().is::()); assert_eq!( - get_exec_columns(filter.as_ref()), + filter.schema().field_names(), ["score", "_rowid", "vec", "i"] ); let take = &filter.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i"]); + assert_eq!(take.schema().field_names(), ["score", "_rowid", "vec", "i"]); assert_eq!( take.extra_schema .fields @@ -1056,7 +1007,7 @@ mod test { // TODO: Two continuous take execs, we can merge them into one. let take = &take.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec"]); + assert_eq!(take.schema().field_names(), ["score", "_rowid", "vec"]); assert_eq!( take.extra_schema .fields @@ -1068,7 +1019,7 @@ mod test { let knn = &take.children()[0]; assert!(knn.as_any().is::()); - assert_eq!(get_exec_columns(knn.as_ref()), ["score", "_rowid"]); + assert_eq!(knn.schema().field_names(), ["score", "_rowid"]); } /// Test KNN index with refine factor @@ -1105,7 +1056,10 @@ mod test { let take = &plan.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i", "s"]); + assert_eq!( + take.schema().field_names(), + ["score", "_rowid", "vec", "i", "s"] + ); assert_eq!( take.extra_schema .fields @@ -1118,13 +1072,13 @@ mod test { let filter = &take.children()[0]; assert!(filter.as_any().is::()); assert_eq!( - get_exec_columns(filter.as_ref()), + filter.schema().field_names(), ["score", "_rowid", "vec", "i"] ); let take = &filter.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec", "i"]); + assert_eq!(take.schema().field_names(), ["score", "_rowid", "vec", "i"]); assert_eq!( take.extra_schema .fields @@ -1140,7 +1094,7 @@ mod test { let take = &flat.children()[0]; let take = take.as_any().downcast_ref::().unwrap(); - assert_eq!(get_exec_columns(take), ["score", "_rowid", "vec"]); + assert_eq!(take.schema().field_names(), ["score", "_rowid", "vec"]); assert_eq!( take.extra_schema .fields @@ -1152,6 +1106,6 @@ mod test { let knn = &take.children()[0]; assert!(knn.as_any().is::()); - assert_eq!(get_exec_columns(knn.as_ref()), ["score", "_rowid"]); + assert_eq!(knn.schema().field_names(), ["score", "_rowid"]); } }