Skip to content

Commit

Permalink
fix(core): interval resolution should upcast to smallest unit
Browse files Browse the repository at this point in the history
fixes #6139
  • Loading branch information
mesejo authored and kszucs committed May 13, 2023
1 parent 79cef8e commit f7f844d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 11 deletions.
32 changes: 32 additions & 0 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,38 @@ def convert_to_offset(x):
),
],
),
param(
lambda t, _: t.timestamp_col
+ (ibis.interval(days=4) + ibis.interval(hours=2)),
lambda t, _: t.timestamp_col
+ (pd.Timedelta(days=4) + pd.Timedelta(hours=2)),
id='timestamp-add-interval-binop-different-units',
marks=[
pytest.mark.notimpl(
[
"clickhouse",
"sqlite",
"postgres",
"polars",
"mysql",
"impala",
"snowflake",
],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(
[
"bigquery",
],
raises=com.UnsupportedOperationError,
),
pytest.mark.notimpl(
["druid"],
raises=com.IbisTypeError,
reason="Given argument with datatype interval(<IntervalUnit.HOUR: 'h'>) is not implicitly castable to string",
),
],
),
param(
lambda t, _: t.timestamp_col - ibis.interval(days=17),
lambda t, _: t.timestamp_col - pd.Timedelta(days=17),
Expand Down
10 changes: 9 additions & 1 deletion ibis/expr/operations/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,17 @@ def output_dtype(self):
else arg
for arg in (self.left, self.right)
]

interval_unit_args = [
arg.output_dtype.unit
for arg in (self.left, self.right)
if arg.output_dtype.is_interval()
]

value_dtype = rlz._promote_integral_binop(integer_args, self.op)
unit = rlz._promote_interval_resolution(interval_unit_args)

return self.left.output_dtype.copy(value_type=value_dtype)
return self.left.output_dtype.copy(value_type=value_dtype, unit=unit)


@public
Expand Down
9 changes: 9 additions & 0 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import attribute, optional
from ibis.common.enums import IntervalUnit
from ibis.common.validators import (
bool_,
callable_with, # noqa: F401
Expand Down Expand Up @@ -431,6 +432,14 @@ def output_dtype(self):
return output_dtype


def _promote_interval_resolution(units: list[IntervalUnit]) -> IntervalUnit:
# Find the smallest unit present in units
for unit in reversed(IntervalUnit):
if unit in units:
return unit
raise AssertionError('unreachable')


# TODO(kszucs): it could be as simple as rlz.instance_of(ops.TableNode)
# we have a single test case testing the schema superset condition, not
# used anywhere else
Expand Down
40 changes: 30 additions & 10 deletions ibis/tests/expr/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,47 @@ def test_multiply(expr):


@pytest.mark.parametrize(
'expr',
('expr', 'expected_unit'),
[
api.interval(days=1) + api.interval(days=1),
api.interval(days=2) + api.interval(hours=4),
(
api.interval(days=1) + api.interval(days=1),
IntervalUnit('D'),
),
(
api.interval(days=2) + api.interval(hours=4),
IntervalUnit('h'),
),
(
api.interval(seconds=1) + ibis.interval(minutes=2),
IntervalUnit('s'),
),
],
)
def test_add(expr):
def test_add(expr, expected_unit):
assert isinstance(expr, ir.IntervalScalar)
assert expr.type().unit == IntervalUnit('D')
assert expr.type().unit == expected_unit


@pytest.mark.parametrize(
'expr',
('expr', 'expected_unit'),
[
api.interval(days=3) - api.interval(days=1),
api.interval(days=2) - api.interval(hours=4),
(
api.interval(days=3) - api.interval(days=1),
IntervalUnit('D'),
),
(
api.interval(days=2) - api.interval(hours=4),
IntervalUnit('h'),
),
(
api.interval(minutes=2) - api.interval(seconds=1),
IntervalUnit('s'),
),
],
)
def test_subtract(expr):
def test_subtract(expr, expected_unit):
assert isinstance(expr, ir.IntervalScalar)
assert expr.type().unit == IntervalUnit('D')
assert expr.type().unit == expected_unit


@pytest.mark.parametrize(
Expand Down

0 comments on commit f7f844d

Please sign in to comment.