Skip to content

Commit

Permalink
feat(cubesql): SQL push down support for window functions (#7403)
Browse files Browse the repository at this point in the history
  • Loading branch information
paveltiunov committed Nov 11, 2023
1 parent b3ea6ab commit b1da6c0
Show file tree
Hide file tree
Showing 18 changed files with 607 additions and 37 deletions.
1 change: 1 addition & 0 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Expand Up @@ -2479,6 +2479,7 @@ class BaseQuery {
binary: '({{ left }} {{ op }} {{ right }})',
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
cast: 'CAST({{ expr }} AS {{ data_type }})',
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})'
},
quotes: {
identifiers: '"',
Expand Down
13 changes: 12 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Expand Up @@ -140,6 +140,7 @@ pub struct WrappedSelectNode {
pub projection_expr: Vec<Expr>,
pub group_expr: Vec<Expr>,
pub aggr_expr: Vec<Expr>,
pub window_expr: Vec<Expr>,
pub from: Arc<LogicalPlan>,
pub joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
pub filter_expr: Vec<Expr>,
Expand All @@ -158,6 +159,7 @@ impl WrappedSelectNode {
projection_expr: Vec<Expr>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
window_expr: Vec<Expr>,
from: Arc<LogicalPlan>,
joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
filter_expr: Vec<Expr>,
Expand All @@ -174,6 +176,7 @@ impl WrappedSelectNode {
projection_expr,
group_expr,
aggr_expr,
window_expr,
from,
joins,
filter_expr,
Expand Down Expand Up @@ -207,6 +210,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
exprs.extend(self.projection_expr.clone());
exprs.extend(self.group_expr.clone());
exprs.extend(self.aggr_expr.clone());
exprs.extend(self.window_expr.clone());
exprs.extend(self.joins.iter().map(|(_, expr, _)| expr.clone()));
exprs.extend(self.filter_expr.clone());
exprs.extend(self.having_expr.clone());
Expand All @@ -217,11 +221,12 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, window_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
self.select_type,
self.projection_expr,
self.group_expr,
self.aggr_expr,
self.window_expr,
self.from,
self.joins,
self.filter_expr,
Expand Down Expand Up @@ -261,6 +266,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
let mut projection_expr = vec![];
let mut group_expr = vec![];
let mut aggregate_expr = vec![];
let mut window_expr = vec![];
let limit = None;
let offset = None;
let alias = None;
Expand All @@ -278,6 +284,10 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
aggregate_expr.push(exprs_iter.next().unwrap().clone());
}

for _ in self.window_expr.iter() {
window_expr.push(exprs_iter.next().unwrap().clone());
}

for _ in self.joins.iter() {
joins_expr.push(exprs_iter.next().unwrap().clone());
}
Expand All @@ -300,6 +310,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
projection_expr,
group_expr,
aggregate_expr,
window_expr,
from,
joins
.into_iter()
Expand Down
96 changes: 95 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Expand Up @@ -297,6 +297,7 @@ impl CubeScanWrapperNode {
projection_expr,
group_expr,
aggr_expr,
window_expr,
from,
joins: _joins,
filter_expr: _filter_expr,
Expand Down Expand Up @@ -431,6 +432,20 @@ impl CubeScanWrapperNode {
ungrouped_scan_node.clone(),
)
.await?;

let (window, sql) = Self::generate_column_expr(
plan.clone(),
schema.clone(),
window_expr.clone(),
sql,
generator.clone(),
&column_remapping,
&mut next_remapping,
alias.clone(),
can_rename_columns,
ungrouped_scan_node.clone(),
)
.await?;
// Sort node always comes on top and pushed down to select so we need to replace columns here by appropriate column definitions
let order_replace_map = projection_expr
.iter()
Expand Down Expand Up @@ -504,6 +519,12 @@ impl CubeScanWrapperNode {
)
}),
)
.chain(window.iter().map(|m| {
Self::ungrouped_member_def(
m,
&ungrouped_scan_node.used_cubes,
)
}))
.collect::<Result<_>>()?,
);
load_request.dimensions = Some(
Expand Down Expand Up @@ -1333,7 +1354,80 @@ impl CubeScanWrapperNode {
sql_query,
))
}
// Expr::WindowFunction { .. } => {}
Expr::WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
let mut sql_args = Vec::new();
for arg in args {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_args.push(sql);
}
let mut sql_partition_by = Vec::new();
for arg in partition_by {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_partition_by.push(sql);
}
let mut sql_order_by = Vec::new();
for arg in order_by {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_order_by.push(
sql_generator
.get_sql_templates()
// TODO asc/desc
.sort_expr(sql, true, false)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for sort expr: {}",
e
))
})?,
);
}
let resulting_sql = sql_generator
.get_sql_templates()
.window_function_expr(
fun,
sql_args,
sql_partition_by,
sql_order_by,
window_frame,
)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for window function: {}",
e
))
})?;
Ok((resulting_sql, sql_query))
}
// Expr::AggregateUDF { .. } => {}
// Expr::InList { .. } => {}
// Expr::Wildcard => {}
Expand Down
77 changes: 65 additions & 12 deletions rust/cubesql/cubesql/src/compile/mod.rs
Expand Up @@ -18847,12 +18847,20 @@ ORDER BY \"COUNT(count)\" DESC"
.sql
.contains("CASE WHEN"));

assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"));
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"),
"SQL contains 1123: {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
Expand Down Expand Up @@ -18883,12 +18891,20 @@ ORDER BY \"COUNT(count)\" DESC"
.sql
.contains("CASE WHEN"));

assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("LIMIT 1123"));
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"),
"SQL contains 1123: {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
Expand Down Expand Up @@ -19063,6 +19079,43 @@ ORDER BY \"COUNT(count)\" DESC"
.contains("EXTRACT"));
}

#[tokio::test]
async fn test_wrapper_window_function() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT customer_gender, AVG(avgPrice) mp, SUM(COUNT(count)) OVER() FROM KibanaSampleDataEcommerce a GROUP BY 1 LIMIT 100"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("OVER"),
"SQL should contain 'OVER': {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}

#[tokio::test]
async fn test_thoughtspot_pg_date_trunc_year() {
init_logger();
Expand Down
50 changes: 37 additions & 13 deletions rust/cubesql/cubesql/src/compile/rewrite/converter.rs
Expand Up @@ -41,8 +41,8 @@ use datafusion::{
logical_plan::{
build_join_schema, build_table_udf_schema, exprlist_to_fields, normalize_cols,
plan::{Aggregate, Extension, Filter, Join, Projection, Sort, TableUDFs, Window},
CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation, Expr, Like, Limit,
LogicalPlan, LogicalPlanBuilder, TableScan, Union,
replace_col_to_expr, CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation,
Expr, Like, Limit, LogicalPlan, LogicalPlanBuilder, TableScan, Union,
},
physical_plan::planner::DefaultPhysicalPlanner,
scalar::ScalarValue,
Expand Down Expand Up @@ -1671,8 +1671,10 @@ impl LanguageToLogicalPlanConverter {
match_expr_list_node!(node_by_id, to_expr, params[2], WrappedSelectGroupExpr);
let aggr_expr =
match_expr_list_node!(node_by_id, to_expr, params[3], WrappedSelectAggrExpr);
let from = Arc::new(self.to_logical_plan(params[4])?);
let joins = match_list_node!(node_by_id, params[5], WrappedSelectJoins)
let window_expr =
match_expr_list_node!(node_by_id, to_expr, params[4], WrappedSelectWindowExpr);
let from = Arc::new(self.to_logical_plan(params[5])?);
let joins = match_list_node!(node_by_id, params[6], WrappedSelectJoins)
.into_iter()
.map(|j| {
if let LogicalPlanLanguage::WrappedSelectJoin(params) = j {
Expand All @@ -1688,28 +1690,49 @@ impl LanguageToLogicalPlanConverter {
.collect::<Result<Vec<_>, _>>()?;

let filter_expr =
match_expr_list_node!(node_by_id, to_expr, params[6], WrappedSelectFilterExpr);
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectFilterExpr);
let having_expr =
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectHavingExpr);
let limit = match_data_node!(node_by_id, params[8], WrappedSelectLimit);
let offset = match_data_node!(node_by_id, params[9], WrappedSelectOffset);
match_expr_list_node!(node_by_id, to_expr, params[8], WrappedSelectHavingExpr);
let limit = match_data_node!(node_by_id, params[9], WrappedSelectLimit);
let offset = match_data_node!(node_by_id, params[10], WrappedSelectOffset);
let order_expr =
match_expr_list_node!(node_by_id, to_expr, params[10], WrappedSelectOrderExpr);
let alias = match_data_node!(node_by_id, params[11], WrappedSelectAlias);
let ungrouped = match_data_node!(node_by_id, params[12], WrappedSelectUngrouped);
match_expr_list_node!(node_by_id, to_expr, params[11], WrappedSelectOrderExpr);
let alias = match_data_node!(node_by_id, params[12], WrappedSelectAlias);
let ungrouped = match_data_node!(node_by_id, params[13], WrappedSelectUngrouped);

let group_expr = normalize_cols(group_expr, &from)?;
let aggr_expr = normalize_cols(aggr_expr, &from)?;
let projection_expr = normalize_cols(projection_expr, &from)?;
let all_expr = match select_type {
let all_expr_without_window = match select_type {
WrappedSelectType::Projection => projection_expr.clone(),
WrappedSelectType::Aggregate => {
group_expr.iter().chain(aggr_expr.iter()).cloned().collect()
}
};
let without_window_fields =
exprlist_to_fields(all_expr_without_window.iter(), from.schema())?;
let replace_map = all_expr_without_window
.iter()
.zip(without_window_fields.iter())
.map(|(e, f)| (f.qualified_column(), e.clone()))
.collect::<Vec<_>>();
let replace_map = replace_map
.iter()
.map(|(c, e)| (c, e))
.collect::<HashMap<_, _>>();
let window_expr_rebased = window_expr
.iter()
.map(|e| replace_col_to_expr(e.clone(), &replace_map))
.collect::<Result<Vec<_>, _>>()?;
let schema = DFSchema::new_with_metadata(
// TODO support joins schema
exprlist_to_fields(all_expr.iter(), from.schema())?,
without_window_fields
.into_iter()
.chain(
exprlist_to_fields(window_expr_rebased.iter(), from.schema())?
.into_iter(),
)
.collect(),
HashMap::new(),
)?;

Expand All @@ -1725,6 +1748,7 @@ impl LanguageToLogicalPlanConverter {
projection_expr,
group_expr,
aggr_expr,
window_expr_rebased,
from,
joins,
filter_expr,
Expand Down

0 comments on commit b1da6c0

Please sign in to comment.