Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes naive datetime bug in incremental #1020

Merged
merged 7 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dlt/common/data_types/type_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dlt.common.data_types.typing import TDataType
from dlt.common.time import (
ensure_pendulum_datetime,
parse_iso_like_datetime,
ensure_pendulum_date,
ensure_pendulum_time,
)
Expand Down
9 changes: 5 additions & 4 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, date # noqa: I251
from pendulum.tz import UTC
from typing import Any, Tuple, Optional, Union, Callable, Iterable, Iterator, Sequence, Tuple

from dlt import version
Expand Down Expand Up @@ -314,21 +315,21 @@ def is_arrow_item(item: Any) -> bool:
return isinstance(item, (pyarrow.Table, pyarrow.RecordBatch))


def to_arrow_compute_input(value: Any, arrow_type: pyarrow.DataType) -> Any:
def to_arrow_scalar(value: Any, arrow_type: pyarrow.DataType) -> Any:
"""Converts python value to an arrow compute friendly version"""
return pyarrow.scalar(value, type=arrow_type)


def from_arrow_compute_output(arrow_value: pyarrow.Scalar) -> Any:
"""Converts arrow scalar into Python type. Currently adds "UTC" to naive date times."""
def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any:
"""Converts arrow scalar into Python type. Currently adds "UTC" to naive date times and converts all others to UTC"""
row_value = arrow_value.as_py()
# dates are not represented as datetimes but I see connector-x represents
# datetimes as dates and keeping the exact time inside. probably a bug
# but can be corrected this way
if isinstance(row_value, date) and not isinstance(row_value, datetime):
row_value = pendulum.from_timestamp(arrow_value.cast(pyarrow.int64()).as_py() / 1000)
elif isinstance(row_value, datetime):
row_value = pendulum.instance(row_value)
row_value = pendulum.instance(row_value).in_tz("UTC")
return row_value


Expand Down
9 changes: 6 additions & 3 deletions dlt/common/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def timestamp_before(timestamp: float, max_inclusive: Optional[float]) -> bool:


def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]:
# we use internal pendulum parse function. the generic function, for example, parses string "now" as now()
# it also tries to parse ISO intervals but the code is very low quality
"""Parses ISO8601 string into pendulum datetime, date or time. Preserves timezone info.
Note: naive datetimes will generated from string without timezone

we use internal pendulum parse function. the generic function, for example, parses string "now" as now()
it also tries to parse ISO intervals but the code is very low quality
"""
# only iso dates are allowed
dtv = None
with contextlib.suppress(ValueError):
Expand All @@ -57,7 +60,7 @@ def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Dat
if isinstance(dtv, datetime.time):
return pendulum.time(dtv.hour, dtv.minute, dtv.second, dtv.microsecond)
if isinstance(dtv, datetime.datetime):
return pendulum.instance(dtv)
return pendulum.instance(dtv, tz=dtv.tzinfo)
if isinstance(dtv, pendulum.Duration):
raise ValueError("Interval ISO 8601 not supported: " + value)
return pendulum.date(dtv.year, dtv.month, dtv.day) # type: ignore[union-attr]
Expand Down
10 changes: 7 additions & 3 deletions dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ def update_dict_with_prune(dest: DictStrAny, update: StrAny) -> None:
del dest[k]


def update_dict_nested(dst: TDict, src: StrAny) -> TDict:
"""Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit."""
def update_dict_nested(dst: TDict, src: StrAny, keep_dst_values: bool = False) -> TDict:
"""Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit.
Optionally (`keep_dst_values`) you can keep the `dst` value on conflict
"""
# based on https://github.com/clarketm/mergedeep/blob/master/mergedeep/mergedeep.py

def _is_recursive_merge(a: StrAny, b: StrAny) -> bool:
Expand All @@ -290,7 +292,9 @@ def _is_recursive_merge(a: StrAny, b: StrAny) -> bool:
# If a key exists in both objects and the values are `same`, the value from the `dst` object will be used.
pass
else:
dst[key] = src[key]
if not keep_dst_values:
# if not keep then overwrite
dst[key] = src[key]
else:
# If the key exists only in `src`, the value from the `src` object will be used.
dst[key] = src[key]
Expand Down
16 changes: 15 additions & 1 deletion dlt/extract/extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import copy
from typing import Set, Dict, Any, Optional, Set

from dlt.common import logger
from dlt.common.configuration.inject import with_config
from dlt.common.configuration.specs import BaseConfiguration, configspec
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
Expand Down Expand Up @@ -289,9 +290,22 @@ def _compute_table(self, resource: DltResource, items: TDataItems) -> TPartialTa
arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(items.schema)
# normalize arrow table before merging
arrow_table = self.schema.normalize_table_identifiers(arrow_table)
# issue warnings when overriding computed with arrow
for col_name, column in arrow_table["columns"].items():
if src_column := computed_table["columns"].get(col_name):
print(src_column)
for hint_name, hint in column.items():
if (src_hint := src_column.get(hint_name)) is not None:
if src_hint != hint:
logger.warning(
f"In resource: {resource.name}, when merging arrow schema on column"
f" {col_name}. The hint {hint_name} value {src_hint} defined in"
f" resource is overwritten from arrow with value {hint}."
)

# we must override the columns to preserve the order in arrow table
arrow_table["columns"] = update_dict_nested(
arrow_table["columns"], computed_table["columns"]
arrow_table["columns"], computed_table["columns"], keep_dst_values=True
)

return arrow_table
Expand Down
13 changes: 13 additions & 0 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime # noqa: I251
from typing import Generic, ClassVar, Any, Optional, Type, Dict
from typing_extensions import get_origin, get_args
import inspect
Expand Down Expand Up @@ -213,6 +214,8 @@ def on_resolved(self) -> None:
"Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is"
" required when using 'end_value'."
)
self._cursor_datetime_check(self.initial_value, "initial_value")
self._cursor_datetime_check(self.initial_value, "end_value")
# Ensure end value is "higher" than initial value
if (
self.end_value is not None
Expand Down Expand Up @@ -281,6 +284,16 @@ def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState:
# if state params is empty
return state

@staticmethod
def _cursor_datetime_check(value: Any, arg_name: str) -> None:
if value and isinstance(value, datetime) and value.tzinfo is None:
logger.warning(
f"The {arg_name} argument {value} is a datetime without timezone. This may result"
" in an error when such values are compared by Incremental class. Note that `dlt`"
" stores datetimes in timezone-aware types so the UTC timezone will be added by"
" the destination"
)

@property
def last_value(self) -> Optional[TCursorValue]:
s = self.get_state()
Expand Down
25 changes: 15 additions & 10 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
try:
from dlt.common.libs import pyarrow
from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem
from dlt.common.libs.pyarrow import from_arrow_compute_output, to_arrow_compute_input
from dlt.common.libs.pyarrow import from_arrow_scalar, to_arrow_scalar
except MissingDependencyException:
pa = None
pyarrow = None
Expand Down Expand Up @@ -120,13 +120,17 @@ def __call__(
return row, start_out_of_range, end_out_of_range

row_value = self.find_cursor_value(row)
last_value = self.incremental_state["last_value"]

# For datetime cursor, ensure the value is a timezone aware datetime.
# The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable
if isinstance(row_value, datetime):
row_value = pendulum.instance(row_value)

last_value = self.incremental_state["last_value"]
if (
isinstance(row_value, datetime)
and row_value.tzinfo is None
and isinstance(last_value, datetime)
and last_value.tzinfo is not None
):
row_value = pendulum.instance(row_value).in_tz("UTC")

# Check whether end_value has been reached
# Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value
Expand Down Expand Up @@ -250,9 +254,10 @@ def __call__(
cursor_path = self.cursor_path
# The new max/min value
try:
row_value = from_arrow_compute_output(compute(tbl[cursor_path]))
# NOTE: datetimes are always pendulum in UTC
row_value = from_arrow_scalar(compute(tbl[cursor_path]))
cursor_data_type = tbl.schema.field(cursor_path).type
row_value_scalar = to_arrow_compute_input(row_value, cursor_data_type)
row_value_scalar = to_arrow_scalar(row_value, cursor_data_type)
except KeyError as e:
raise IncrementalCursorPathMissing(
self.resource_name,
Expand All @@ -265,7 +270,7 @@ def __call__(

# If end_value is provided, filter to include table rows that are "less" than end_value
if self.end_value is not None:
end_value_scalar = to_arrow_compute_input(self.end_value, cursor_data_type)
end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type)
tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar))
# Is max row value higher than end value?
# NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary
Expand All @@ -275,13 +280,13 @@ def __call__(
if self.start_value is not None:
# Remove rows lower than the last start value
keep_filter = last_value_compare(
tbl[cursor_path], to_arrow_compute_input(self.start_value, cursor_data_type)
tbl[cursor_path], to_arrow_scalar(self.start_value, cursor_data_type)
)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)

# Deduplicate after filtering old values
last_value_scalar = to_arrow_compute_input(last_value, cursor_data_type)
last_value_scalar = to_arrow_scalar(last_value, cursor_data_type)
tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path)
# Remove already processed rows where the cursor is equal to the last value
eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value_scalar))
Expand Down
26 changes: 16 additions & 10 deletions tests/common/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dlt.common import pendulum
from dlt.common.time import (
parse_iso_like_datetime,
timestamp_before,
timestamp_within,
ensure_pendulum_datetime,
Expand Down Expand Up @@ -40,27 +41,27 @@ def test_before() -> None:
# python datetime without tz
(
datetime(2021, 1, 1, 0, 0, 0),
pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
),
# python datetime with tz
(
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-8))),
pendulum.datetime(2021, 1, 1, 8, 0, 0).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 8, 0, 0).in_tz("UTC"),
),
# python date object
(date(2021, 1, 1), pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
(date(2021, 1, 1), pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
# pendulum datetime with tz
(
pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
),
# pendulum datetime without tz
(
pendulum.datetime(2021, 1, 1, 0, 0, 0),
pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 0, 0, 0),
pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"),
),
# iso datetime in UTC
("2021-01-01T00:00:00+00:00", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
("2021-01-01T00:00:00+00:00", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
# iso datetime with non utc tz
(
"2021-01-01T00:00:00+05:00",
Expand All @@ -69,13 +70,18 @@ def test_before() -> None:
# iso datetime without tz
(
"2021-01-01T05:02:32",
pendulum.datetime(2021, 1, 1, 5, 2, 32).in_tz("UTC"),
pendulum.DateTime(2021, 1, 1, 5, 2, 32).in_tz("UTC"),
),
# iso date
("2021-01-01", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
("2021-01-01", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")),
]


def test_parse_iso_like_datetime() -> None:
# naive datetime is still naive
assert parse_iso_like_datetime("2021-01-01T05:02:32") == pendulum.DateTime(2021, 1, 1, 5, 2, 32)


@pytest.mark.parametrize("date_value, expected", test_params)
def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.DateTime) -> None:
dt = ensure_pendulum_datetime(date_value)
Expand Down
12 changes: 12 additions & 0 deletions tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
extend_list_deduplicated,
get_exception_trace,
get_exception_trace_chain,
update_dict_nested,
)


Expand Down Expand Up @@ -277,3 +278,14 @@ def test_exception_trace_chain() -> None:
assert traces[0]["exception_type"] == "dlt.common.exceptions.PipelineException"
assert traces[1]["exception_type"] == "dlt.common.exceptions.IdentifierTooLongException"
assert traces[2]["exception_type"] == "dlt.common.exceptions.TerminalValueError"


def test_nested_dict_merge() -> None:
dict_1 = {"a": 1, "b": 2}
dict_2 = {"a": 2, "c": 4}

assert update_dict_nested(dict(dict_1), dict_2) == {"a": 2, "b": 2, "c": 4}
assert update_dict_nested(dict(dict_2), dict_1) == {"a": 1, "b": 2, "c": 4}
assert update_dict_nested(dict(dict_1), dict_2, keep_dst_values=True) == update_dict_nested(
dict_2, dict_1
)
45 changes: 45 additions & 0 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dlt.common.json import json

from dlt.extract import DltSource
from dlt.extract.exceptions import InvalidStepFunctionArguments
from dlt.sources.helpers.transform import take_first
from dlt.extract.incremental.exceptions import (
IncrementalCursorPathMissing,
Expand Down Expand Up @@ -977,11 +978,17 @@ def some_data(
"updated_at", initial_value=pendulum_start_dt
),
max_hours: int = 2,
tz: str = None,
):
data = [
{"updated_at": start_dt + timedelta(hours=hour), "hour": hour}
for hour in range(1, max_hours + 1)
]
# make sure this is naive datetime
assert data[0]["updated_at"].tzinfo is None # type: ignore[attr-defined]
if tz:
data = [{**d, "updated_at": pendulum.instance(d["updated_at"])} for d in data] # type: ignore[call-overload]

yield data_to_item_format(item_type, data)

pipeline = dlt.pipeline(pipeline_name=uniq_id())
Expand Down Expand Up @@ -1024,6 +1031,44 @@ def some_data(
== 2
)

# initial value is naive
resource = some_data(max_hours=4).with_name("copy_1") # also make new resource state
resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt))
# and the data is naive. so it will work as expected with naive datetimes in the result set
data = list(resource)
if item_type == "json":
# we do not convert data in arrow tables
assert data[0]["updated_at"].tzinfo is None

# end value is naive
resource = some_data(max_hours=4).with_name("copy_2") # also make new resource state
resource.apply_hints(
incremental=dlt.sources.incremental(
"updated_at", initial_value=start_dt, end_value=start_dt + timedelta(hours=3)
)
)
data = list(resource)
if item_type == "json":
assert data[0]["updated_at"].tzinfo is None

# now use naive initial value but data is UTC
resource = some_data(max_hours=4, tz="UTC").with_name("copy_3") # also make new resource state
resource.apply_hints(
incremental=dlt.sources.incremental(
"updated_at", initial_value=start_dt + timedelta(hours=3)
)
)
# will cause invalid comparison
if item_type == "json":
with pytest.raises(InvalidStepFunctionArguments):
list(resource)
else:
data = data_item_to_list(item_type, list(resource))
# we select two rows by adding 3 hours to start_dt. rows have hours:
# 1, 2, 3, 4
# and we select >=3
assert len(data) == 2


@dlt.resource
def endless_sequence(
Expand Down
Loading
Loading