diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index 89aff1f1c..156181a7d 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -58,12 +58,15 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "23.0.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fedc767fbaa36ea50f086215f54f1a007d22046fc4754b0448c657bcbe9f8413" +checksum = "d68391300d5237f6725f0f869ae7cb65d45fcf8a6d18f6ceecd328fb803bef93" dependencies = [ "ahash 0.8.0", + "arrow-array", "arrow-buffer", + "arrow-data", + "arrow-schema", "bitflags", "chrono", "comfy-table", @@ -78,20 +81,52 @@ dependencies = [ "num", "regex", "regex-syntax", - "serde", "serde_json", ] +[[package]] +name = "arrow-array" +version = "24.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0bb00c5862b5eea683812083c495bef01a9a5149da46ad2f4c0e4aa8800f64d" +dependencies = [ + "ahash 0.8.0", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "hashbrown", + "num", +] + [[package]] name = "arrow-buffer" -version = "23.0.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d290050c6e12a81a24ad08525cef2203c4156a6350f75508d49885d677e88ea9" +checksum = "3e594d0fe0026a8bc2459bdc5ac9623e5fb666724a715e0acbc96ba30c5d4cc7" dependencies = [ "half", +] + +[[package]] +name = "arrow-data" +version = "24.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8500df05060d86fdc53e9b5cb32e51bfeaacc040fdeced3eb99ac0d59200ff45" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", "num", ] +[[package]] +name = "arrow-schema" +version = "24.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d1fef01f25e1452c86fa6887f078de8e0aaeeb828370feab205944cfc30e27" + [[package]] name = "async-trait" version = "0.1.57" @@ -314,8 +349,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "506eab038bf2d39ac02c22be30b019873ca01f887148b939d309a0e9523f4515" dependencies = [ "arrow", "ordered-float", @@ -324,8 +360,9 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d2810e369c735d69479e27fe8410e97a76ed07484aa9b3ad7c039efa504257" dependencies = [ "ahash 0.8.0", "arrow", @@ -335,8 +372,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60f3b80326243629d02e33f37e955a7114781c6c44caf9d8b254618157de7143" dependencies = [ "arrow", "async-trait", @@ -350,8 +388,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9bf3b7ae861d351a85174fd4fddb29d249950b2f23671318971960452b4b9ab" dependencies = [ "ahash 0.8.0", "arrow", @@ -374,8 +413,9 @@ dependencies = [ [[package]] name = "datafusion-row" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f44a2a722719c569b437b3aa2108a99dc911369e8d86c44e6293225c3387147" dependencies = [ "arrow", "datafusion-common", @@ -385,16 +425,14 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "12.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=1261741af2a5e142fa0c7916e759859cc18ea59a#1261741af2a5e142fa0c7916e759859cc18ea59a" +version = "13.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e98493e04385c924d1d3d7ab8739c41f1ebf676a46863181103a0fb9c7168fa9" dependencies = [ - "ahash 0.8.0", "arrow", "datafusion-common", "datafusion-expr", - "hashbrown", "sqlparser", - "tokio", ] [[package]] @@ -530,9 +568,9 @@ checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" [[package]] name = "itoa" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754" +checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" [[package]] name = "js-sys" @@ -801,9 +839,9 @@ checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" [[package]] name = "ordered-float" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98ffdb14730ed2ef599c65810c15b000896e21e8776b512de0db0c3d7335cc2a" +checksum = "129d36517b53c461acc6e1580aeb919c8ae6708a4b1eae61c4463a615d4f0411" dependencies = [ "num-traits", ] @@ -1018,20 +1056,6 @@ name = "serde" version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.145" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "serde_json" @@ -1039,7 +1063,7 @@ version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" dependencies = [ - "itoa 1.0.3", + "itoa 1.0.4", "ryu", "serde", ] @@ -1057,15 +1081,15 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd0db749597d91ff862fd1d55ea87f7855a744a8425a64695b6fca237d1dad1" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "sqlparser" -version = "0.23.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0beb13adabbdda01b63d595f38c8bfd19a361e697fd94ce0098a634077bc5b25" +checksum = "0781f2b6bd03e5adf065c8e772b49eaea9f640d06a1b9130330fe8bd2563f4fd" dependencies = [ "log", ] @@ -1103,9 +1127,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.101" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2" +checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" dependencies = [ "proc-macro2", "quote", diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index de4356890..e15cc7f3f 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -9,12 +9,12 @@ edition = "2021" rust-version = "1.62" [dependencies] -arrow = { version = "23.0.0", features = ["prettyprint"] } +arrow = { version = "24.0.0", features = ["prettyprint"] } async-trait = "0.1.41" -datafusion-common = { git = "https://github.com/apache/arrow-datafusion/", rev = "1261741af2a5e142fa0c7916e759859cc18ea59a" } -datafusion-expr = { git = "https://github.com/apache/arrow-datafusion/", rev = "1261741af2a5e142fa0c7916e759859cc18ea59a" } -datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion/", rev = "1261741af2a5e142fa0c7916e759859cc18ea59a" } -datafusion-sql = { git = "https://github.com/apache/arrow-datafusion/", rev = "1261741af2a5e142fa0c7916e759859cc18ea59a" } +datafusion-common = "13.0.0" +datafusion-expr = "13.0.0" +datafusion-optimizer = "13.0.0" +datafusion-sql = "13.0.0" env_logger = "0.9" log = "^0.4" mimalloc = { version = "*", default-features = false } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 6512c54d5..089d818ae 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -390,7 +390,7 @@ impl DaskSQLContext { match existing_plan.original_plan.accept(&mut visitor) { Ok(valid) => { if valid { - optimizer::DaskSqlOptimizer::new() + optimizer::DaskSqlOptimizer::new(true) .run_optimizations(existing_plan.original_plan) .map(|k| PyLogicalPlan { original_plan: k, diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs index ce86e0390..24ddc9e6b 100644 --- a/dask_planner/src/sql/optimizer.rs +++ b/dask_planner/src/sql/optimizer.rs @@ -2,8 +2,14 @@ use datafusion_common::DataFusionError; use datafusion_expr::LogicalPlan; use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; +use datafusion_optimizer::eliminate_filter::EliminateFilter; +use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin; +use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; +use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; +use datafusion_optimizer::simplify_expressions::SimplifyExpressions; use datafusion_optimizer::type_coercion::TypeCoercion; +use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_optimizer::{ common_subexpr_eliminate::CommonSubexprEliminate, eliminate_limit::EliminateLimit, filter_null_join_keys::FilterNullJoinKeys, filter_push_down::FilterPushDown, @@ -19,30 +25,42 @@ use eliminate_agg_distinct::EliminateAggDistinct; /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance pub struct DaskSqlOptimizer { + skip_failing_rules: bool, optimizations: Vec>, } impl DaskSqlOptimizer { /// Creates a new instance of the DaskSqlOptimizer with all the DataFusion desired /// optimizers as well as any custom `OptimizerRule` trait impls that might be desired. - pub fn new() -> Self { + pub fn new(skip_failing_rules: bool) -> Self { let rules: Vec> = vec![ - Box::new(CommonSubexprEliminate::new()), + Box::new(TypeCoercion::new()), + Box::new(SimplifyExpressions::new()), + Box::new(UnwrapCastInComparison::new()), Box::new(DecorrelateWhereExists::new()), Box::new(DecorrelateWhereIn::new()), Box::new(ScalarSubqueryToJoin::new()), + Box::new(SubqueryFilterToJoin::new()), + // simplify expressions does not simplify expressions in subqueries, so we + // run it again after running the optimizations that potentially converted + // subqueries to joins + Box::new(SimplifyExpressions::new()), + Box::new(EliminateFilter::new()), + Box::new(ReduceCrossJoin::new()), + Box::new(CommonSubexprEliminate::new()), Box::new(EliminateLimit::new()), + Box::new(ProjectionPushDown::new()), + Box::new(RewriteDisjunctivePredicate::new()), Box::new(FilterNullJoinKeys::default()), + Box::new(ReduceOuterJoin::new()), Box::new(FilterPushDown::new()), - Box::new(TypeCoercion::new()), Box::new(LimitPushDown::new()), - Box::new(ProjectionPushDown::new()), // Box::new(SingleDistinctToGroupBy::new()), - Box::new(SubqueryFilterToJoin::new()), // Dask-SQL specific optimizations Box::new(EliminateAggDistinct::new()), ]; Self { + skip_failing_rules, optimizations: rules, } } @@ -65,14 +83,128 @@ impl DaskSqlOptimizer { resulting_plan = optimized_plan } Err(e) => { - println!( - "Skipping optimizer rule {} due to unexpected error: {}", - optimization.name(), - e - ); + if self.skip_failing_rules { + println!( + "Skipping optimizer rule {} due to unexpected error: {}", + optimization.name(), + e + ); + } else { + return Err(e); + } } } } Ok(resulting_plan) } } + +#[cfg(test)] +mod tests { + use crate::dialect::DaskDialect; + use crate::sql::optimizer::DaskSqlOptimizer; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::{DataFusionError, Result}; + use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; + use datafusion_sql::{ + planner::{ContextProvider, SqlToRel}, + sqlparser::{ast::Statement, parser::Parser}, + TableReference, + }; + use std::any::Any; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn subquery_filter_with_cast() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3760 + let sql = "SELECT col_int32 FROM test \ + WHERE col_int32 > (\ + SELECT AVG(col_int32) FROM test \ + WHERE col_utf8 BETWEEN '2002-05-08' \ + AND (cast('2002-05-08' as date) + interval '5 days')\ + )"; + let plan = test_sql(sql)?; + let expected = + "Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\ + \n CrossJoin:\ + \n TableScan: test projection=[col_int32]\ + \n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\ + \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ + \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) + } + + fn test_sql(sql: &str) -> Result { + // parse the SQL + let dialect = DaskDialect {}; + let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); + let statement = &ast[0]; + + // create a logical query plan + let schema_provider = MySchemaProvider {}; + let sql_to_rel = SqlToRel::new(&schema_provider); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + + // optimize the logical plan + let optimizer = DaskSqlOptimizer::new(false); + optimizer.run_optimizations(plan) + } + + struct MySchemaProvider {} + + impl ContextProvider for MySchemaProvider { + fn get_table_provider( + &self, + name: TableReference, + ) -> datafusion_common::Result> { + let table_name = name.table(); + if table_name.starts_with("test") { + let schema = Schema::new_with_metadata( + vec![ + Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), + Field::new("col_utf8", DataType::Utf8, true), + Field::new("col_date32", DataType::Date32, true), + Field::new("col_date64", DataType::Date64, true), + ], + HashMap::new(), + ); + + Ok(Arc::new(MyTableSource { + schema: Arc::new(schema), + })) + } else { + Err(DataFusionError::Plan("table does not exist".to_string())) + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + } + + struct MyTableSource { + schema: SchemaRef, + } + + impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + } +} diff --git a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs index 803ad67fc..5ad2e59f6 100644 --- a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs +++ b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs @@ -13,9 +13,9 @@ //! //! Would typically produce a LogicalPlan like ... //! ```text -//! Projection: #COUNT(a.a) AS COUNT(DISTINCT a.a)))\ -//! Aggregate: groupBy=[[]], aggr=[[COUNT(#a.a)]]\ -//! Aggregate: groupBy=[[#a.a]], aggr=[[]]\ +//! Projection: COUNT(a.a) AS COUNT(DISTINCT a.a)))\ +//! Aggregate: groupBy=[[]], aggr=[[COUNT(a.a)]]\ +//! Aggregate: groupBy=[[a.a]], aggr=[[]]\ //! TableScan: test"; //! ``` //! @@ -34,9 +34,9 @@ //! //! Would typically produce a LogicalPlan like ... //! ```text -//! Projection: #SUM(alias2) AS COUNT(a), #COUNT(alias1) AS COUNT(DISTINCT a) -//! Aggregate: groupBy=[[]], aggr=[[SUM(alias2), COUNT(#alias1)]] -//! Aggregate: groupBy=[[#a AS alias1]], aggr=[[COUNT(*) AS alias2]] +//! Projection: SUM(alias2) AS COUNT(a), COUNT(alias1) AS COUNT(DISTINCT a) +//! Aggregate: groupBy=[[]], aggr=[[SUM(alias2), COUNT(alias1)]] +//! Aggregate: groupBy=[[a AS alias1]], aggr=[[COUNT(*) AS alias2]] //! TableScan: test projection=[a] //! //! If the query contains DISTINCT aggregates for multiple columns then we need to perform @@ -46,26 +46,25 @@ //! CrossJoin:\ //! CrossJoin:\ //! CrossJoin:\ -//! Projection: #SUM(__dask_sql_count__1) AS COUNT(a.a), #COUNT(a.a) AS COUNT(DISTINCT a.a)\ -//! Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ -//! Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\ +//! Projection: SUM(__dask_sql_count__1) AS COUNT(a.a), COUNT(a.a) AS COUNT(DISTINCT a.a)\ +//! Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ +//! Aggregate: groupBy=[[a.a]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__1]]\ //! TableScan: a\ -//! Projection: #SUM(__dask_sql_count__2) AS COUNT(a.b), #COUNT(a.b) AS COUNT(DISTINCT(#a.b))\ -//! Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\ -//! Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__2]]\ +//! Projection: SUM(__dask_sql_count__2) AS COUNT(a.b), COUNT(a.b) AS COUNT(DISTINCT(a.b))\ +//! Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__2), COUNT(a.b)]]\ +//! Aggregate: groupBy=[[a.b]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__2]]\ //! TableScan: a\ -//! Projection: #SUM(__dask_sql_count__3) AS COUNT(a.c), #COUNT(a.c) AS COUNT(DISTINCT(#a.c))\ -//! Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\ -//! Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__3]]\ +//! Projection: SUM(__dask_sql_count__3) AS COUNT(a.c), COUNT(a.c) AS COUNT(DISTINCT(a.c))\ +//! Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__3), COUNT(a.c)]]\ +//! Aggregate: groupBy=[[a.c]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__3]]\ //! TableScan: a\ -//! Projection: #SUM(__dask_sql_count__4) AS COUNT(a.d), #COUNT(a.d) AS COUNT(DISTINCT(#a.d))\ -//! Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\ -//! Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\ +//! Projection: SUM(__dask_sql_count__4) AS COUNT(a.d), COUNT(a.d) AS COUNT(DISTINCT(a.d))\ +//! Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__4), COUNT(a.d)]]\ +//! Aggregate: groupBy=[[a.d]], aggr=[[COUNT(UInt64(1)) AS __dask_sql_count__4]]\ //! TableScan: a -use datafusion_common::{Column, DFSchema, Result}; +use datafusion_common::{Column, Result}; use datafusion_expr::logical_plan::Projection; -use datafusion_expr::utils::exprlist_to_fields; use datafusion_expr::{ col, count, logical_plan::{Aggregate, LogicalPlan}, @@ -73,7 +72,6 @@ use datafusion_expr::{ }; use datafusion_optimizer::{utils, OptimizerConfig, OptimizerRule}; use log::trace; -use std::collections::hash_map::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -205,9 +203,9 @@ fn create_plan( if has_distinct && has_non_distinct && distinct_expr.len() == 1 && not_distinct_expr.len() == 1 { - // Projection: #SUM(alias2) AS COUNT(a), #COUNT(alias1) AS COUNT(DISTINCT a) - // Aggregate: groupBy=[[]], aggr=[[SUM(alias2), COUNT(#alias1)]] - // Aggregate: groupBy=[[#a AS alias1]], aggr=[[COUNT(*) AS alias2]] + // Projection: SUM(alias2) AS COUNT(a), COUNT(alias1) AS COUNT(DISTINCT a) + // Aggregate: groupBy=[[]], aggr=[[SUM(alias2), COUNT(alias1)]] + // Aggregate: groupBy=[[a AS alias1]], aggr=[[COUNT(*) AS alias2]] // TableScan: test projection=[a] // The first aggregate groups by the distinct expression and performs a COUNT(*). This @@ -219,18 +217,7 @@ fn create_plan( let expr_name = expr.name()?; let count_expr = Expr::Column(Column::from_qualified_name(&expr_name)); let aggr_expr = vec![count(count_expr).alias(&alias)]; - let mut schema_expr = group_expr.clone(); - schema_expr.extend_from_slice(&aggr_expr); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&schema_expr, input)?, - HashMap::new(), - )?; - LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), - group_expr, - aggr_expr, - Arc::new(schema), - )?) + LogicalPlan::Aggregate(Aggregate::try_new(input.clone(), group_expr, aggr_expr)?) }; trace!("first agg:\n{}", first_aggregate.display_indent_schema()); @@ -256,17 +243,10 @@ fn create_plan( trace!("aggr_expr = {:?}", aggr_expr); - let mut schema_expr = group_expr.clone(); - schema_expr.extend_from_slice(&aggr_expr); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&schema_expr, &first_aggregate)?, - HashMap::new(), - )?; LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(first_aggregate), group_expr.clone(), aggr_expr, - Arc::new(schema), )?) }; @@ -308,9 +288,9 @@ fn create_plan( } else if has_distinct && distinct_expr.len() == 1 { // simple case of a single DISTINCT aggregation // - // Projection: #COUNT(#a) AS COUNT(DISTINCT a) - // Aggregate: groupBy=[[]], aggr=[[COUNT(#a)]] - // Aggregate: groupBy=[[#a]], aggr=[[]] + // Projection: COUNT(a) AS COUNT(DISTINCT a) + // Aggregate: groupBy=[[]], aggr=[[COUNT(a)]] + // Aggregate: groupBy=[[a]], aggr=[[]] // TableScan: test projection=[a] // The first aggregate groups by the distinct expression. This is the equivalent @@ -318,16 +298,7 @@ fn create_plan( let first_aggregate = { let mut group_expr = group_expr.clone(); group_expr.push(expr.clone()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&group_expr, input)?, - HashMap::new(), - )?; - LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), - group_expr, - vec![], - Arc::new(schema), - )?) + LogicalPlan::Aggregate(Aggregate::try_new(input.clone(), group_expr, vec![])?) }; trace!("first agg:\n{}", first_aggregate.display_indent_schema()); @@ -344,17 +315,10 @@ fn create_plan( distinct: false, filter: None, }; - let mut second_aggr_schema = group_expr.clone(); - second_aggr_schema.push(count.clone()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&second_aggr_schema, &first_aggregate)?, - HashMap::new(), - )?; LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(first_aggregate), group_expr.clone(), vec![count], - Arc::new(schema), )?) }; @@ -485,7 +449,7 @@ mod tests { /// Optimize with all of the optimizer rules, including eliminate_agg_distinct fn assert_fully_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimizer = DaskSqlOptimizer::new(); + let optimizer = DaskSqlOptimizer::new(false); let optimized_plan = optimizer .run_optimizations(plan.clone()) .expect("failed to optimize plan"); @@ -524,9 +488,9 @@ mod tests { .aggregate(vec![col("a")], vec![count_distinct(col("b"))])? .build()?; - let expected = "Projection: #a.a, #COUNT(a.b) AS COUNT(DISTINCT a.b)\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.b)]]\ - \n Aggregate: groupBy=[[#a.a, #a.b]], aggr=[[]]\ + let expected = "Projection: a.a, COUNT(a.b) AS COUNT(DISTINCT a.b)\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.b)]]\ + \n Aggregate: groupBy=[[a.a, a.b]], aggr=[[]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -538,9 +502,9 @@ mod tests { .aggregate(vec![col("a")], vec![count_distinct(col("b")).alias("cd_b")])? .build()?; - let expected = "Projection: #a.a, #COUNT(a.b) AS cd_b\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.b)]]\ - \n Aggregate: groupBy=[[#a.a, #a.b]], aggr=[[]]\ + let expected = "Projection: a.a, COUNT(a.b) AS cd_b\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.b)]]\ + \n Aggregate: groupBy=[[a.a, a.b]], aggr=[[]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -553,9 +517,9 @@ mod tests { .aggregate(empty_group_expr, vec![count_distinct(col("a"))])? .build()?; - let expected = "Projection: #COUNT(a.a) AS COUNT(DISTINCT a.a)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[]]\ + let expected = "Projection: COUNT(a.a) AS COUNT(DISTINCT a.a)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -571,9 +535,9 @@ mod tests { )? .build()?; - let expected = "Projection: #COUNT(a.a) AS cd_a\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[]]\ + let expected = "Projection: COUNT(a.a) AS cd_a\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -588,9 +552,10 @@ mod tests { )? .build()?; - let expected = "Projection: #a.b, #a.b AS COUNT(a.a), #SUM(__dask_sql_count__1) AS COUNT(DISTINCT a.a)\ - \n Aggregate: groupBy=[[#a.b]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.b, #a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + let expected = + "Projection: a.b, a.b AS COUNT(a.a), SUM(__dask_sql_count__1) AS COUNT(DISTINCT a.a)\ + \n Aggregate: groupBy=[[a.b]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.b, a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -606,9 +571,10 @@ mod tests { )? .build()?; - let expected = "Projection: #SUM(__dask_sql_count__1) AS COUNT(a.a), #COUNT(a.a) AS COUNT(DISTINCT a.a)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + let expected = + "Projection: SUM(__dask_sql_count__1) AS COUNT(a.a), COUNT(a.a) AS COUNT(DISTINCT a.a)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -627,9 +593,9 @@ mod tests { )? .build()?; - let expected = "Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + let expected = "Projection: SUM(__dask_sql_count__1) AS c_a, COUNT(a.a) AS cd_a\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -657,21 +623,21 @@ mod tests { let expected = "CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n Projection: #SUM(__dask_sql_count__1) AS COUNT(a.a), #COUNT(a.a) AS COUNT(DISTINCT a.a)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + \n Projection: SUM(__dask_sql_count__1) AS COUNT(a.a), COUNT(a.a) AS COUNT(DISTINCT a.a)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__2) AS COUNT(a.b), #COUNT(a.b) AS COUNT(DISTINCT a.b)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\ - \n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\ + \n Projection: SUM(__dask_sql_count__2) AS COUNT(a.b), COUNT(a.b) AS COUNT(DISTINCT a.b)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__2), COUNT(a.b)]]\ + \n Aggregate: groupBy=[[a.b]], aggr=[[COUNT(a.b) AS __dask_sql_count__2]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__3) AS COUNT(a.c), #COUNT(a.c) AS COUNT(DISTINCT a.c)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\ - \n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\ + \n Projection: SUM(__dask_sql_count__3) AS COUNT(a.c), COUNT(a.c) AS COUNT(DISTINCT a.c)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__3), COUNT(a.c)]]\ + \n Aggregate: groupBy=[[a.c]], aggr=[[COUNT(a.c) AS __dask_sql_count__3]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__4) AS COUNT(a.d), #COUNT(a.d) AS COUNT(DISTINCT a.d)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\ - \n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\ + \n Projection: SUM(__dask_sql_count__4) AS COUNT(a.d), COUNT(a.d) AS COUNT(DISTINCT a.d)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__4), COUNT(a.d)]]\ + \n Aggregate: groupBy=[[a.d]], aggr=[[COUNT(a.d) AS __dask_sql_count__4]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -699,42 +665,42 @@ mod tests { let expected = "CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + \n Projection: SUM(__dask_sql_count__1) AS c_a, COUNT(a.a) AS cd_a\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__2) AS c_b, #COUNT(a.b) AS cd_b\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\ - \n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\ + \n Projection: SUM(__dask_sql_count__2) AS c_b, COUNT(a.b) AS cd_b\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__2), COUNT(a.b)]]\ + \n Aggregate: groupBy=[[a.b]], aggr=[[COUNT(a.b) AS __dask_sql_count__2]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__3) AS c_c, #COUNT(a.c) AS cd_c\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\ - \n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\ + \n Projection: SUM(__dask_sql_count__3) AS c_c, COUNT(a.c) AS cd_c\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__3), COUNT(a.c)]]\ + \n Aggregate: groupBy=[[a.c]], aggr=[[COUNT(a.c) AS __dask_sql_count__3]]\ \n TableScan: a\ - \n Projection: #SUM(__dask_sql_count__4) AS c_d, #COUNT(a.d) AS cd_d\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\ - \n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\ + \n Projection: SUM(__dask_sql_count__4) AS c_d, COUNT(a.d) AS cd_d\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__4), COUNT(a.d)]]\ + \n Aggregate: groupBy=[[a.d]], aggr=[[COUNT(a.d) AS __dask_sql_count__4]]\ \n TableScan: a"; assert_optimized_plan_eq(&plan, expected); let expected = "CrossJoin:\ \n CrossJoin:\ \n CrossJoin:\ - \n Projection: #SUM(__dask_sql_count__1) AS c_a, #COUNT(a.a) AS cd_a\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__1), COUNT(#a.a)]]\ - \n Aggregate: groupBy=[[#a.a]], aggr=[[COUNT(#a.a) AS __dask_sql_count__1]]\ + \n Projection: SUM(__dask_sql_count__1) AS c_a, COUNT(a.a) AS cd_a\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__1), COUNT(a.a)]]\ + \n Aggregate: groupBy=[[a.a]], aggr=[[COUNT(a.a) AS __dask_sql_count__1]]\ \n TableScan: a projection=[a, b, c, d]\ - \n Projection: #SUM(__dask_sql_count__2) AS c_b, #COUNT(a.b) AS cd_b\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__2), COUNT(#a.b)]]\ - \n Aggregate: groupBy=[[#a.b]], aggr=[[COUNT(#a.b) AS __dask_sql_count__2]]\ + \n Projection: SUM(__dask_sql_count__2) AS c_b, COUNT(a.b) AS cd_b\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__2), COUNT(a.b)]]\ + \n Aggregate: groupBy=[[a.b]], aggr=[[COUNT(a.b) AS __dask_sql_count__2]]\ \n TableScan: a projection=[a, b, c, d]\ - \n Projection: #SUM(__dask_sql_count__3) AS c_c, #COUNT(a.c) AS cd_c\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__3), COUNT(#a.c)]]\ - \n Aggregate: groupBy=[[#a.c]], aggr=[[COUNT(#a.c) AS __dask_sql_count__3]]\ + \n Projection: SUM(__dask_sql_count__3) AS c_c, COUNT(a.c) AS cd_c\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__3), COUNT(a.c)]]\ + \n Aggregate: groupBy=[[a.c]], aggr=[[COUNT(a.c) AS __dask_sql_count__3]]\ \n TableScan: a projection=[a, b, c, d]\ - \n Projection: #SUM(__dask_sql_count__4) AS c_d, #COUNT(a.d) AS cd_d\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#__dask_sql_count__4), COUNT(#a.d)]]\ - \n Aggregate: groupBy=[[#a.d]], aggr=[[COUNT(#a.d) AS __dask_sql_count__4]]\ + \n Projection: SUM(__dask_sql_count__4) AS c_d, COUNT(a.d) AS cd_d\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(__dask_sql_count__4), COUNT(a.d)]]\ + \n Aggregate: groupBy=[[a.d]], aggr=[[COUNT(a.d) AS __dask_sql_count__4]]\ \n TableScan: a projection=[a, b, c, d]"; assert_fully_optimized_plan_eq(&plan, expected); diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index c72c457a0..64db29d68 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -1,5 +1,4 @@ import logging -from datetime import timedelta from typing import Any import dask.array as da @@ -141,7 +140,9 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: return literal_value elif sql_type == SqlTypeName.INTERVAL_DAY: - return timedelta(days=literal_value[0], milliseconds=literal_value[1]) + return np.timedelta64(literal_value[0], "D") + np.timedelta64( + literal_value[1], "ms" + ) elif sql_type == SqlTypeName.INTERVAL: # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR try: @@ -160,7 +161,7 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: # Calcite will always convert INTERVAL types except YEAR, QUATER, MONTH to milliseconds # Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0 - return timedelta(milliseconds=float(str(literal_value))) + return np.timedelta64(literal_value, "ms") elif sql_type == SqlTypeName.BOOLEAN: return bool(literal_value) @@ -170,7 +171,9 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: or sql_type == SqlTypeName.TIME or sql_type == SqlTypeName.DATE ): - if str(literal_value) == "None": + if isinstance(literal_value, str): + literal_value = np.datetime64(literal_value) + elif str(literal_value) == "None": # NULL time return pd.NaT # pragma: no cover if sql_type == SqlTypeName.DATE: diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index f3d6a3ddd..c0b364d89 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -1,8 +1,6 @@ -import datetime import logging import operator import re -from datetime import timedelta from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, Union @@ -17,7 +15,11 @@ from dask_planner.rust import SqlTypeName from dask_sql.datacontainer import DataContainer -from dask_sql.mappings import cast_column_to_type, sql_to_python_type +from dask_sql.mappings import ( + cast_column_to_type, + sql_to_python_type, + sql_to_python_value, +) from dask_sql.physical.rex import RexConverter from dask_sql.physical.rex.base import BaseRexPlugin from dask_sql.physical.rex.core.literal import SargPythonImplementation @@ -185,7 +187,7 @@ def div(self, lhs, rhs): # We do not need to truncate in this case # So far, I did not spot any other occurrence # of this function. - if isinstance(result, (datetime.timedelta, np.timedelta64)): + if isinstance(result, np.timedelta64): return result else: return da.trunc(result).astype(np.int64) @@ -239,11 +241,13 @@ def __init__(self): super().__init__(self.cast) def cast(self, operand, rex=None) -> SeriesOrScalar: + output_type = str(rex.getType()) + sql_type = SqlTypeName.fromString(output_type.upper()) + if not is_frame(operand): # pragma: no cover - return operand + return sql_to_python_value(sql_type, operand) - output_type = str(rex.getType()) - python_type = sql_to_python_type(SqlTypeName.fromString(output_type.upper())) + python_type = sql_to_python_type(sql_type) return_column = cast_column_to_type(operand, python_type) @@ -613,19 +617,19 @@ def timestampadd(self, unit, interval, df: SeriesOrScalar): df = convert_to_datetime(df) if unit in {"DAY", "SQL_TSI_DAY"}: - return df + timedelta(days=interval) + return df + np.timedelta64(interval, "D") elif unit in {"HOUR", "SQL_TSI_HOUR"}: - return df + timedelta(hours=interval) + return df + np.timedelta64(interval, "h") elif unit == "MICROSECOND": - return df + timedelta(microseconds=interval) + return df + np.timedelta64(interval, "us") elif unit == "MILLISECOND": - return df + timedelta(miliseconds=interval) + return df + np.timedelta64(interval, "ms") elif unit in {"MINUTE", "SQL_TSI_MINUTE"}: - return df + timedelta(minutes=interval) + return df + np.timedelta64(interval, "m") elif unit in {"SECOND", "SQL_TSI_SECOND"}: - return df + timedelta(seconds=interval) + return df + np.timedelta64(interval, "s") elif unit in {"WEEK", "SQL_TSI_WEEK"}: - return df + timedelta(days=interval * 7) + return df + np.timedelta64(interval * 7, "W") else: raise NotImplementedError(f"Extraction of {unit} is not (yet) implemented.") diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index c95d4ef8f..482484eee 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -1,4 +1,5 @@ import logging +from datetime import datetime from typing import TYPE_CHECKING, Any import dask.dataframe as dd @@ -167,18 +168,20 @@ def convert( "TimestampNanosecond", }: unit_mapping = { - "Second": "s", - "Millisecond": "ms", - "Microsecond": "us", - "Nanosecond": "ns", + "TimestampSecond": "s", + "TimestampMillisecond": "ms", + "TimestampMicrosecond": "us", + "TimestampNanosecond": "ns", } + numpy_unit = unit_mapping.get(literal_type) literal_value, timezone = rex.getTimestampValue() if timezone and timezone != "UTC": raise ValueError("Non UTC timezones not supported") + elif timezone is None: + literal_value = datetime.fromtimestamp(literal_value // 10**9) + literal_value = str(literal_value) literal_type = SqlTypeName.TIMESTAMP - literal_value = np.datetime64( - literal_value, unit_mapping.get(literal_type.partition("Timestamp")[2]) - ) + literal_value = np.datetime64(literal_value, numpy_unit) else: raise RuntimeError( f"Failed to map literal type {literal_type} to python type in literal.py" diff --git a/tests/integration/test_explain.py b/tests/integration/test_explain.py index ad94d46c8..2a1793612 100644 --- a/tests/integration/test_explain.py +++ b/tests/integration/test_explain.py @@ -10,7 +10,7 @@ def test_sql_query_explain(c, gpu): sql_string = c.sql("EXPLAIN SELECT * FROM df") - assert sql_string.startswith("Projection: #df.a\n") + assert sql_string.startswith("Projection: df.a\n") # TODO: Need to add statistics to Rust optimizer before this can be uncommented. # c.create_table("df", data_frame, statistics=Statistics(row_count=1337)) @@ -26,5 +26,5 @@ def test_sql_query_explain(c, gpu): dataframes={"other_df": df}, gpu=gpu, ) - assert sql_string.startswith("Projection: #MIN(other_df.a) AS a_min\n") - assert "Aggregate: groupBy=[[#other_df.a]], aggr=[[MIN(#other_df.a)]]" in sql_string + assert sql_string.startswith("Projection: MIN(other_df.a) AS a_min\n") + assert "Aggregate: groupBy=[[other_df.a]], aggr=[[MIN(other_df.a)]]" in sql_string diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 6238a346d..a3c1f1060 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -329,11 +329,11 @@ def test_intersect(c): actual_df = c.sql( """ select count(*) from ( - select * from df_simple + select a, b from df_simple intersect - select * from df_simple + select a, b from df_simple intersect - select * from df_wide + select a, b from df_wide ) hot_item limit 100 """ diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 5945df8f9..39b09fb44 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -94,6 +94,23 @@ def test_literals(c): assert_eq(df, expected_df) +def test_date_interval_math(c): + df = c.sql( + """SELECT + DATE '1998-08-18' - INTERVAL '4 days' AS "before", + DATE '1998-08-18' + INTERVAL '4 days' AS "after" + """ + ) + + expected_df = pd.DataFrame( + { + "before": [pd.to_datetime("1998-08-14 00:00")], + "after": [pd.to_datetime("1998-08-22 00:00")], + } + ) + assert_eq(df, expected_df) + + def test_literal_null(c): df = c.sql( """ diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 979ee0296..d10f7012b 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -63,7 +63,7 @@ def test_explain(gpu): sql_string = c.explain("SELECT * FROM df") - assert sql_string.startswith("Projection: #df.a\n") + assert sql_string.startswith("Projection: df.a\n") # TODO: Need to add statistics to Rust optimizer before this can be uncommented. # c.create_table("df", data_frame, statistics=Statistics(row_count=1337)) @@ -82,7 +82,7 @@ def test_explain(gpu): "SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu ) - assert sql_string.startswith("Projection: #other_df.a\n") + assert sql_string.startswith("Projection: other_df.a\n") @pytest.mark.parametrize( diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index 606c7a3ad..871680a96 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -32,7 +32,6 @@ 44, 45, 47, - 48, 49, 50, 51, @@ -54,7 +53,6 @@ 80, 82, 84, - 85, 86, 87, 88,