Skip to content

Commit

Permalink
feat(cubesql): Metabase - filters with relative dates support (#4851)
Browse files Browse the repository at this point in the history
  • Loading branch information
gandronchik committed Jul 12, 2022
1 parent 3fc5a5c commit 423be2f
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 60 deletions.
61 changes: 44 additions & 17 deletions rust/cubesql/cubesql/src/compile/engine/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ use datafusion::{
arrow::{
array::{
new_null_array, Array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder,
Float64Array, GenericStringArray, Int64Array, Int64Builder, IntervalDayTimeArray,
IntervalDayTimeBuilder, ListArray, ListBuilder, PrimitiveArray, PrimitiveBuilder,
StringArray, StringBuilder, StructBuilder, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
UInt32Builder,
Float64Array, GenericStringArray, Int64Array, Int64Builder, IntervalDayTimeBuilder,
ListArray, ListBuilder, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder,
StructBuilder, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt32Builder,
},
compute::{cast, concat},
datatypes::{
DataType, Field, Float64Type, Int32Type, Int64Type, IntervalDayTimeType, IntervalUnit,
TimeUnit, TimestampNanosecondType, UInt32Type, UInt64Type,
IntervalYearMonthType, TimeUnit, TimestampNanosecondType, UInt32Type, UInt64Type,
},
},
error::{DataFusionError, Result},
Expand Down Expand Up @@ -958,22 +957,30 @@ pub fn create_date_sub_udf() -> ScalarUDF {
)
}

pub fn create_date_add_udf() -> ScalarUDF {
let fun = make_scalar_function(move |args: &[ArrayRef]| {
let timestamps = args[0]
macro_rules! date_add_udf {
($ARGS:expr, $TYPE: ident) => {{
let timestamps = &$ARGS[0]
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
let intervals = args[1]
.as_any()
.downcast_ref::<IntervalDayTimeArray>()
.unwrap();
let intervals = downcast_primitive_arg!(&$ARGS[1], "interval", $TYPE);
let mut builder = TimestampNanosecondArray::builder(timestamps.len());
for i in 0..timestamps.len() {
let timestamp = timestamps.value(i);
let interval = intervals.value(i);
let interval_days = interval >> 32;
let interval_millis = interval & 0xffffffff;
let (interval_days, interval_millis) = match &$ARGS[1].data_type() {
DataType::Interval(IntervalUnit::DayTime) => {
let interval: i64 = intervals.value(i).into();
(interval >> 32, interval & 0xffffffff)
}
DataType::Interval(IntervalUnit::YearMonth) => {
((intervals.value(i) * 30) as i64, 0_i64)
}
_ => {
return Err(DataFusionError::Execution(format!(
"unsupported interval type"
)))
}
};
let timestamp = NaiveDateTime::from_timestamp(
timestamp / 1000000000,
(timestamp % 1000000000) as u32,
Expand All @@ -986,7 +993,19 @@ pub fn create_date_add_udf() -> ScalarUDF {
.unwrap();
builder.append_value(timestamp.timestamp_nanos())?;
}
Ok(Arc::new(builder.finish()))
return Ok(Arc::new(builder.finish()));
}};
}

pub fn create_date_add_udf() -> ScalarUDF {
let fun = make_scalar_function(move |args: &[ArrayRef]| match &args[1].data_type() {
DataType::Interval(IntervalUnit::DayTime) => date_add_udf!(args, IntervalDayTimeType),
DataType::Interval(IntervalUnit::YearMonth) => {
date_add_udf!(args, IntervalYearMonthType)
}
_ => Err(DataFusionError::Execution(format!(
"unsupported interval type"
))),
});

let return_type: ReturnTypeFunction =
Expand All @@ -1004,6 +1023,14 @@ pub fn create_date_add_udf() -> ScalarUDF {
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_string())),
DataType::Interval(IntervalUnit::DayTime),
]),
TypeSignature::Exact(vec![
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Interval(IntervalUnit::YearMonth),
]),
TypeSignature::Exact(vec![
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_string())),
DataType::Interval(IntervalUnit::YearMonth),
]),
],
Volatility::Immutable,
),
Expand Down
76 changes: 75 additions & 1 deletion rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3767,7 +3767,7 @@ mod tests {

assert_eq!(
logical_plan,
"Projection: CAST(TimestampNanosecond(0, None) AS Timestamp(Nanosecond, None)) AS COL\
"Projection: TimestampNanosecond(0, None) AS COL\
\n EmptyRelation",
);
}
Expand Down Expand Up @@ -9857,4 +9857,78 @@ ORDER BY \"COUNT(count)\" DESC"
}
)
}

#[tokio::test]
async fn metabase_date_filters() {
init_logger();

let now = "str_to_date('2022-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')";
let cases = vec![
// last 30 days
[
format!("CAST(({} + (INTERVAL '-30 day')) AS date)", now),
format!("CAST({} AS date)", now),
"2021-12-02T00:00:00.000Z".to_string(),
"2021-12-31T23:59:59.999Z".to_string(),
],
// last 30 weeks
[
format!("(CAST(date_trunc('week', (({} + (INTERVAL '-30 week')) + (INTERVAL '1 day'))) AS timestamp) + (INTERVAL '-1 day'))", now),
format!("(CAST(date_trunc('week', ({} + (INTERVAL '1 day'))) AS timestamp) + (INTERVAL '-1 day'))", now),
"2021-05-30T00:00:00.000Z".to_string(),
"2021-12-25T23:59:59.999Z".to_string(),
],
// last 30 quarters
[
format!("date_trunc('quarter', ({} + (INTERVAL '-90 month')))", now),
format!("date_trunc('quarter', {})", now),
"2014-07-01T00:00:00.000Z".to_string(),
"2021-12-31T23:59:59.999Z".to_string(),
],
// this year
[
format!("date_trunc('year', {})", now),
format!("date_trunc('year', ({} + (INTERVAL '1 year')))", now),
"2022-01-01T00:00:00.000Z".to_string(),
"2021-12-31T23:59:59.999Z".to_string(),
],
// next 2 years including current
[
format!("date_trunc('year', {})", now),
format!("date_trunc('year', ({} + (INTERVAL '3 year')))", now),
"2022-01-01T00:00:00.000Z".to_string(),
"2023-12-31T23:59:59.999Z".to_string(),
],
];
for [lte, gt, from, to] in cases {
let logical_plan = convert_select_to_query_plan(
format!(
"SELECT count FROM (SELECT count FROM KibanaSampleDataEcommerce
WHERE (order_date >= {} AND order_date < {})) source",
lte, gt
),
DatabaseProtocol::PostgreSQL,
)
.await
.as_logical_plan();

assert_eq!(
logical_plan.find_cube_scan().request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
dimensions: Some(vec![]),
segments: Some(vec![]),
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
granularity: None,
date_range: Some(json!(vec![from, to])),
}]),
order: None,
limit: None,
offset: None,
filters: None
}
)
}
}
}
80 changes: 71 additions & 9 deletions rust/cubesql/cubesql/src/compile/rewrite/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ pub struct LogicalPlanData {
pub column: Option<Column>,
pub expr_to_alias: Option<Vec<(Expr, String)>>,
pub referenced_expr: Option<Vec<Expr>>,
pub constant: Option<ScalarValue>,
pub constant: Option<ConstantFolding>,
pub constant_in_list: Option<Vec<ScalarValue>>,
pub cube_reference: Option<String>,
}

#[derive(Debug, Clone)]
pub enum ConstantFolding {
Scalar(ScalarValue),
List(Vec<ScalarValue>),
}

#[derive(Clone)]
pub struct LogicalPlanAnalysis {
cube_context: Arc<CubeContext>,
Expand Down Expand Up @@ -294,14 +300,21 @@ impl LogicalPlanAnalysis {
fn make_constant(
egraph: &EGraph<LogicalPlanLanguage, Self>,
enode: &LogicalPlanLanguage,
) -> Option<ScalarValue> {
) -> Option<ConstantFolding> {
let constant_node = |id| egraph.index(id).data.constant.clone();
let constant_expr = |id| {
egraph
.index(id)
.data
.constant
.clone()
.map(|c| Expr::Literal(c))
.and_then(|c| {
if let ConstantFolding::Scalar(c) = c {
Some(Expr::Literal(c))
} else {
None
}
})
.ok_or_else(|| CubeError::internal("Not a constant".to_string()))
};
match enode {
Expand All @@ -314,7 +327,7 @@ impl LogicalPlanAnalysis {
)
.ok()?;
match expr {
Expr::Literal(value) => Some(value),
Expr::Literal(value) => Some(ConstantFolding::Scalar(value)),
_ => panic!("Expected Literal but got: {:?}", expr),
}
}
Expand Down Expand Up @@ -358,6 +371,20 @@ impl LogicalPlanAnalysis {
panic!("Expected ScalarFunctionExpr but got: {:?}", expr);
}
}
LogicalPlanLanguage::ScalarFunctionExprArgs(params)
| LogicalPlanLanguage::ScalarUDFExprArgs(params) => {
let mut list = Vec::new();
for id in params.iter() {
match constant_node(*id)? {
ConstantFolding::Scalar(v) => list.push(v),
ConstantFolding::List(v) => list.extend(v),
};
}
// TODO ConstantFolding::List currently used only to trigger redo analysis for it's parents.
// TODO It should be used also when actual lists are evaluated as a part of node_to_expr() call.
// TODO In case multiple node variant exists ConstantFolding::List will choose one which contains actual constants.
Some(ConstantFolding::List(list))
}
LogicalPlanLanguage::AnyExpr(_) => {
let expr = node_to_expr(
enode,
Expand All @@ -369,6 +396,33 @@ impl LogicalPlanAnalysis {

Self::eval_constant_expr(&egraph, &expr)
}
LogicalPlanLanguage::CastExpr(_) => {
let expr = node_to_expr(
enode,
&egraph.analysis.cube_context,
&constant_expr,
&SingleNodeIndex { egraph },
)
.ok()?;

// Ignore any string casts as local timestamps casted incorrectly
if let Expr::Cast { expr, .. } = &expr {
if let Expr::Literal(ScalarValue::Utf8(_)) = expr.as_ref() {
return None;
}
}

// TODO: Support decimal type in filters and remove it
if let Expr::Cast {
data_type: DataType::Decimal(_, _),
..
} = &expr
{
return None;
}

Self::eval_constant_expr(&egraph, &expr)
}
LogicalPlanLanguage::BinaryExpr(_) => {
let expr = node_to_expr(
enode,
Expand Down Expand Up @@ -411,7 +465,13 @@ impl LogicalPlanAnalysis {
.iter()
.map(|id| {
constant(*id)
.map(|c| vec![c])
.and_then(|c| {
if let ConstantFolding::Scalar(c) = c {
Some(vec![c])
} else {
None
}
})
.or_else(|| constant_in_list(*id))
})
.collect::<Option<Vec<_>>>()?
Expand All @@ -426,7 +486,7 @@ impl LogicalPlanAnalysis {
fn eval_constant_expr(
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
expr: &Expr,
) -> Option<ScalarValue> {
) -> Option<ConstantFolding> {
let schema = DFSchema::empty();
let arrow_schema = Arc::new(schema.to_owned().into());
let physical_expr = match egraph.analysis.planner.create_physical_expr(
Expand Down Expand Up @@ -458,10 +518,10 @@ impl LogicalPlanAnalysis {
}
};
Some(match value {
ColumnarValue::Scalar(value) => value,
ColumnarValue::Scalar(value) => ConstantFolding::Scalar(value),
ColumnarValue::Array(arr) => {
if arr.len() == 1 {
ScalarValue::try_from_array(&arr, 0).unwrap()
ConstantFolding::Scalar(ScalarValue::try_from_array(&arr, 0).unwrap())
} else {
log::trace!(
"Expected one row but got {} during constant eval",
Expand Down Expand Up @@ -541,19 +601,21 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
let (column_name_to_alias, b) = self.merge_option_field(a, b, |d| &mut d.expr_to_alias);
let (referenced_columns, b) = self.merge_option_field(a, b, |d| &mut d.referenced_expr);
let (constant_in_list, b) = self.merge_option_field(a, b, |d| &mut d.constant_in_list);
let (constant, b) = self.merge_option_field(a, b, |d| &mut d.constant);
let (cube_reference, b) = self.merge_option_field(a, b, |d| &mut d.cube_reference);
let (column_name, _) = self.merge_option_field(a, b, |d| &mut d.column);
original_expr
| member_name_to_expr
| column_name_to_alias
| referenced_columns
| constant_in_list
| constant
| cube_reference
| column_name
}

fn modify(egraph: &mut EGraph<LogicalPlanLanguage, Self>, id: Id) {
if let Some(c) = &egraph[id].data.constant {
if let Some(ConstantFolding::Scalar(c)) = &egraph[id].data.constant {
let c = c.clone();
let value = egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(c)));
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([value]));
Expand Down
15 changes: 13 additions & 2 deletions rust/cubesql/cubesql/src/compile/rewrite/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,24 @@ macro_rules! variant_field_struct {

impl core::hash::Hash for [<$variant $var_field:camel>] {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
std::mem::discriminant(&self.0).hash(state);
self.0.hash(state);
}
}

impl core::cmp::PartialEq for [<$variant $var_field:camel>] {
fn eq(&self, other: &[<$variant $var_field:camel>]) -> bool {
self.0 == other.0
// TODO Datafusion has incorrect Timestamp comparison without timezone involved
match &self.0 {
ScalarValue::TimestampNanosecond(_, self_tz) => {
match &other.0 {
ScalarValue::TimestampNanosecond(_, other_tz) => {
self_tz == other_tz && self.0 == other.0
}
_ => self.0 == other.0
}
}
_ => self.0 == other.0
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ impl Rewriter {
let rules = Self::rewrite_rules(cube_context.clone());
let runner = Self::rewrite_runner(cube_context.clone(), egraph);
let runner = runner.run(rules.iter());
log::debug!("Iterations: {:?}", runner.iterations);
if !IterInfo::egraph_debug_enabled() {
log::debug!("Iterations: {:?}", runner.iterations);
}
let stop_reason = &runner.iterations[runner.iterations.len() - 1].stop_reason;
let stop_reason = match stop_reason {
None => Some("timeout reached".to_string()),
Expand Down

0 comments on commit 423be2f

Please sign in to comment.