diff --git a/dask_planner/src/dialect.rs b/dask_planner/src/dialect.rs index 973f76f4f..b27c81ec3 100644 --- a/dask_planner/src/dialect.rs +++ b/dask_planner/src/dialect.rs @@ -1,6 +1,11 @@ use core::{iter::Peekable, str::Chars}; -use datafusion_sql::sqlparser::dialect::Dialect; +use datafusion_sql::sqlparser::{ + ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value}, + dialect::Dialect, + parser::{Parser, ParserError}, + tokenizer::Token, +}; #[derive(Debug)] pub struct DaskDialect {} @@ -37,4 +42,75 @@ impl Dialect for DaskDialect { fn supports_filter_during_aggregation(&self) -> bool { true } + + /// override expression parsing + fn parse_prefix(&self, parser: &mut Parser) -> Option> { + fn parse_expr(parser: &mut Parser) -> Result, ParserError> { + match parser.peek_token() { + Token::Word(w) if w.value.to_lowercase() == "timestampadd" => { + // TIMESTAMPADD(YEAR, 2, d) + parser.next_token(); // skip timestampadd + parser.expect_token(&Token::LParen)?; + let time_unit = parser.next_token(); + parser.expect_token(&Token::Comma)?; + let n = parser.parse_expr()?; + parser.expect_token(&Token::Comma)?; + let expr = parser.parse_expr()?; + parser.expect_token(&Token::RParen)?; + + // convert to function args + let args = vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString(time_unit.to_string()), + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(n)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), + ]; + + Ok(Some(Expr::Function(Function { + name: ObjectName(vec![Ident::new("timestampadd")]), + args, + over: None, + distinct: false, + special: false, + }))) + } + Token::Word(w) if w.value.to_lowercase() == "to_timestamp" => { + // TO_TIMESTAMP(d, "%d/%m/%Y") + parser.next_token(); // skip to_timestamp + parser.expect_token(&Token::LParen)?; + let expr = parser.parse_expr()?; + let comma = parser.consume_token(&Token::Comma); + let time_format = if comma { + parser.next_token().to_string() + } else { + "%Y-%m-%d %H:%M:%S".to_string() + }; + parser.expect_token(&Token::RParen)?; + + // convert to function args + let args = vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString(time_format), + ))), + ]; + + Ok(Some(Expr::Function(Function { + name: ObjectName(vec![Ident::new("dsql_totimestamp")]), + args, + over: None, + distinct: false, + special: false, + }))) + } + _ => Ok(None), + } + } + match parse_expr(parser) { + Ok(Some(expr)) => Some(Ok(expr)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + } } diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs index d743af901..61be0e1cd 100644 --- a/dask_planner/src/parser.rs +++ b/dask_planner/src/parser.rs @@ -1236,6 +1236,52 @@ impl<'a> DaskParser<'a> { mod test { use crate::parser::{DaskParser, DaskStatement}; + #[test] + fn timestampadd() { + let sql = "SELECT TIMESTAMPADD(YEAR, 2, d) FROM t"; + let statements = DaskParser::parse_sql(sql).unwrap(); + assert_eq!(1, statements.len()); + let actual = format!("{:?}", statements[0]); + let expected = "projection: [\ + UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"timestampadd\", quote_style: None }]), \ + args: [\ + Unnamed(Expr(Value(SingleQuotedString(\"YEAR\")))), \ + Unnamed(Expr(Value(Number(\"2\", false)))), \ + Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None })))\ + ], over: None, distinct: false, special: false }))\ + ]"; + assert!(actual.contains(expected)); + } + + #[test] + fn to_timestamp() { + let sql1 = "SELECT TO_TIMESTAMP(d) FROM t"; + let statements1 = DaskParser::parse_sql(sql1).unwrap(); + assert_eq!(1, statements1.len()); + let actual1 = format!("{:?}", statements1[0]); + let expected1 = "projection: [\ + UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), \ + args: [\ + Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), \ + Unnamed(Expr(Value(SingleQuotedString(\"%Y-%m-%d %H:%M:%S\"))))\ + ], over: None, distinct: false, special: false }))\ + ]"; + assert!(actual1.contains(expected1)); + + let sql2 = "SELECT TO_TIMESTAMP(d, \"%d/%m/%Y\") FROM t"; + let statements2 = DaskParser::parse_sql(sql2).unwrap(); + assert_eq!(1, statements2.len()); + let actual2 = format!("{:?}", statements2[0]); + let expected2 = "projection: [\ + UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), \ + args: [\ + Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), \ + Unnamed(Expr(Value(SingleQuotedString(\"\\\"%d/%m/%Y\\\"\"))))\ + ], over: None, distinct: false, special: false }))\ + ]"; + assert!(actual2.contains(expected2)); + } + #[test] fn create_model() { let sql = r#"CREATE MODEL my_model WITH ( diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index d52211ff7..bf6ce16ab 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -152,6 +152,24 @@ impl ContextProvider for DaskSQLContext { let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } + "dsql_totimestamp" => { + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Int16, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::UInt8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::UInt16, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::UInt32, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::UInt64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + ], + Volatility::Immutable, + ); + let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64))); + return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); + } "mod" => { let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index a66b178dc..6a5b01c17 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -1,6 +1,7 @@ import logging import operator import re +from datetime import datetime from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, Union @@ -613,6 +614,41 @@ def extract(self, what, df: SeriesOrScalar): raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.") +class ToTimestampOperation(Operation): + def __init__(self): + super().__init__(self.to_timestamp) + + def to_timestamp(self, df, format): + default_format = "%Y-%m-%d %H:%M:%S" + # Remove double and single quotes from string + format = format.replace('"', "") + format = format.replace("'", "") + + # TODO: format timestamps for GPU tests + if "cudf" in str(type(df)): + if format != default_format: + raise RuntimeError("Non-default timestamp formats not supported on GPU") + if df.dtype == "object": + return df + else: + nanoseconds_to_seconds = 10**9 + return df * nanoseconds_to_seconds + # String cases + elif type(df) == str: + return np.datetime64(datetime.strptime(df, format)) + elif df.dtype == "object": + return dd.to_datetime(df, format=format) + # Integer cases + elif np.isscalar(df): + if format != default_format: + raise RuntimeError("Integer input does not accept a format argument") + return np.datetime64(int(df), "s") + else: + if format != default_format: + raise RuntimeError("Integer input does not accept a format argument") + return dd.to_datetime(df, unit="s") + + class YearOperation(Operation): def __init__(self): super().__init__(self.extract_year) @@ -990,6 +1026,7 @@ class RexCallPlugin(BaseRexPlugin): lambda x: x + pd.tseries.offsets.MonthEnd(1), lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1), ), + "dsql_totimestamp": ToTimestampOperation(), # Temporary UDF functions that need to be moved after this POC "datepart": DatePartOperation(), "year": YearOperation(), diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index b7d455fe3..510bf953b 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -677,3 +677,122 @@ def test_date_functions(c): FROM df """ ) + + +@pytest.mark.parametrize( + "gpu", + [ + False, + pytest.param( + True, + marks=( + pytest.mark.gpu, + pytest.mark.xfail( + reason="Failing due to dask-cudf bug https://github.com/rapidsai/cudf/issues/12062" + ), + ), + ), + ], +) +def test_totimestamp(c, gpu): + df = pd.DataFrame( + { + "a": np.array([1203073300, 1406073600, 2806073600]), + } + ) + c.create_table("df", df, gpu=gpu) + + df = c.sql( + """ + SELECT to_timestamp(a) AS date FROM df + """ + ) + expected_df = pd.DataFrame( + { + "date": [ + datetime(2008, 2, 15, 11, 1, 40), + datetime(2014, 7, 23), + datetime(2058, 12, 2, 16, 53, 20), + ], + } + ) + assert_eq(df, expected_df, check_dtype=False) + + df = pd.DataFrame( + { + "a": np.array(["1997-02-28 10:30:00", "1997-03-28 10:30:01"]), + } + ) + c.create_table("df", df, gpu=gpu) + + df = c.sql( + """ + SELECT to_timestamp(a) AS date FROM df + """ + ) + expected_df = pd.DataFrame( + { + "date": [ + datetime(1997, 2, 28, 10, 30, 0), + datetime(1997, 3, 28, 10, 30, 1), + ], + } + ) + assert_eq(df, expected_df, check_dtype=False) + + df = pd.DataFrame( + { + "a": np.array(["02/28/1997", "03/28/1997"]), + } + ) + c.create_table("df", df, gpu=gpu) + + df = c.sql( + """ + SELECT to_timestamp(a, "%m/%d/%Y") AS date FROM df + """ + ) + expected_df = pd.DataFrame( + { + "date": [ + datetime(1997, 2, 28, 0, 0, 0), + datetime(1997, 3, 28, 0, 0, 0), + ], + } + ) + # https://github.com/rapidsai/cudf/issues/12062 + if not gpu: + assert_eq(df, expected_df, check_dtype=False) + + int_input = 1203073300 + df = c.sql(f"SELECT to_timestamp({int_input}) as date") + expected_df = pd.DataFrame( + { + "date": [ + datetime(2008, 2, 15, 11, 1, 40), + ], + } + ) + assert_eq(df, expected_df, check_dtype=False) + + string_input = "1997-02-28 10:30:00" + df = c.sql(f"SELECT to_timestamp('{string_input}') as date") + expected_df = pd.DataFrame( + { + "date": [ + datetime(1997, 2, 28, 10, 30, 0), + ], + } + ) + assert_eq(df, expected_df, check_dtype=False) + + string_input = "02/28/1997" + df = c.sql(f"SELECT to_timestamp('{string_input}', '%m/%d/%Y') as date") + expected_df = pd.DataFrame( + { + "date": [ + datetime(1997, 2, 28, 0, 0, 0), + ], + } + ) + assert_eq(df, expected_df, check_dtype=False)