diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 226032357..72db211df 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -106,10 +106,23 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: return literal_value elif sql_type.startswith("INTERVAL"): - # Calcite will always convert to milliseconds - # no matter what the actual interval is - # I am not sure if this breaks somewhere, - # but so far it works + # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR + try: + interval_type = sql_type.split()[1].lower() + + if interval_type in {"year", "quarter", "month"}: + # if sql_type is INTERVAL YEAR, Calcite will covert to months + delta = pd.tseries.offsets.DateOffset(months=float(str(literal_value))) + return delta + except IndexError: # pragma: no cover + # no finer granular interval type specified + pass + except TypeError: # pragma: no cover + # interval type is not recognized, fall back to default case + pass + + # Calcite will always convert INTERVAL types except YEAR, QUATER, MONTH to milliseconds + # Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0 return timedelta(milliseconds=float(str(literal_value))) elif sql_type == "BOOLEAN": diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 6fd6dc059..b3eda125c 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -14,7 +14,7 @@ from dask_sql.physical.rex import RexConverter from dask_sql.physical.rex.base import BaseRexPlugin -from dask_sql.utils import LoggableDataFrame, is_frame +from dask_sql.utils import LoggableDataFrame, is_frame, is_datetime, convert_to_datetime from dask_sql.datacontainer import DataContainer logger = logging.getLogger(__name__) @@ -37,25 +37,34 @@ def of(self, op: "Operation") -> "Operation": return Operation(lambda x: self(op(x))) -class TensorScalarOperation(Operation): +class PredicteBasedOperation(Operation): """ Helper operation to call a function on the input, - depending if the first is a dataframe or not + depending if the first arg evaluates, given a predicate function, to true or false """ - def __init__(self, tensor_f: Callable, scalar_f: Callable = None): - """Init with the given operation""" + def __init__(self, predicte: Callable, true_route: Callable, false_route: Callable): super().__init__(self.apply) - - self.tensor_f = tensor_f - self.scalar_f = scalar_f or tensor_f + self.predicte = predicte + self.true_route = true_route + self.false_route = false_route def apply(self, *operands): - """Call the stored functions""" - if is_frame(operands[0]): - return self.tensor_f(*operands) + if self.predicte(operands[0]): + return self.true_route(*operands) - return self.scalar_f(*operands) + return self.false_route(*operands) + + +class TensorScalarOperation(PredicteBasedOperation): + """ + Helper operation to call a function on the input, + depending if the first is a dataframe or not + """ + + def __init__(self, tensor_f: Callable, scalar_f: Callable = None): + """Init with the given operation""" + super().__init__(is_frame, tensor_f, scalar_f) class ReduceOperation(Operation): @@ -377,10 +386,7 @@ def __init__(self): def extract(self, what, df: SeriesOrScalar): input_df = df - if is_frame(df): - df = df.dt - else: - df = pd.to_datetime(df) + df = convert_to_datetime(df) if what == "CENTURY": return da.trunc(df.year / 100) @@ -416,6 +422,48 @@ def extract(self, what, df: SeriesOrScalar): raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.") +class CeilFloorOperation(PredicteBasedOperation): + """ + Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal) + """ + + def __init__(self, round_method: str): + assert round_method in { + "ceil", + "floor", + }, f"Round method can only be either ceil or floor" + + super().__init__( + is_datetime, # if the series is dt type + self._round_datetime, + getattr(da, round_method), + ) + + self.round_method = round_method + + def _round_datetime(self, *operands): + df, unit = operands + + df = convert_to_datetime(df) + + unit_map = { + "DAY": "D", + "HOUR": "H", + "MINUTE": "T", + "SECOND": "S", + "MICROSECOND": "U", + "MILLISECOND": "L", + } + + try: + freq = unit_map[unit.upper()] + return getattr(df, self.round_method)(freq) + except KeyError: + raise NotImplementedError( + f"{self.round_method} TO {unit} is not (yet) implemented." + ) + + class RexCallPlugin(BaseRexPlugin): """ RexCall is used for expressions, which calculate something. @@ -468,12 +516,12 @@ class RexCallPlugin(BaseRexPlugin): "atan": Operation(da.arctan), "atan2": Operation(da.arctan2), "cbrt": Operation(da.cbrt), - "ceil": Operation(da.ceil), + "ceil": CeilFloorOperation("ceil"), "cos": Operation(da.cos), "cot": Operation(lambda x: 1 / da.tan(x)), "degrees": Operation(da.degrees), "exp": Operation(da.exp), - "floor": Operation(da.floor), + "floor": CeilFloorOperation("floor"), "log10": Operation(da.log10), "ln": Operation(da.log), # "mod": Operation(da.mod), # needs cast @@ -501,6 +549,10 @@ class RexCallPlugin(BaseRexPlugin): "current_time": Operation(lambda *args: pd.Timestamp.now()), "current_date": Operation(lambda *args: pd.Timestamp.now()), "current_timestamp": Operation(lambda *args: pd.Timestamp.now()), + "last_day": TensorScalarOperation( + lambda x: x + pd.tseries.offsets.MonthEnd(1), + lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1), + ), } def convert( diff --git a/dask_sql/utils.py b/dask_sql/utils.py index 7637c1956..50a74cc32 100644 --- a/dask_sql/utils.py +++ b/dask_sql/utils.py @@ -52,6 +52,24 @@ def is_frame(df): ) +def is_datetime(obj): + """ + Check if a scalar or a series is of datetime type + """ + return pd.api.types.is_datetime64_any_dtype(obj) or isinstance(obj, datetime) + + +def convert_to_datetime(df): + """ + Covert a scalar or a series to datetime type + """ + if is_frame(df): + df = df.dt + else: + df = pd.to_datetime(df) + return df + + class Pluggable: """ Helper class for everything which can be extended by plugins. diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 1bf161d34..2068af2f5 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -1,5 +1,6 @@ from datetime import datetime +import pytest import numpy as np import pandas as pd import dask.dataframe as dd @@ -432,7 +433,32 @@ def test_date_functions(c): EXTRACT(QUARTER FROM d) AS "quarter", EXTRACT(SECOND FROM d) AS "second", EXTRACT(WEEK FROM d) AS "week", - EXTRACT(YEAR FROM d) AS "year" + EXTRACT(YEAR FROM d) AS "year", + + LAST_DAY(d) as "last_day", + + TIMESTAMPADD(YEAR, 2, d) as "plus_1_year", + TIMESTAMPADD(MONTH, 1, d) as "plus_1_month", + TIMESTAMPADD(WEEK, 1, d) as "plus_1_week", + TIMESTAMPADD(DAY, 1, d) as "plus_1_day", + TIMESTAMPADD(HOUR, 1, d) as "plus_1_hour", + TIMESTAMPADD(MINUTE, 1, d) as "plus_1_min", + TIMESTAMPADD(SECOND, 1, d) as "plus_1_sec", + TIMESTAMPADD(MICROSECOND, 1000, d) as "plus_1000_millisec", + TIMESTAMPADD(QUARTER, 1, d) as "plus_1_qt", + + CEIL(d TO DAY) as ceil_to_day, + CEIL(d TO HOUR) as ceil_to_hour, + CEIL(d TO MINUTE) as ceil_to_minute, + CEIL(d TO SECOND) as ceil_to_seconds, + CEIL(d TO MILLISECOND) as ceil_to_millisec, + + FLOOR(d TO DAY) as floor_to_day, + FLOOR(d TO HOUR) as floor_to_hour, + FLOOR(d TO MINUTE) as floor_to_minute, + FLOOR(d TO SECOND) as floor_to_seconds, + FLOOR(d TO MILLISECOND) as floor_to_millisec + FROM df """ ).compute() @@ -454,7 +480,37 @@ def test_date_functions(c): "second": [42], "week": [39], "year": [2021], + "last_day": [datetime(2021, 10, 31, 15, 53, 42, 47)], + "plus_1_year": [datetime(2023, 10, 3, 15, 53, 42, 47)], + "plus_1_month": [datetime(2021, 11, 3, 15, 53, 42, 47)], + "plus_1_week": [datetime(2021, 10, 10, 15, 53, 42, 47)], + "plus_1_day": [datetime(2021, 10, 4, 15, 53, 42, 47)], + "plus_1_hour": [datetime(2021, 10, 3, 16, 53, 42, 47)], + "plus_1_min": [datetime(2021, 10, 3, 15, 54, 42, 47)], + "plus_1_sec": [datetime(2021, 10, 3, 15, 53, 43, 47)], + "plus_1000_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1047)], + "plus_1_qt": [datetime(2022, 1, 3, 15, 53, 42, 47)], + "ceil_to_day": [datetime(2021, 10, 4)], + "ceil_to_hour": [datetime(2021, 10, 3, 16)], + "ceil_to_minute": [datetime(2021, 10, 3, 15, 54)], + "ceil_to_seconds": [datetime(2021, 10, 3, 15, 53, 43)], + "ceil_to_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1000)], + "floor_to_day": [datetime(2021, 10, 3)], + "floor_to_hour": [datetime(2021, 10, 3, 15)], + "floor_to_minute": [datetime(2021, 10, 3, 15, 53)], + "floor_to_seconds": [datetime(2021, 10, 3, 15, 53, 42)], + "floor_to_millisec": [datetime(2021, 10, 3, 15, 53, 42)], } ) assert_frame_equal(df, expected_df, check_dtype=False) + + # test exception handling + with pytest.raises(NotImplementedError): + df = c.sql( + """ + SELECT + FLOOR(d TO YEAR) as floor_to_year + FROM df + """ + ).compute() diff --git a/tests/unit/test_call.py b/tests/unit/test_call.py index 05047c2cc..9e419eb73 100644 --- a/tests/unit/test_call.py +++ b/tests/unit/test_call.py @@ -212,3 +212,18 @@ def test_dates(): assert op("SECOND", date) == 42 assert op("WEEK", date) == 39 assert op("YEAR", date) == 2021 + + ceil_op = call.CeilFloorOperation("ceil") + floor_op = call.CeilFloorOperation("floor") + + assert ceil_op(date, "DAY") == datetime(2021, 10, 4) + assert ceil_op(date, "HOUR") == datetime(2021, 10, 3, 16) + assert ceil_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 54) + assert ceil_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 43) + assert ceil_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42, 1000) + + assert floor_op(date, "DAY") == datetime(2021, 10, 3) + assert floor_op(date, "HOUR") == datetime(2021, 10, 3, 15) + assert floor_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 53) + assert floor_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 42) + assert floor_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42)