Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,11 @@ impl LogicalPlanBuilder {

let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len());
for missing_aggr_expr in missing_aggr_exprs {
let expr_name = missing_aggr_expr.name(input_schema)?;
alias_map.insert(expr_name.clone(), expr_name);
if aggr_expr.contains(missing_aggr_expr) {
continue;
}
let expr_name = missing_aggr_expr.name(input_schema)?;
alias_map.insert(expr_name.clone(), expr_name);
missing_exprs.push(missing_aggr_expr.clone());
}

Expand All @@ -725,6 +725,31 @@ impl LogicalPlanBuilder {
schema: Arc::new(new_schema),
}))
}
LogicalPlan::Union(Union {
inputs,
schema: _,
alias,
}) => {
let inputs = inputs
.into_iter()
.map(|input_plan| {
self.add_missing_aggr_exprs(
input_plan,
missing_aggr_exprs,
alias_map,
)
})
.collect::<Result<Vec<_>>>()?;
let Some(first_input) = inputs.first() else {
return Err(DataFusionError::Internal("Inputs in union are empty".to_string()));
};
let schema = Arc::clone(first_input.schema());
Ok(LogicalPlan::Union(Union {
inputs,
schema,
alias,
}))
}
_ => {
let new_inputs = curr_plan
.inputs()
Expand Down Expand Up @@ -1482,7 +1507,7 @@ pub(crate) fn table_udfs(plan: LogicalPlan, udtf_expr: Vec<Expr>) -> Result<Logi

#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, TimeUnit};

use crate::logical_plan::StringifiedPlan;

Expand Down Expand Up @@ -1725,6 +1750,76 @@ mod tests {
}
}

#[test]
fn plan_builder_order_by_missing_aggr() -> Result<()> {
let builder = LogicalPlanBuilder::scan_empty(Some("Ecom"), &ecom_schema(), None)?
.filter(col("Ecom.status").eq(lit("completed")))?;

let first_plan = builder
.aggregate(
[col("Ecom.created"), col("Ecom.status")],
[sum(col("Ecom.sumPrice"))],
)?
.filter(col("SUM(Ecom.sumPrice)").is_not_null())?
.project([
col("Ecom.created").alias("Ecom[created]"),
col("Ecom.status").alias("Ecom[status]"),
lit(false).alias("[IsGrandTotalRowTotal]"),
col("SUM(Ecom.sumPrice)").alias("[count]"),
])?
.build()?;

let second_plan = builder
.aggregate(Vec::<Expr>::new(), [sum(col("Ecom.sumPrice"))])?
.filter(col("SUM(Ecom.sumPrice)").is_not_null())?
.project([
Expr::Literal(ScalarValue::Null).alias("Ecom[created]"),
Expr::Literal(ScalarValue::Null).alias("Ecom[status]"),
lit(true).alias("[IsGrandTotalRowTotal]"),
col("SUM(Ecom.sumPrice)").alias("[count]"),
])?
.build()?;

let plan_before_sort = LogicalPlanBuilder::from(first_plan)
.union(second_plan)?
.sort([
col("[IsGrandTotalRowTotal]").sort(false, true),
col("Ecom[created]").sort(true, false),
col("Ecom[status]").sort(true, false),
])?
.limit(None, Some(502))?
.build()?;

let plan_with_sort = LogicalPlanBuilder::from(plan_before_sort)
.sort([
col("[IsGrandTotalRowTotal]").sort(false, true),
sum(col("Ecom.sumPrice")).sort(true, false),
col("Ecom[status]").sort(true, false),
])?
.build()?;

let expected = "\
Projection: #Ecom[created], #Ecom[status], #[IsGrandTotalRowTotal], #[count]\
\n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #SUM(Ecom.sumPrice) ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\
\n Limit: skip=None, fetch=502\
\n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #Ecom[created] ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\
\n Union\
\n Projection: #Ecom.created AS Ecom[created], #Ecom.status AS Ecom[status], Boolean(false) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\
\n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\
\n Aggregate: groupBy=[[#Ecom.created, #Ecom.status]], aggr=[[SUM(#Ecom.sumPrice)]]\
\n Filter: #Ecom.status = Utf8(\"completed\")\
\n TableScan: Ecom projection=None\
\n Projection: CAST(NULL AS Timestamp(Nanosecond, None)) AS Ecom[created], CAST(NULL AS Utf8) AS Ecom[status], Boolean(true) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\
\n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\
\n Aggregate: groupBy=[[]], aggr=[[SUM(#Ecom.sumPrice)]]\
\n Filter: #Ecom.status = Utf8(\"completed\")\
\n TableScan: Ecom projection=None";

assert_eq!(expected, format!("{:?}", plan_with_sort));

Ok(())
}

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Expand All @@ -1735,6 +1830,18 @@ mod tests {
])
}

fn ecom_schema() -> Schema {
Schema::new(vec![
Field::new(
"created",
DataType::Timestamp(TimeUnit::Nanosecond, None),
true,
),
Field::new("status", DataType::Utf8, true),
Field::new("sumPrice", DataType::Float64, true),
])
}

#[test]
fn stringified_plan() {
let stringified_plan =
Expand Down
Loading