Skip to content

Commit

Permalink
Feature/datetime functions (#91)
Browse files Browse the repository at this point in the history
* Implement the last_day function

* Modify interval literals for timestampadd

* Ceil and floor for datetime

* Clean up docs and comments

* Code formatting + common util functions

* Add tests for exception handling
  • Loading branch information
Yuhuishishishi authored Dec 7, 2020
1 parent 7a9b4e1 commit 6fa88ed
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 23 deletions.
21 changes: 17 additions & 4 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,23 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any:
return literal_value

elif sql_type.startswith("INTERVAL"):
# Calcite will always convert to milliseconds
# no matter what the actual interval is
# I am not sure if this breaks somewhere,
# but so far it works
# check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR
try:
interval_type = sql_type.split()[1].lower()

if interval_type in {"year", "quarter", "month"}:
# if sql_type is INTERVAL YEAR, Calcite will covert to months
delta = pd.tseries.offsets.DateOffset(months=float(str(literal_value)))
return delta
except IndexError: # pragma: no cover
# no finer granular interval type specified
pass
except TypeError: # pragma: no cover
# interval type is not recognized, fall back to default case
pass

# Calcite will always convert INTERVAL types except YEAR, QUATER, MONTH to milliseconds
# Issue: if sql_type is INTERVAL MICROSECOND, and value <= 1000, literal_value will be rounded to 0
return timedelta(milliseconds=float(str(literal_value)))

elif sql_type == "BOOLEAN":
Expand Down
88 changes: 70 additions & 18 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dask_sql.physical.rex import RexConverter
from dask_sql.physical.rex.base import BaseRexPlugin
from dask_sql.utils import LoggableDataFrame, is_frame
from dask_sql.utils import LoggableDataFrame, is_frame, is_datetime, convert_to_datetime
from dask_sql.datacontainer import DataContainer

logger = logging.getLogger(__name__)
Expand All @@ -37,25 +37,34 @@ def of(self, op: "Operation") -> "Operation":
return Operation(lambda x: self(op(x)))


class TensorScalarOperation(Operation):
class PredicteBasedOperation(Operation):
"""
Helper operation to call a function on the input,
depending if the first is a dataframe or not
depending if the first arg evaluates, given a predicate function, to true or false
"""

def __init__(self, tensor_f: Callable, scalar_f: Callable = None):
"""Init with the given operation"""
def __init__(self, predicte: Callable, true_route: Callable, false_route: Callable):
super().__init__(self.apply)

self.tensor_f = tensor_f
self.scalar_f = scalar_f or tensor_f
self.predicte = predicte
self.true_route = true_route
self.false_route = false_route

def apply(self, *operands):
"""Call the stored functions"""
if is_frame(operands[0]):
return self.tensor_f(*operands)
if self.predicte(operands[0]):
return self.true_route(*operands)

return self.scalar_f(*operands)
return self.false_route(*operands)


class TensorScalarOperation(PredicteBasedOperation):
"""
Helper operation to call a function on the input,
depending if the first is a dataframe or not
"""

def __init__(self, tensor_f: Callable, scalar_f: Callable = None):
"""Init with the given operation"""
super().__init__(is_frame, tensor_f, scalar_f)


class ReduceOperation(Operation):
Expand Down Expand Up @@ -377,10 +386,7 @@ def __init__(self):

def extract(self, what, df: SeriesOrScalar):
input_df = df
if is_frame(df):
df = df.dt
else:
df = pd.to_datetime(df)
df = convert_to_datetime(df)

if what == "CENTURY":
return da.trunc(df.year / 100)
Expand Down Expand Up @@ -416,6 +422,48 @@ def extract(self, what, df: SeriesOrScalar):
raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.")


class CeilFloorOperation(PredicteBasedOperation):
"""
Apply ceil/floor operations on a series depending on its dtype (datetime like vs normal)
"""

def __init__(self, round_method: str):
assert round_method in {
"ceil",
"floor",
}, f"Round method can only be either ceil or floor"

super().__init__(
is_datetime, # if the series is dt type
self._round_datetime,
getattr(da, round_method),
)

self.round_method = round_method

def _round_datetime(self, *operands):
df, unit = operands

df = convert_to_datetime(df)

unit_map = {
"DAY": "D",
"HOUR": "H",
"MINUTE": "T",
"SECOND": "S",
"MICROSECOND": "U",
"MILLISECOND": "L",
}

try:
freq = unit_map[unit.upper()]
return getattr(df, self.round_method)(freq)
except KeyError:
raise NotImplementedError(
f"{self.round_method} TO {unit} is not (yet) implemented."
)


class RexCallPlugin(BaseRexPlugin):
"""
RexCall is used for expressions, which calculate something.
Expand Down Expand Up @@ -468,12 +516,12 @@ class RexCallPlugin(BaseRexPlugin):
"atan": Operation(da.arctan),
"atan2": Operation(da.arctan2),
"cbrt": Operation(da.cbrt),
"ceil": Operation(da.ceil),
"ceil": CeilFloorOperation("ceil"),
"cos": Operation(da.cos),
"cot": Operation(lambda x: 1 / da.tan(x)),
"degrees": Operation(da.degrees),
"exp": Operation(da.exp),
"floor": Operation(da.floor),
"floor": CeilFloorOperation("floor"),
"log10": Operation(da.log10),
"ln": Operation(da.log),
# "mod": Operation(da.mod), # needs cast
Expand Down Expand Up @@ -501,6 +549,10 @@ class RexCallPlugin(BaseRexPlugin):
"current_time": Operation(lambda *args: pd.Timestamp.now()),
"current_date": Operation(lambda *args: pd.Timestamp.now()),
"current_timestamp": Operation(lambda *args: pd.Timestamp.now()),
"last_day": TensorScalarOperation(
lambda x: x + pd.tseries.offsets.MonthEnd(1),
lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),
),
}

def convert(
Expand Down
18 changes: 18 additions & 0 deletions dask_sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def is_frame(df):
)


def is_datetime(obj):
"""
Check if a scalar or a series is of datetime type
"""
return pd.api.types.is_datetime64_any_dtype(obj) or isinstance(obj, datetime)


def convert_to_datetime(df):
"""
Covert a scalar or a series to datetime type
"""
if is_frame(df):
df = df.dt
else:
df = pd.to_datetime(df)
return df


class Pluggable:
"""
Helper class for everything which can be extended by plugins.
Expand Down
58 changes: 57 additions & 1 deletion tests/integration/test_rex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime

import pytest
import numpy as np
import pandas as pd
import dask.dataframe as dd
Expand Down Expand Up @@ -432,7 +433,32 @@ def test_date_functions(c):
EXTRACT(QUARTER FROM d) AS "quarter",
EXTRACT(SECOND FROM d) AS "second",
EXTRACT(WEEK FROM d) AS "week",
EXTRACT(YEAR FROM d) AS "year"
EXTRACT(YEAR FROM d) AS "year",
LAST_DAY(d) as "last_day",
TIMESTAMPADD(YEAR, 2, d) as "plus_1_year",
TIMESTAMPADD(MONTH, 1, d) as "plus_1_month",
TIMESTAMPADD(WEEK, 1, d) as "plus_1_week",
TIMESTAMPADD(DAY, 1, d) as "plus_1_day",
TIMESTAMPADD(HOUR, 1, d) as "plus_1_hour",
TIMESTAMPADD(MINUTE, 1, d) as "plus_1_min",
TIMESTAMPADD(SECOND, 1, d) as "plus_1_sec",
TIMESTAMPADD(MICROSECOND, 1000, d) as "plus_1000_millisec",
TIMESTAMPADD(QUARTER, 1, d) as "plus_1_qt",
CEIL(d TO DAY) as ceil_to_day,
CEIL(d TO HOUR) as ceil_to_hour,
CEIL(d TO MINUTE) as ceil_to_minute,
CEIL(d TO SECOND) as ceil_to_seconds,
CEIL(d TO MILLISECOND) as ceil_to_millisec,
FLOOR(d TO DAY) as floor_to_day,
FLOOR(d TO HOUR) as floor_to_hour,
FLOOR(d TO MINUTE) as floor_to_minute,
FLOOR(d TO SECOND) as floor_to_seconds,
FLOOR(d TO MILLISECOND) as floor_to_millisec
FROM df
"""
).compute()
Expand All @@ -454,7 +480,37 @@ def test_date_functions(c):
"second": [42],
"week": [39],
"year": [2021],
"last_day": [datetime(2021, 10, 31, 15, 53, 42, 47)],
"plus_1_year": [datetime(2023, 10, 3, 15, 53, 42, 47)],
"plus_1_month": [datetime(2021, 11, 3, 15, 53, 42, 47)],
"plus_1_week": [datetime(2021, 10, 10, 15, 53, 42, 47)],
"plus_1_day": [datetime(2021, 10, 4, 15, 53, 42, 47)],
"plus_1_hour": [datetime(2021, 10, 3, 16, 53, 42, 47)],
"plus_1_min": [datetime(2021, 10, 3, 15, 54, 42, 47)],
"plus_1_sec": [datetime(2021, 10, 3, 15, 53, 43, 47)],
"plus_1000_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1047)],
"plus_1_qt": [datetime(2022, 1, 3, 15, 53, 42, 47)],
"ceil_to_day": [datetime(2021, 10, 4)],
"ceil_to_hour": [datetime(2021, 10, 3, 16)],
"ceil_to_minute": [datetime(2021, 10, 3, 15, 54)],
"ceil_to_seconds": [datetime(2021, 10, 3, 15, 53, 43)],
"ceil_to_millisec": [datetime(2021, 10, 3, 15, 53, 42, 1000)],
"floor_to_day": [datetime(2021, 10, 3)],
"floor_to_hour": [datetime(2021, 10, 3, 15)],
"floor_to_minute": [datetime(2021, 10, 3, 15, 53)],
"floor_to_seconds": [datetime(2021, 10, 3, 15, 53, 42)],
"floor_to_millisec": [datetime(2021, 10, 3, 15, 53, 42)],
}
)

assert_frame_equal(df, expected_df, check_dtype=False)

# test exception handling
with pytest.raises(NotImplementedError):
df = c.sql(
"""
SELECT
FLOOR(d TO YEAR) as floor_to_year
FROM df
"""
).compute()
15 changes: 15 additions & 0 deletions tests/unit/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,18 @@ def test_dates():
assert op("SECOND", date) == 42
assert op("WEEK", date) == 39
assert op("YEAR", date) == 2021

ceil_op = call.CeilFloorOperation("ceil")
floor_op = call.CeilFloorOperation("floor")

assert ceil_op(date, "DAY") == datetime(2021, 10, 4)
assert ceil_op(date, "HOUR") == datetime(2021, 10, 3, 16)
assert ceil_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 54)
assert ceil_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 43)
assert ceil_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42, 1000)

assert floor_op(date, "DAY") == datetime(2021, 10, 3)
assert floor_op(date, "HOUR") == datetime(2021, 10, 3, 15)
assert floor_op(date, "MINUTE") == datetime(2021, 10, 3, 15, 53)
assert floor_op(date, "SECOND") == datetime(2021, 10, 3, 15, 53, 42)
assert floor_op(date, "MILLISECOND") == datetime(2021, 10, 3, 15, 53, 42)

0 comments on commit 6fa88ed

Please sign in to comment.