Skip to content

Commit

Permalink
refactor(datatypes): normalize interval values to integers
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `dt.Interval` has no longer a default unit, `dt.interval` is removed
  • Loading branch information
kszucs authored and cpcloud committed Jun 16, 2023
1 parent b3d9619 commit 80a40ab
Show file tree
Hide file tree
Showing 17 changed files with 211 additions and 121 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/duckdb/datatypes.py
Expand Up @@ -25,7 +25,7 @@
def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
"""Parse a DuckDB type into an ibis data type."""
primitive = (
spaceless_string("interval").result(dt.Interval())
spaceless_string("interval").result(dt.Interval('us'))
| spaceless_string("bigint", "int8", "long").result(dt.int64)
| spaceless_string("boolean", "bool", "logical").result(dt.boolean)
| spaceless_string("blob", "bytea", "binary", "varbinary").result(dt.binary)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/duckdb/tests/test_datatypes.py
Expand Up @@ -36,7 +36,7 @@
("INT4", dt.int32),
("INT", dt.int32),
("SIGNED", dt.int32),
("INTERVAL", dt.interval),
("INTERVAL", dt.Interval('us')),
("REAL", dt.float32),
("FLOAT4", dt.float32),
("FLOAT", dt.float32),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/datatypes.py
Expand Up @@ -67,7 +67,7 @@ def _get_type(typestr: str) -> dt.DataType:
"geometry": dt.geometry,
"inet": dt.inet,
"integer": dt.int32,
"interval": dt.interval,
"interval": dt.Interval('s'),
"json": dt.json,
"jsonb": dt.json,
"line": dt.linestring,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_client.py
Expand Up @@ -188,7 +188,7 @@ def test_create_and_drop_table(con, temp_table, params):
("time without time zone", dt.time),
("timestamp without time zone", dt.timestamp),
("timestamp with time zone", dt.Timestamp("UTC")),
("interval", dt.interval),
("interval", dt.Interval("s")),
("numeric", dt.decimal),
("numeric(3, 2)", dt.Decimal(3, 2)),
("uuid", dt.uuid),
Expand Down
70 changes: 11 additions & 59 deletions ibis/backends/tests/test_temporal.py
Expand Up @@ -9,7 +9,6 @@
import pandas.testing as tm
import pytest
import sqlalchemy as sa
import sqlglot
from pytest import param

import ibis
Expand Down Expand Up @@ -973,6 +972,11 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
raises=com.UnsupportedOperationError,
reason='BigQuery does not allow binary operation TIMESTAMP_ADD with INTERVAL offset D',
),
pytest.mark.broken(
["clickhouse"],
raises=AssertionError,
reason="DateTime column overflows, should use DateTime64",
),
],
),
param(
Expand Down Expand Up @@ -1007,28 +1011,10 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
),
],
),
param(
'1.5d',
plus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
['druid'],
raises=com.IbisTypeError,
reason="Given argument with datatype interval('s') is not implicitly castable to string",
),
pytest.mark.broken(
["bigquery"],
raises=GoogleBadRequest,
reason='400 Syntax error: Expected ")" but got integer literal "12" at [1:58]',
),
],
),
param(
'2h',
plus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
['druid'],
raises=com.IbisTypeError,
Expand All @@ -1045,7 +1031,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
'3m',
plus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
['druid'],
raises=com.IbisTypeError,
Expand All @@ -1062,7 +1047,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
'10s',
plus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
['druid'],
raises=com.IbisTypeError,
Expand All @@ -1089,6 +1073,11 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
raises=com.UnsupportedOperationError,
reason='BigQuery does not allow binary operation TIMESTAMP_SUB with INTERVAL offset D',
),
pytest.mark.broken(
["clickhouse"],
raises=AssertionError,
reason="DateTime column overflows, should use DateTime64",
),
],
),
param(
Expand Down Expand Up @@ -1123,28 +1112,10 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
),
],
),
param(
'1.5d',
minus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
["druid"],
raises=TypeError,
reason="unsupported operand type(s) for -: 'StringColumn' and 'Timedelta'",
),
pytest.mark.broken(
["bigquery"],
raises=GoogleBadRequest,
reason='400 Syntax error: Expected ")" but got integer literal "12" at [1:58]',
),
],
),
param(
'2h',
minus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
["druid"],
raises=TypeError,
Expand All @@ -1161,7 +1132,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
'3m',
minus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
["druid"],
raises=TypeError,
Expand All @@ -1178,7 +1148,6 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
'10s',
minus,
marks=[
pytest.mark.broken(["mysql"], raises=AssertionError),
pytest.mark.broken(
["druid"],
raises=TypeError,
Expand All @@ -1194,26 +1163,9 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
],
)
@pytest.mark.notimpl(
["datafusion", "impala", "sqlite", "mssql", "trino", "oracle"],
["datafusion", "sqlite", "mssql", "trino", "oracle"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["impala"],
raises=ImpalaHiveServer2Error,
)
@pytest.mark.notimpl(
["clickhouse"],
raises=sqlglot.errors.ParseError,
reason="Invalid expression / Unexpected token.",
)
@pytest.mark.notimpl(
["snowflake"],
raises=sa.exc.ProgrammingError,
)
@pytest.mark.broken(
["polars"],
raises=AssertionError,
)
def test_temporal_binop_pandas_timedelta(
backend, con, alltypes, df, timedelta, temporal_fn
):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/trino/datatypes.py
Expand Up @@ -93,7 +93,7 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
)

primitive = (
spaceless_string("interval").result(dt.Interval())
spaceless_string("interval").result(dt.Interval(unit='s'))
| spaceless_string("bigint").result(dt.int64)
| spaceless_string("boolean").result(dt.boolean)
| spaceless_string("varbinary").result(dt.binary)
Expand Down
73 changes: 72 additions & 1 deletion ibis/common/temporal.py
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import numbers
from abc import ABCMeta
Expand All @@ -9,7 +11,7 @@
from public import public

from ibis.common.dispatch import lazy_singledispatch
from ibis.common.validators import Coercible
from ibis.common.patterns import Coercible


class ABCEnumMeta(EnumMeta, ABCMeta):
Expand All @@ -19,6 +21,9 @@ class ABCEnumMeta(EnumMeta, ABCMeta):
class Unit(Coercible, Enum, metaclass=ABCEnumMeta):
@classmethod
def __coerce__(cls, value):
if isinstance(value, cls):
return value

# first look for aliases
value = cls.aliases().get(value, value)

Expand Down Expand Up @@ -114,6 +119,72 @@ class IntervalUnit(TemporalUnit):
NANOSECOND = "ns"


def normalize_timedelta(
value: datetime.timedelta | numbers.Real, unit: IntervalUnit
) -> datetime.timedelta:
"""Normalize a timedelta value to the given unit.
Parameters
----------
value
The value to normalize, either a timedelta or a number.
unit
The unit to normalize to.
Returns
-------
The normalized timedelta value.
Examples
--------
>>> from datetime import timedelta
>>> normalize_timedelta(1, IntervalUnit.SECOND)
1
>>> normalize_timedelta(1, IntervalUnit.DAY)
1
>>> normalize_timedelta(timedelta(days=14), IntervalUnit.WEEK)
2
>>> normalize_timedelta(timedelta(seconds=3), IntervalUnit.MILLISECOND)
3000
>>> normalize_timedelta(timedelta(seconds=3), IntervalUnit.MICROSECOND)
3000000
"""
if isinstance(value, datetime.timedelta):
# datetime.timedelta only stores days, seconds, and microseconds internally
total_seconds = value.total_seconds()
if unit == IntervalUnit.NANOSECOND:
value = total_seconds * 1e9
elif unit == IntervalUnit.MICROSECOND:
value = total_seconds * 1e6
elif unit == IntervalUnit.MILLISECOND:
value = total_seconds * 1e3
elif unit == IntervalUnit.SECOND:
value = total_seconds
elif unit == IntervalUnit.MINUTE:
value = total_seconds / 60
elif unit == IntervalUnit.HOUR:
value = total_seconds / 3600
elif unit == IntervalUnit.DAY:
value = total_seconds / 86400
elif unit == IntervalUnit.WEEK:
value = total_seconds / 604800
elif unit == IntervalUnit.MONTH:
raise ValueError("Cannot normalize a timedelta to months")
elif unit == IntervalUnit.QUARTER:
raise ValueError("Cannot normalize a timedelta to quarters")
elif unit == IntervalUnit.YEAR:
raise ValueError("Cannot normalize a timedelta to years")
else:
raise ValueError(f"Unknown unit {unit}")
else:
value = float(value)

if not value.is_integer():
raise ValueError(f"Normalizing {value} to {unit} would lose precision")

return int(value)


def normalize_timezone(tz):
if tz is None:
return None
Expand Down
64 changes: 56 additions & 8 deletions ibis/common/tests/test_temporal.py
Expand Up @@ -5,8 +5,13 @@
import pytest
import pytz

from ibis.common.temporal import IntervalUnit, normalize_datetime, normalize_timezone
from ibis.common.validators import coerced_to
from ibis.common.patterns import CoercedTo
from ibis.common.temporal import (
IntervalUnit,
normalize_datetime,
normalize_timedelta,
normalize_timezone,
)

interval_units = pytest.mark.parametrize(
["singular", "plural", "short"],
Expand Down Expand Up @@ -37,10 +42,10 @@ def test_interval_units(singular, plural, short):
@interval_units
def test_interval_unit_coercions(singular, plural, short):
u = IntervalUnit[singular.upper()]
v = coerced_to(IntervalUnit)
assert v(singular) == u
assert v(plural) == u
assert v(short) == u
v = CoercedTo(IntervalUnit)
assert v.match(singular, {}) == u
assert v.match(plural, {}) == u
assert v.match(short, {}) == u


@pytest.mark.parametrize(
Expand All @@ -56,8 +61,51 @@ def test_interval_unit_coercions(singular, plural, short):
],
)
def test_interval_unit_aliases(alias, expected):
v = coerced_to(IntervalUnit)
assert v(alias) == IntervalUnit(expected)
v = CoercedTo(IntervalUnit)
assert v.match(alias, {}) == IntervalUnit(expected)


@pytest.mark.parametrize(
("value", "unit", "expected"),
[
(1, IntervalUnit.DAY, 1),
(1, IntervalUnit.HOUR, 1),
(1, IntervalUnit.MINUTE, 1),
(1, IntervalUnit.SECOND, 1),
(1, IntervalUnit.MILLISECOND, 1),
(1, IntervalUnit.MICROSECOND, 1),
(timedelta(days=1), IntervalUnit.DAY, 1),
(timedelta(hours=1), IntervalUnit.HOUR, 1),
(timedelta(minutes=1), IntervalUnit.MINUTE, 1),
(timedelta(seconds=1), IntervalUnit.SECOND, 1),
(timedelta(milliseconds=1), IntervalUnit.MILLISECOND, 1),
(timedelta(microseconds=1), IntervalUnit.MICROSECOND, 1),
(timedelta(days=1, milliseconds=100), IntervalUnit.MILLISECOND, 86400100),
(timedelta(days=1, milliseconds=21), IntervalUnit.MICROSECOND, 86400021000),
],
)
def test_normalize_timedelta(value, unit, expected):
assert normalize_timedelta(value, unit) == expected


@pytest.mark.parametrize(
("value", "unit"),
[
(timedelta(days=1), IntervalUnit.YEAR),
(timedelta(days=1), IntervalUnit.QUARTER),
(timedelta(days=1), IntervalUnit.MONTH),
(timedelta(days=1), IntervalUnit.WEEK),
(timedelta(hours=1), IntervalUnit.DAY),
(timedelta(minutes=1), IntervalUnit.HOUR),
(timedelta(seconds=1), IntervalUnit.MINUTE),
(timedelta(milliseconds=1), IntervalUnit.SECOND),
(timedelta(microseconds=1), IntervalUnit.MILLISECOND),
(timedelta(days=1, microseconds=100), IntervalUnit.MILLISECOND),
],
)
def test_normalize_timedelta_invalid(value, unit):
with pytest.raises(ValueError):
normalize_timedelta(value, unit)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 80a40ab

Please sign in to comment.