diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 6ce961d72c..659b4951df 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -13,7 +13,6 @@ from dlt.common.data_types.typing import TDataType from dlt.common.time import ( ensure_pendulum_datetime, - parse_iso_like_datetime, ensure_pendulum_date, ensure_pendulum_time, ) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 31423665f7..7b9977a2cf 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,4 +1,5 @@ from datetime import datetime, date # noqa: I251 +from pendulum.tz import UTC from typing import Any, Tuple, Optional, Union, Callable, Iterable, Iterator, Sequence, Tuple from dlt import version @@ -314,13 +315,13 @@ def is_arrow_item(item: Any) -> bool: return isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)) -def to_arrow_compute_input(value: Any, arrow_type: pyarrow.DataType) -> Any: +def to_arrow_scalar(value: Any, arrow_type: pyarrow.DataType) -> Any: """Converts python value to an arrow compute friendly version""" return pyarrow.scalar(value, type=arrow_type) -def from_arrow_compute_output(arrow_value: pyarrow.Scalar) -> Any: - """Converts arrow scalar into Python type. Currently adds "UTC" to naive date times.""" +def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: + """Converts arrow scalar into Python type. Currently adds "UTC" to naive date times and converts all others to UTC""" row_value = arrow_value.as_py() # dates are not represented as datetimes but I see connector-x represents # datetimes as dates and keeping the exact time inside. probably a bug @@ -328,7 +329,7 @@ def from_arrow_compute_output(arrow_value: pyarrow.Scalar) -> Any: if isinstance(row_value, date) and not isinstance(row_value, datetime): row_value = pendulum.from_timestamp(arrow_value.cast(pyarrow.int64()).as_py() / 1000) elif isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) + row_value = pendulum.instance(row_value).in_tz("UTC") return row_value diff --git a/dlt/common/time.py b/dlt/common/time.py index c06e2e2581..d3c8f9746c 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -44,9 +44,12 @@ def timestamp_before(timestamp: float, max_inclusive: Optional[float]) -> bool: def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: - # we use internal pendulum parse function. the generic function, for example, parses string "now" as now() - # it also tries to parse ISO intervals but the code is very low quality + """Parses ISO8601 string into pendulum datetime, date or time. Preserves timezone info. + Note: naive datetimes will generated from string without timezone + we use internal pendulum parse function. the generic function, for example, parses string "now" as now() + it also tries to parse ISO intervals but the code is very low quality + """ # only iso dates are allowed dtv = None with contextlib.suppress(ValueError): @@ -57,7 +60,7 @@ def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Dat if isinstance(dtv, datetime.time): return pendulum.time(dtv.hour, dtv.minute, dtv.second, dtv.microsecond) if isinstance(dtv, datetime.datetime): - return pendulum.instance(dtv) + return pendulum.instance(dtv, tz=dtv.tzinfo) if isinstance(dtv, pendulum.Duration): raise ValueError("Interval ISO 8601 not supported: " + value) return pendulum.date(dtv.year, dtv.month, dtv.day) # type: ignore[union-attr] diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 49a425780b..9beb4e48bf 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -272,8 +272,10 @@ def update_dict_with_prune(dest: DictStrAny, update: StrAny) -> None: del dest[k] -def update_dict_nested(dst: TDict, src: StrAny) -> TDict: - """Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit.""" +def update_dict_nested(dst: TDict, src: StrAny, keep_dst_values: bool = False) -> TDict: + """Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit. + Optionally (`keep_dst_values`) you can keep the `dst` value on conflict + """ # based on https://github.com/clarketm/mergedeep/blob/master/mergedeep/mergedeep.py def _is_recursive_merge(a: StrAny, b: StrAny) -> bool: @@ -290,7 +292,9 @@ def _is_recursive_merge(a: StrAny, b: StrAny) -> bool: # If a key exists in both objects and the values are `same`, the value from the `dst` object will be used. pass else: - dst[key] = src[key] + if not keep_dst_values: + # if not keep then overwrite + dst[key] = src[key] else: # If the key exists only in `src`, the value from the `src` object will be used. dst[key] = src[key] diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index f6c3fde5d4..51a38fa714 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,6 +1,7 @@ from copy import copy from typing import Set, Dict, Any, Optional, Set +from dlt.common import logger from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, configspec from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -289,9 +290,22 @@ def _compute_table(self, resource: DltResource, items: TDataItems) -> TPartialTa arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(items.schema) # normalize arrow table before merging arrow_table = self.schema.normalize_table_identifiers(arrow_table) + # issue warnings when overriding computed with arrow + for col_name, column in arrow_table["columns"].items(): + if src_column := computed_table["columns"].get(col_name): + print(src_column) + for hint_name, hint in column.items(): + if (src_hint := src_column.get(hint_name)) is not None: + if src_hint != hint: + logger.warning( + f"In resource: {resource.name}, when merging arrow schema on column" + f" {col_name}. The hint {hint_name} value {src_hint} defined in" + f" resource is overwritten from arrow with value {hint}." + ) + # we must override the columns to preserve the order in arrow table arrow_table["columns"] = update_dict_nested( - arrow_table["columns"], computed_table["columns"] + arrow_table["columns"], computed_table["columns"], keep_dst_values=True ) return arrow_table diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 955aa12efd..699e9389ad 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -1,4 +1,5 @@ import os +from datetime import datetime # noqa: I251 from typing import Generic, ClassVar, Any, Optional, Type, Dict from typing_extensions import get_origin, get_args import inspect @@ -213,6 +214,8 @@ def on_resolved(self) -> None: "Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is" " required when using 'end_value'." ) + self._cursor_datetime_check(self.initial_value, "initial_value") + self._cursor_datetime_check(self.initial_value, "end_value") # Ensure end value is "higher" than initial value if ( self.end_value is not None @@ -281,6 +284,16 @@ def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState: # if state params is empty return state + @staticmethod + def _cursor_datetime_check(value: Any, arg_name: str) -> None: + if value and isinstance(value, datetime) and value.tzinfo is None: + logger.warning( + f"The {arg_name} argument {value} is a datetime without timezone. This may result" + " in an error when such values are compared by Incremental class. Note that `dlt`" + " stores datetimes in timezone-aware types so the UTC timezone will be added by" + " the destination" + ) + @property def last_value(self) -> Optional[TCursorValue]: s = self.get_state() diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 2fc78fe4ee..4614de908e 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -29,7 +29,7 @@ try: from dlt.common.libs import pyarrow from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem - from dlt.common.libs.pyarrow import from_arrow_compute_output, to_arrow_compute_input + from dlt.common.libs.pyarrow import from_arrow_scalar, to_arrow_scalar except MissingDependencyException: pa = None pyarrow = None @@ -120,13 +120,17 @@ def __call__( return row, start_out_of_range, end_out_of_range row_value = self.find_cursor_value(row) + last_value = self.incremental_state["last_value"] # For datetime cursor, ensure the value is a timezone aware datetime. # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable - if isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) - - last_value = self.incremental_state["last_value"] + if ( + isinstance(row_value, datetime) + and row_value.tzinfo is None + and isinstance(last_value, datetime) + and last_value.tzinfo is not None + ): + row_value = pendulum.instance(row_value).in_tz("UTC") # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value @@ -250,9 +254,10 @@ def __call__( cursor_path = self.cursor_path # The new max/min value try: - row_value = from_arrow_compute_output(compute(tbl[cursor_path])) + # NOTE: datetimes are always pendulum in UTC + row_value = from_arrow_scalar(compute(tbl[cursor_path])) cursor_data_type = tbl.schema.field(cursor_path).type - row_value_scalar = to_arrow_compute_input(row_value, cursor_data_type) + row_value_scalar = to_arrow_scalar(row_value, cursor_data_type) except KeyError as e: raise IncrementalCursorPathMissing( self.resource_name, @@ -265,7 +270,7 @@ def __call__( # If end_value is provided, filter to include table rows that are "less" than end_value if self.end_value is not None: - end_value_scalar = to_arrow_compute_input(self.end_value, cursor_data_type) + end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar)) # Is max row value higher than end value? # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary @@ -275,13 +280,13 @@ def __call__( if self.start_value is not None: # Remove rows lower than the last start value keep_filter = last_value_compare( - tbl[cursor_path], to_arrow_compute_input(self.start_value, cursor_data_type) + tbl[cursor_path], to_arrow_scalar(self.start_value, cursor_data_type) ) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) tbl = tbl.filter(keep_filter) # Deduplicate after filtering old values - last_value_scalar = to_arrow_compute_input(last_value, cursor_data_type) + last_value_scalar = to_arrow_scalar(last_value, cursor_data_type) tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) # Remove already processed rows where the cursor is equal to the last value eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value_scalar)) diff --git a/tests/common/test_time.py b/tests/common/test_time.py index 72a9098e4d..7568e84046 100644 --- a/tests/common/test_time.py +++ b/tests/common/test_time.py @@ -4,6 +4,7 @@ from dlt.common import pendulum from dlt.common.time import ( + parse_iso_like_datetime, timestamp_before, timestamp_within, ensure_pendulum_datetime, @@ -40,27 +41,27 @@ def test_before() -> None: # python datetime without tz ( datetime(2021, 1, 1, 0, 0, 0), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # python datetime with tz ( datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-8))), - pendulum.datetime(2021, 1, 1, 8, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 8, 0, 0).in_tz("UTC"), ), # python date object - (date(2021, 1, 1), pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + (date(2021, 1, 1), pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), # pendulum datetime with tz ( - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # pendulum datetime without tz ( - pendulum.datetime(2021, 1, 1, 0, 0, 0), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # iso datetime in UTC - ("2021-01-01T00:00:00+00:00", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + ("2021-01-01T00:00:00+00:00", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), # iso datetime with non utc tz ( "2021-01-01T00:00:00+05:00", @@ -69,13 +70,18 @@ def test_before() -> None: # iso datetime without tz ( "2021-01-01T05:02:32", - pendulum.datetime(2021, 1, 1, 5, 2, 32).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 5, 2, 32).in_tz("UTC"), ), # iso date - ("2021-01-01", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + ("2021-01-01", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), ] +def test_parse_iso_like_datetime() -> None: + # naive datetime is still naive + assert parse_iso_like_datetime("2021-01-01T05:02:32") == pendulum.DateTime(2021, 1, 1, 5, 2, 32) + + @pytest.mark.parametrize("date_value, expected", test_params) def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.DateTime) -> None: dt = ensure_pendulum_datetime(date_value) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 7cd8e9f1a2..456ef3cb91 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -21,6 +21,7 @@ extend_list_deduplicated, get_exception_trace, get_exception_trace_chain, + update_dict_nested, ) @@ -277,3 +278,14 @@ def test_exception_trace_chain() -> None: assert traces[0]["exception_type"] == "dlt.common.exceptions.PipelineException" assert traces[1]["exception_type"] == "dlt.common.exceptions.IdentifierTooLongException" assert traces[2]["exception_type"] == "dlt.common.exceptions.TerminalValueError" + + +def test_nested_dict_merge() -> None: + dict_1 = {"a": 1, "b": 2} + dict_2 = {"a": 2, "c": 4} + + assert update_dict_nested(dict(dict_1), dict_2) == {"a": 2, "b": 2, "c": 4} + assert update_dict_nested(dict(dict_2), dict_1) == {"a": 1, "b": 2, "c": 4} + assert update_dict_nested(dict(dict_1), dict_2, keep_dst_values=True) == update_dict_nested( + dict_2, dict_1 + ) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 6b6c7887d3..6228efea03 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -18,6 +18,7 @@ from dlt.common.json import json from dlt.extract import DltSource +from dlt.extract.exceptions import InvalidStepFunctionArguments from dlt.sources.helpers.transform import take_first from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, @@ -977,11 +978,17 @@ def some_data( "updated_at", initial_value=pendulum_start_dt ), max_hours: int = 2, + tz: str = None, ): data = [ {"updated_at": start_dt + timedelta(hours=hour), "hour": hour} for hour in range(1, max_hours + 1) ] + # make sure this is naive datetime + assert data[0]["updated_at"].tzinfo is None # type: ignore[attr-defined] + if tz: + data = [{**d, "updated_at": pendulum.instance(d["updated_at"])} for d in data] # type: ignore[call-overload] + yield data_to_item_format(item_type, data) pipeline = dlt.pipeline(pipeline_name=uniq_id()) @@ -1024,6 +1031,44 @@ def some_data( == 2 ) + # initial value is naive + resource = some_data(max_hours=4).with_name("copy_1") # also make new resource state + resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt)) + # and the data is naive. so it will work as expected with naive datetimes in the result set + data = list(resource) + if item_type == "json": + # we do not convert data in arrow tables + assert data[0]["updated_at"].tzinfo is None + + # end value is naive + resource = some_data(max_hours=4).with_name("copy_2") # also make new resource state + resource.apply_hints( + incremental=dlt.sources.incremental( + "updated_at", initial_value=start_dt, end_value=start_dt + timedelta(hours=3) + ) + ) + data = list(resource) + if item_type == "json": + assert data[0]["updated_at"].tzinfo is None + + # now use naive initial value but data is UTC + resource = some_data(max_hours=4, tz="UTC").with_name("copy_3") # also make new resource state + resource.apply_hints( + incremental=dlt.sources.incremental( + "updated_at", initial_value=start_dt + timedelta(hours=3) + ) + ) + # will cause invalid comparison + if item_type == "json": + with pytest.raises(InvalidStepFunctionArguments): + list(resource) + else: + data = data_item_to_list(item_type, list(resource)) + # we select two rows by adding 3 hours to start_dt. rows have hours: + # 1, 2, 3, 4 + # and we select >=3 + assert len(data) == 2 + @dlt.resource def endless_sequence( diff --git a/tests/libs/test_pyarrow.py b/tests/libs/test_pyarrow.py index dffda35005..68541e96e0 100644 --- a/tests/libs/test_pyarrow.py +++ b/tests/libs/test_pyarrow.py @@ -1,9 +1,17 @@ from copy import deepcopy - +from datetime import timezone, datetime, timedelta # noqa: I251 import pyarrow as pa -from dlt.common.libs.pyarrow import py_arrow_to_table_schema_columns, get_py_arrow_datatype +from dlt.common import pendulum +from dlt.common.libs.pyarrow import ( + from_arrow_scalar, + get_py_arrow_timestamp, + py_arrow_to_table_schema_columns, + get_py_arrow_datatype, + to_arrow_scalar, +) from dlt.common.destination import DestinationCapabilitiesContext + from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA @@ -49,3 +57,55 @@ def test_py_arrow_to_table_schema_columns(): # Resulting schema should match the original assert result == dlt_schema + + +def test_to_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + # print(naive_dt) + # naive datetimes are converted as UTC when time aware python objects are used + assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime( + 2021, 1, 1, 5, 2, 32 + ) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + timedelta(hours=8) + + # naive datetimes are treated like UTC + utc_dt = get_py_arrow_timestamp(6, tz="UTC") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), utc_dt + ).as_py() + assert dt_converted.utcoffset().seconds == 0 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ).as_py() + # no dst + assert dt_converted.utcoffset().seconds == 60 * 60 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + +def test_from_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt) + + # this value is like UTC + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + # and we convert to explicit UTC + assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc) + + # converts to UTC + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + sc_dt = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ) + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + assert py_dt.tzname() == "UTC" + assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 5fa656ada9..a93599831d 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -878,7 +878,6 @@ def test_pipeline_upfront_tables_two_loads( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy - print(destination_config) pipeline = destination_config.setup_pipeline( "test_pipeline_upfront_tables_two_loads", dataset_name="test_pipeline_upfront_tables_two_loads", @@ -984,6 +983,48 @@ def table_3(make_data=False): ) +# @pytest.mark.skip(reason="Finalize the test: compare some_data values to values from database") +# @pytest.mark.parametrize( +# "destination_config", +# destinations_configs(all_staging_configs=True, default_sql_configs=True, file_format=["insert_values", "jsonl", "parquet"]), +# ids=lambda x: x.name, +# ) +# def test_load_non_utc_timestamps_with_arrow(destination_config: DestinationTestConfiguration) -> None: +# """Checks if dates are stored properly and timezones are not mangled""" +# from datetime import timedelta, datetime, timezone +# start_dt = datetime.now() + +# # columns=[{"name": "Hour", "data_type": "bool"}] +# @dlt.resource(standalone=True, primary_key="Hour") +# def some_data( +# max_hours: int = 2, +# ): +# data = [ +# { +# "naive_dt": start_dt + timedelta(hours=hour), "hour": hour, +# "utc_dt": pendulum.instance(start_dt + timedelta(hours=hour)), "hour": hour, +# # tz="Europe/Berlin" +# "berlin_dt": pendulum.instance(start_dt + timedelta(hours=hour), tz=timezone(offset=timedelta(hours=-8))), "hour": hour, +# } +# for hour in range(0, max_hours) +# ] +# data = data_to_item_format("arrow", data) +# # print(py_arrow_to_table_schema_columns(data[0].schema)) +# # print(data) +# yield data + +# pipeline = destination_config.setup_pipeline( +# "test_load_non_utc_timestamps", +# dataset_name="test_load_non_utc_timestamps", +# full_refresh=True, +# ) +# info = pipeline.run(some_data()) +# # print(pipeline.default_schema.to_pretty_yaml()) +# assert_load_info(info) +# table_name = pipeline.sql_client().make_qualified_table_name("some_data") +# print(select_data(pipeline, f"SELECT * FROM {table_name}")) + + def simple_nested_pipeline( destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool ) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: diff --git a/tests/load/utils.py b/tests/load/utils.py index d8a20d5518..7b4cf72b47 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -159,7 +159,7 @@ def destinations_configs( all_buckets_filesystem_configs: bool = False, subset: Sequence[str] = (), exclude: Sequence[str] = (), - file_format: Optional[TLoaderFileFormat] = None, + file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: @@ -383,8 +383,12 @@ def destinations_configs( conf for conf in destination_configs if conf.destination not in exclude ] if file_format: + if not isinstance(file_format, Sequence): + file_format = [file_format] destination_configs = [ - conf for conf in destination_configs if conf.file_format == file_format + conf + for conf in destination_configs + if conf.file_format and conf.file_format in file_format ] if supports_merge is not None: destination_configs = [