diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index 85d8aeef073c1..c8ece999d4c91 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -35,6 +35,7 @@ use async_trait::async_trait; use futures::Stream; /// Execution plan for reading in-memory batches of data +#[derive(Clone)] pub struct MemoryExec { /// The partitions to query partitions: Vec>, @@ -76,12 +77,16 @@ impl ExecutionPlan for MemoryExec { fn with_new_children( &self, - _: Vec>, + children: Vec>, ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {:?}", - self - ))) + if children.is_empty() { + Ok(Arc::new(self.clone())) + } else { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } } async fn execute(&self, partition: usize) -> Result { diff --git a/datafusion/src/physical_plan/merge_sort.rs b/datafusion/src/physical_plan/merge_sort.rs index d28b12e3d5fc3..fa40bccb0b4a6 100644 --- a/datafusion/src/physical_plan/merge_sort.rs +++ b/datafusion/src/physical_plan/merge_sort.rs @@ -25,9 +25,11 @@ use std::task::{Context, Poll}; use futures::stream::{Fuse, Stream}; use futures::StreamExt; -use arrow::array::ArrayRef; +use arrow::array::{build_compare, ArrayRef, BooleanArray, DynComparator}; pub use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; +use arrow::compute::{ + filter_record_batch, lexsort_to_indices, take, SortColumn, TakeOptions, +}; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -551,6 +553,209 @@ impl RecordBatchStream for MergeSortStream { } } +/// Filter out all but last row by unique key execution plan +#[derive(Debug)] +pub struct LastRowByUniqueKeyExec { + input: Arc, + /// Columns to sort on + pub unique_key: Vec, +} + +impl LastRowByUniqueKeyExec { + /// Create a new execution plan + pub fn try_new( + input: Arc, + unique_key: Vec, + ) -> Result { + if unique_key.is_empty() { + return Err(DataFusionError::Internal( + "Empty unique_key passed for LastRowByUniqueKeyExec".to_string(), + )); + } + Ok(Self { input, unique_key }) + } + + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } +} + +#[async_trait] +impl ExecutionPlan for LastRowByUniqueKeyExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + Ok(Arc::new(LastRowByUniqueKeyExec::try_new( + children[0].clone(), + self.unique_key.clone(), + )?)) + } + + fn output_hints(&self) -> OptimizerHints { + OptimizerHints { + single_value_columns: self.input.output_hints().single_value_columns, + sort_order: self.input.output_hints().sort_order, + } + } + + async fn execute(&self, partition: usize) -> Result { + if 0 != partition { + return Err(DataFusionError::Internal(format!( + "LastRowByUniqueKeyExec invalid partition {}", + partition + ))); + } + + if self.input.output_partitioning().partition_count() != 1 { + return Err(DataFusionError::Internal(format!( + "LastRowByUniqueKeyExec expects only one partition but got {}", + self.input.output_partitioning().partition_count() + ))); + } + let input_stream = self.input.execute(0).await?; + + Ok(Box::pin(LastRowByUniqueKeyExecStream { + schema: self.input.schema(), + input: input_stream, + unique_key: self.unique_key.clone(), + current_record_batch: None, + })) + } +} + +/// Filter out all but last row by unique key stream +struct LastRowByUniqueKeyExecStream { + /// Output schema, which is the same as the input schema for this operator + schema: SchemaRef, + /// The input stream to filter. + input: SendableRecordBatchStream, + /// Key columns + unique_key: Vec, + /// Current Record Batch + current_record_batch: Option, +} + +impl LastRowByUniqueKeyExecStream { + fn row_equals(comparators: &Vec, a: usize, b: usize) -> bool { + for comparator in comparators.iter().rev() { + if comparator(a, b) != Ordering::Equal { + return false; + } + } + true + } + + fn keep_only_last_rows_by_key( + &mut self, + next_batch: Option, + ) -> ArrowResult { + let batch = self.current_record_batch.take().unwrap(); + let num_rows = batch.num_rows(); + let mut builder = BooleanArray::builder(num_rows); + let key_columns = self + .unique_key + .iter() + .map(|k| batch.column(k.index()).clone()) + .collect::>(); + let mut requires_filtering = false; + let self_column_comparators = key_columns + .iter() + .map(|c| build_compare(c.as_ref(), c.as_ref())) + .collect::>>()?; + for i in 0..num_rows { + let filter_value = if i == num_rows - 1 && next_batch.is_none() { + true + } else if i == num_rows - 1 { + let next_key_columns = self + .unique_key + .iter() + .map(|k| next_batch.as_ref().unwrap().column(k.index()).clone()) + .collect::>(); + let next_column_comparators = key_columns + .iter() + .zip(next_key_columns.iter()) + .map(|(c, n)| build_compare(c.as_ref(), n.as_ref())) + .collect::>>()?; + !Self::row_equals(&next_column_comparators, i, 0) + } else { + !Self::row_equals(&self_column_comparators, i, i + 1) + }; + if !filter_value { + requires_filtering = true; + } + builder.append_value(filter_value)?; + } + self.current_record_batch = next_batch; + if requires_filtering { + let filter_array = builder.finish(); + filter_record_batch(&batch, &filter_array) + } else { + Ok(batch) + } + } +} + +impl Stream for LastRowByUniqueKeyExecStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.input.poll_next_unpin(cx).map(|x| { + match x { + Some(Ok(batch)) => { + if self.current_record_batch.is_none() { + let schema = batch.schema(); + self.current_record_batch = Some(batch); + // TODO get rid of empty batch. Returning Poll::Pending here results in stuck stream. + Some(Ok(RecordBatch::new_empty(schema))) + } else { + Some(self.keep_only_last_rows_by_key(Some(batch))) + } + } + None => { + if self.current_record_batch.is_some() { + Some(self.keep_only_last_rows_by_key(None)) + } else { + None + } + } + other => other, + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.input.size_hint(); + (lower, upper.map(|u| u + 1)) + } +} + +impl RecordBatchStream for LastRowByUniqueKeyExecStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use super::*; @@ -1059,6 +1264,32 @@ mod tests { ) } + #[tokio::test] + async fn last_row_by_unique_key_exec() { + let p1 = vec![ + ints(vec![1, 1, 2, 3, 4, 5, 5, 6, 7]), + ints(vec![8, 9, 9, 10]), + ints(vec![11, 12, 13]), + ]; + + let schema = ints_schema(); + let inp = Arc::new(MemoryExec::try_new(&vec![p1], schema.clone(), None).unwrap()); + let r = collect(Arc::new( + LastRowByUniqueKeyExec::try_new(inp, vec![col("a", &schema)]).unwrap(), + )) + .await + .unwrap(); + assert_eq!( + to_ints(r), + vec![ + vec![], + vec![1, 2, 3, 4, 5, 6, 7], + vec![8, 9, 10], + vec![11, 12, 13] + ] + ); + } + fn test_merge(arrays: Vec<&ArrayRef>) -> ArrayRef { let schema = Arc::new(Schema::new(vec![Field::new( "a", diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 0c61721361a06..1a5a1b5866473 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -41,7 +41,9 @@ use crate::physical_plan::hash_join::HashJoinExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::merge::MergeExec; use crate::physical_plan::merge_join::MergeJoinExec; -use crate::physical_plan::merge_sort::{MergeReSortExec, MergeSortExec}; +use crate::physical_plan::merge_sort::{ + LastRowByUniqueKeyExec, MergeReSortExec, MergeSortExec, +}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::skip::SkipExec; @@ -947,6 +949,10 @@ impl DefaultPhysicalPlanner { Some(node.clone()) } else if let Some(aliased) = node.as_any().downcast_ref::() { self.merge_sort_node(aliased.children()[0].clone()) + } else if let Some(aliased) = + node.as_any().downcast_ref::() + { + self.merge_sort_node(aliased.children()[0].clone()) } else if let Some(aliased) = node.as_any().downcast_ref::() { // TODO self.merge_sort_node(aliased.children()[0].clone())