Skip to content

Commit

Permalink
feat(datafusion): add temporal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Aug 13, 2023
1 parent 57ea7a1 commit 6be6c2b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 14 deletions.
111 changes: 111 additions & 0 deletions ibis/backends/datafusion/compiler.py
Expand Up @@ -885,3 +885,114 @@ def join(op, **kw):
)

return left.join(right, join_keys=(left_keys, right_keys), how=how)


@translate.register(ops.ExtractYear)
def extract_year(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("year"), arg)


@translate.register(ops.ExtractMonth)
def extract_month(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("month"), arg)


@translate.register(ops.ExtractDay)
def extract_day(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("day"), arg)


@translate.register(ops.ExtractQuarter)
def extract_quarter(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("quarter"), arg)


@translate.register(ops.ExtractMinute)
def extract_minute(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("minute"), arg)


@translate.register(ops.ExtractHour)
def extract_hour(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("hour"), arg)


@translate.register(ops.ExtractMillisecond)
def extract_millisecond(op, **kw):
def ms(array: pa.Array) -> pa.Array:
return pc.cast(pc.millisecond(array), pa.int32())

extract_milliseconds_udf = df.udf(
ms,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_milliseconds_udf",
)
arg = translate(op.arg, **kw)
return extract_milliseconds_udf(arg)


@translate.register(ops.ExtractSecond)
def extract_second(op, **kw):
def s(array: pa.Array) -> pa.Array:
return pc.cast(pc.second(array), pa.int32())

extract_seconds_udf = df.udf(
s,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_seconds_udf",
)
arg = translate(op.arg, **kw)
return extract_seconds_udf(arg)


@translate.register(ops.ExtractDayOfYear)
def extract_day_of_the_year(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("doy"), arg)


@translate.register(ops.DayOfWeekIndex)
def extract_day_of_the_week_index(op, **kw):
arg = translate(op.arg, **kw)
return (df.functions.date_part(df.literal("dow"), arg) + df.lit(6)) % df.lit(7)


@translate.register(ops.DayOfWeekName)
def extract_down(op, **kw):
def down(array: pa.Array) -> pa.Array:
return pc.choose(
pc.day_of_week(array),
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
)

extract_down_udf = df.udf(
down,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_seconds_udf",
)
arg = translate(op.arg, **kw)
return extract_down_udf(arg)


@translate.register(ops.Date)
def date(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_trunc(df.literal("day"), arg)
25 changes: 11 additions & 14 deletions ibis/backends/tests/test_temporal.py
Expand Up @@ -72,7 +72,6 @@
),
],
)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -108,7 +107,7 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn):
"second",
],
)
@pytest.mark.notimpl(["datafusion", "druid"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -251,7 +250,6 @@ def test_timestamp_extract(backend, alltypes, df, attr):
),
],
)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_timestamp_extract_literal(con, func, expected):
value = ibis.timestamp("2015-09-01 14:48:05.359")
assert con.execute(func(value).name("tmp")) == expected
Expand Down Expand Up @@ -283,7 +281,7 @@ def test_timestamp_extract_microseconds(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -1581,10 +1579,13 @@ def test_string_to_timestamp(alltypes, fmt):
param("2017-01-07", 5, "Saturday", id="saturday"),
],
)
@pytest.mark.notimpl(
["datafusion", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["mssql", "druid", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["impala"], raises=com.UnsupportedBackendType)
@pytest.mark.broken(
["datafusion"],
raises=Exception,
reason="Exception: Arrow error: Cast error: Cannot cast string to value of Date64 type",
)
def test_day_of_week_scalar(con, date, expected_index, expected_day):
expr = ibis.literal(date).cast(dt.date)
result_index = con.execute(expr.day_of_week.index().name("tmp"))
Expand All @@ -1594,9 +1595,7 @@ def test_day_of_week_scalar(con, date, expected_index, expected_day):
assert result_day.lower() == expected_day.lower()


@pytest.mark.notimpl(
["datafusion", "mssql", "oracle"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["mssql", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -1632,7 +1631,7 @@ def test_day_of_week_column(backend, alltypes, df):
),
],
)
@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -2125,9 +2124,7 @@ def test_date_column_from_iso(con, alltypes, df):
tm.assert_series_equal(golden.rename("tmp"), actual.rename("tmp"))


@pytest.mark.notimpl(
["datafusion", "druid", "oracle"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["druid", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["pyspark"],
raises=com.UnsupportedOperationError,
Expand Down

0 comments on commit 6be6c2b

Please sign in to comment.