diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 72787e192763..222cb553cf87 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -698,11 +698,11 @@ impl LogicalPlanBuilder { let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len()); for missing_aggr_expr in missing_aggr_exprs { + let expr_name = missing_aggr_expr.name(input_schema)?; + alias_map.insert(expr_name.clone(), expr_name); if aggr_expr.contains(missing_aggr_expr) { continue; } - let expr_name = missing_aggr_expr.name(input_schema)?; - alias_map.insert(expr_name.clone(), expr_name); missing_exprs.push(missing_aggr_expr.clone()); } @@ -725,6 +725,31 @@ impl LogicalPlanBuilder { schema: Arc::new(new_schema), })) } + LogicalPlan::Union(Union { + inputs, + schema: _, + alias, + }) => { + let inputs = inputs + .into_iter() + .map(|input_plan| { + self.add_missing_aggr_exprs( + input_plan, + missing_aggr_exprs, + alias_map, + ) + }) + .collect::>>()?; + let Some(first_input) = inputs.first() else { + return Err(DataFusionError::Internal("Inputs in union are empty".to_string())); + }; + let schema = Arc::clone(first_input.schema()); + Ok(LogicalPlan::Union(Union { + inputs, + schema, + alias, + })) + } _ => { let new_inputs = curr_plan .inputs() @@ -1482,7 +1507,7 @@ pub(crate) fn table_udfs(plan: LogicalPlan, udtf_expr: Vec) -> Result Result<()> { + let builder = LogicalPlanBuilder::scan_empty(Some("Ecom"), &ecom_schema(), None)? + .filter(col("Ecom.status").eq(lit("completed")))?; + + let first_plan = builder + .aggregate( + [col("Ecom.created"), col("Ecom.status")], + [sum(col("Ecom.sumPrice"))], + )? + .filter(col("SUM(Ecom.sumPrice)").is_not_null())? + .project([ + col("Ecom.created").alias("Ecom[created]"), + col("Ecom.status").alias("Ecom[status]"), + lit(false).alias("[IsGrandTotalRowTotal]"), + col("SUM(Ecom.sumPrice)").alias("[count]"), + ])? + .build()?; + + let second_plan = builder + .aggregate(Vec::::new(), [sum(col("Ecom.sumPrice"))])? + .filter(col("SUM(Ecom.sumPrice)").is_not_null())? + .project([ + Expr::Literal(ScalarValue::Null).alias("Ecom[created]"), + Expr::Literal(ScalarValue::Null).alias("Ecom[status]"), + lit(true).alias("[IsGrandTotalRowTotal]"), + col("SUM(Ecom.sumPrice)").alias("[count]"), + ])? + .build()?; + + let plan_before_sort = LogicalPlanBuilder::from(first_plan) + .union(second_plan)? + .sort([ + col("[IsGrandTotalRowTotal]").sort(false, true), + col("Ecom[created]").sort(true, false), + col("Ecom[status]").sort(true, false), + ])? + .limit(None, Some(502))? + .build()?; + + let plan_with_sort = LogicalPlanBuilder::from(plan_before_sort) + .sort([ + col("[IsGrandTotalRowTotal]").sort(false, true), + sum(col("Ecom.sumPrice")).sort(true, false), + col("Ecom[status]").sort(true, false), + ])? + .build()?; + + let expected = "\ + Projection: #Ecom[created], #Ecom[status], #[IsGrandTotalRowTotal], #[count]\ + \n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #SUM(Ecom.sumPrice) ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\ + \n Limit: skip=None, fetch=502\ + \n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #Ecom[created] ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\ + \n Union\ + \n Projection: #Ecom.created AS Ecom[created], #Ecom.status AS Ecom[status], Boolean(false) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\ + \n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\ + \n Aggregate: groupBy=[[#Ecom.created, #Ecom.status]], aggr=[[SUM(#Ecom.sumPrice)]]\ + \n Filter: #Ecom.status = Utf8(\"completed\")\ + \n TableScan: Ecom projection=None\ + \n Projection: CAST(NULL AS Timestamp(Nanosecond, None)) AS Ecom[created], CAST(NULL AS Utf8) AS Ecom[status], Boolean(true) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\ + \n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(#Ecom.sumPrice)]]\ + \n Filter: #Ecom.status = Utf8(\"completed\")\ + \n TableScan: Ecom projection=None"; + + assert_eq!(expected, format!("{:?}", plan_with_sort)); + + Ok(()) + } + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -1735,6 +1830,18 @@ mod tests { ]) } + fn ecom_schema() -> Schema { + Schema::new(vec![ + Field::new( + "created", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new("status", DataType::Utf8, true), + Field::new("sumPrice", DataType::Float64, true), + ]) + } + #[test] fn stringified_plan() { let stringified_plan =