From 3388b5268aa73436e090327fec828585600298de Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Mon, 19 May 2025 01:47:39 +0400 Subject: [PATCH] feat(cubesql): Push down `DATE_TRUNC` expressions as member expressions with granularity Signed-off-by: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> --- .../cubesql/src/compile/engine/df/wrapper.rs | 42 ++++++++++++++++++- rust/cubesql/cubesql/src/compile/mod.rs | 3 +- .../cubesql/src/compile/test/test_wrapper.rs | 4 +- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index fd7eb01a0e428..d6c7446e081ce 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -8,7 +8,7 @@ use crate::{ extract_exprlist_from_groupping_set, rules::{ filters::Decimal, - utils::{DecomposedDayTime, DecomposedMonthDayNano}, + utils::{granularity_str_to_int_order, DecomposedDayTime, DecomposedMonthDayNano}, }, LikeType, WrappedSelectType, }, @@ -2520,6 +2520,46 @@ impl WrappedSelectNode { )) } Expr::ScalarFunction { fun, args } => { + if args.len() == 2 { + if let ( + BuiltinScalarFunction::DateTrunc, + Expr::Literal(ScalarValue::Utf8(Some(granularity))), + Expr::Column(column), + Some(PushToCubeContext { + ungrouped_scan_node, + known_join_subqueries, + }), + ) = (&fun, &args[0], &args[1], push_to_cube_context) + { + let granularity = granularity.to_ascii_lowercase(); + // Security check to prevent SQL injection + if granularity_str_to_int_order(&granularity, Some(false)).is_some() + && subqueries.get(&column.flat_name()).is_none() + && !column + .relation + .as_ref() + .map(|relation| known_join_subqueries.contains(relation)) + .unwrap_or(false) + { + if let Ok(MemberField::Member(regular_member)) = + Self::find_member_in_ungrouped_scan(ungrouped_scan_node, column) + { + // TODO: check if member is a time dimension + if let MemberField::Member(time_dimension_member) = + MemberField::time_dimension( + regular_member.member.clone(), + granularity, + ) + { + return Ok(( + format!("${{{}}}", time_dimension_member.field_name), + sql_query, + )); + } + } + } + } + } if let BuiltinScalarFunction::DatePart = &fun { if args.len() >= 2 { match &args[0] { diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 6b4e403ff56f1..f0131ac9bb29b 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -14572,8 +14572,7 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; - assert!(sql.contains("DATETIME_TRUNC(")); - assert!(sql.contains("WEEK(MONDAY)")); + assert!(sql.contains(".week")); } #[tokio::test] diff --git a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs index 3a1a68edd0470..6419939a1f90a 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs @@ -1657,11 +1657,11 @@ GROUP BY let dimensions = request.dimensions.unwrap(); assert_eq!(dimensions.len(), 1); let dimension = &dimensions[0]; - assert!(dimension.contains("DATE_TRUNC")); + assert!(dimension.contains(".day")); let segments = request.segments.unwrap(); assert_eq!(segments.len(), 1); let segment = &segments[0]; - assert!(segment.contains("DATE_TRUNC")); + assert!(segment.contains(".day")); } /// Aggregation with falsy filter should NOT get pushed to CubeScan with limit=0