diff --git a/.gitignore b/.gitignore index 8642b9722a..0ed0a27398 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/ /.idea .vscode .env +/python/.venv **/.DS_Store **/.python-version .coverage @@ -27,4 +28,4 @@ Cargo.lock !/delta-inspect/Cargo.lock !/proofs/Cargo.lock -justfile \ No newline at end of file +justfile diff --git a/python/deltalake/_util.py b/python/deltalake/_util.py index 32b8a20a1b..744a6a2575 100644 --- a/python/deltalake/_util.py +++ b/python/deltalake/_util.py @@ -1,5 +1,8 @@ from datetime import date, datetime -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, cast + +if TYPE_CHECKING: + from deltalake.table import FilterType def encode_partition_value(val: Any) -> str: @@ -10,11 +13,84 @@ def encode_partition_value(val: Any) -> str: return val elif isinstance(val, (int, float)): return str(val) - elif isinstance(val, date): - return val.isoformat() elif isinstance(val, datetime): return val.isoformat(sep=" ") + elif isinstance(val, date): + return val.isoformat() elif isinstance(val, bytes): return val.decode("unicode_escape", "backslashreplace") else: raise ValueError(f"Could not encode partition value for type: {val}") + + +def validate_filters(filters: Optional["FilterType"]) -> Optional["FilterType"]: + """Ensures that the filters are a list of list of tuples in DNF format. + + :param filters: Filters to be validated + + Examples: + >>> validate_filters([("a", "=", 1), ("b", "=", 2)]) + >>> validate_filters([[("a", "=", 1), ("b", "=", 2)], [("c", "=", 3)]]) + >>> validate_filters([("a", "=", 1)]) + >>> validate_filters([[("a", "=", 1)], [("b", "=", 2)], [("c", "=", 3)]]) + + """ + from deltalake.table import FilterType + + if filters is None: + return None + + if not isinstance(filters, list) or len(filters) == 0: + raise ValueError("Filters must be a non-empty list.") + + if all(isinstance(item, tuple) and len(item) == 3 for item in filters): + return cast(FilterType, [filters]) + + elif all( + isinstance(conjunction, list) + and len(conjunction) > 0 + and all( + isinstance(literal, tuple) and len(literal) == 3 for literal in conjunction + ) + for conjunction in filters + ): + if len(filters) == 0 or any(len(c) == 0 for c in filters): + raise ValueError("Malformed DNF") + return filters + + else: + raise ValueError( + "Filters must be a list of tuples, or a list of lists of tuples" + ) + + +def stringify_partition_values( + partition_filters: Optional["FilterType"], +) -> Optional["FilterType"]: + if partition_filters is None: + return None + + if all(isinstance(item, tuple) for item in partition_filters): + return [ # type: ignore + ( + field, + op, + [encode_partition_value(val) for val in value] + if isinstance(value, (list, tuple)) + else encode_partition_value(value), + ) + for field, op, value in partition_filters + ] + return [ + [ + ( + field, + op, + [encode_partition_value(val) for val in value] + if isinstance(value, (list, tuple)) + else encode_partition_value(value), + ) + for field, op, value in sublist + ] + for sublist in partition_filters + ] diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 80a48f619e..8972bb7e60 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1,9 +1,7 @@ import json -import operator import warnings from dataclasses import dataclass from datetime import datetime, timedelta -from functools import reduce from pathlib import Path from typing import ( TYPE_CHECKING, @@ -22,18 +20,19 @@ import pyarrow import pyarrow.fs as pa_fs from pyarrow.dataset import ( - Expression, FileSystemDataset, + Fragment, ParquetFileFormat, ParquetFragmentScanOptions, ParquetReadOptions, ) +from pyarrow.parquet import filters_to_expression if TYPE_CHECKING: import pandas from ._internal import RawDeltaTable -from ._util import encode_partition_value +from ._util import stringify_partition_values, validate_filters from .data_catalog import DataCatalog from .exceptions import DeltaProtocolError from .fs import DeltaStorageHandler @@ -107,96 +106,6 @@ class ProtocolVersions(NamedTuple): FilterType = Union[FilterConjunctionType, FilterDNFType] -def _check_contains_null(value: Any) -> bool: - """ - Check if target contains nullish value. - """ - if isinstance(value, bytes): - for byte in value: - if isinstance(byte, bytes): - compare_to = chr(0) - else: - compare_to = 0 - if byte == compare_to: - return True - elif isinstance(value, str): - return "\x00" in value - return False - - -def _check_dnf( - dnf: FilterDNFType, - check_null_strings: bool = True, -) -> FilterDNFType: - """ - Check if DNF are well-formed. - """ - if len(dnf) == 0 or any(len(c) == 0 for c in dnf): - raise ValueError("Malformed DNF") - if check_null_strings: - for conjunction in dnf: - for col, op, val in conjunction: - if ( - isinstance(val, list) - and all(_check_contains_null(v) for v in val) - or _check_contains_null(val) - ): - raise NotImplementedError( - "Null-terminated binary strings are not supported " - "as filter values." - ) - return dnf - - -def _convert_single_predicate(column: str, op: str, value: Any) -> Expression: - """ - Convert given `tuple` to `pyarrow.dataset.Expression`. - """ - import pyarrow.dataset as ds - - field = ds.field(column) - if op == "=" or op == "==": - return field == value - elif op == "!=": - return field != value - elif op == "<": - return field < value - elif op == ">": - return field > value - elif op == "<=": - return field <= value - elif op == ">=": - return field >= value - elif op == "in": - return field.isin(value) - elif op == "not in": - return ~field.isin(value) - else: - raise ValueError( - f'"{(column, op, value)}" is not a valid operator in predicates.' - ) - - -def _filters_to_expression(filters: FilterType) -> Expression: - """ - Check if filters are well-formed and convert to an ``pyarrow.dataset.Expression``. - """ - if isinstance(filters[0][0], str): - # We have encountered the situation where we have one nesting level too few: - # We have [(,,), ..] instead of [[(,,), ..]] - dnf = cast(FilterDNFType, [filters]) - else: - dnf = cast(FilterDNFType, filters) - dnf = _check_dnf(dnf, check_null_strings=False) - disjunction_members = [] - for conjunction in dnf: - conjunction_members = [ - _convert_single_predicate(col, op, val) for col, op, val in conjunction - ] - disjunction_members.append(reduce(operator.and_, conjunction_members)) - return reduce(operator.or_, disjunction_members) - - _DNF_filter_doc = """ Predicates are expressed in disjunctive normal form (DNF), like [("x", "=", "a"), ...]. DNF allows arbitrary boolean logical combinations of single partition predicates. @@ -300,7 +209,7 @@ def version(self) -> int: def files( self, partition_filters: Optional[List[Tuple[str, str, Any]]] = None ) -> List[str]: - return self._table.files(self.__stringify_partition_values(partition_filters)) + return self._table.files(stringify_partition_values(partition_filters)) files.__doc__ = f""" Get the .parquet files of the DeltaTable. @@ -353,9 +262,7 @@ def files_by_partitions( def file_uris( self, partition_filters: Optional[List[Tuple[str, str, Any]]] = None ) -> List[str]: - return self._table.file_uris( - self.__stringify_partition_values(partition_filters) - ) + return self._table.file_uris(stringify_partition_values(partition_filters)) file_uris.__doc__ = f""" Get the list of files as absolute URIs, including the scheme (e.g. "s3://"). @@ -569,9 +476,48 @@ def restore( ) return json.loads(metrics) + def _create_fragments( + self, + partitions: Optional[FilterType], + format: ParquetFileFormat, + filesystem: pa_fs.FileSystem, + ) -> List[Fragment]: + """Create Parquet fragments for the given partitions + + :param partitions: A list of partition filters + :param format: The ParquetFileFormat to use + :param filesystem: The PyArrow FileSystem to use + """ + + partition_filters: Optional[FilterType] = validate_filters(partitions) + partition_filters = stringify_partition_values(partition_filters) + + fragments = [] + if partition_filters is not None: + for partition in partition_filters: + partition = cast(FilterConjunctionType, partition) + for file, partition_expression in self._table.dataset_partitions( + schema=self.schema().to_pyarrow(), partition_filters=partition + ): + fragments.append( + format.make_fragment(file, filesystem, partition_expression) + ) + else: + for file, part_expression in self._table.dataset_partitions( + schema=self.schema().to_pyarrow(), partition_filters=partitions + ): + fragments.append( + format.make_fragment( + file, + filesystem=filesystem, + partition_expression=part_expression, + ) + ) + return fragments + def to_pyarrow_dataset( self, - partitions: Optional[List[Tuple[str, str, Any]]] = None, + partitions: Optional[FilterType] = None, filesystem: Optional[Union[str, pa_fs.FileSystem]] = None, parquet_read_options: Optional[ParquetReadOptions] = None, ) -> pyarrow.dataset.Dataset: @@ -606,17 +552,6 @@ def to_pyarrow_dataset( default_fragment_scan_options=ParquetFragmentScanOptions(pre_buffer=True), ) - fragments = [ - format.make_fragment( - file, - filesystem=filesystem, - partition_expression=part_expression, - ) - for file, part_expression in self._table.dataset_partitions( - self.schema().to_pyarrow(), partitions - ) - ] - schema = self.schema().to_pyarrow() dictionary_columns = format.read_options.dictionary_columns or set() @@ -628,11 +563,13 @@ def to_pyarrow_dataset( ) schema = schema.set(index, dict_field) + fragments = self._create_fragments(partitions, format, filesystem) + return FileSystemDataset(fragments, schema, format, filesystem) def to_pyarrow_table( self, - partitions: Optional[List[Tuple[str, str, Any]]] = None, + partitions: Optional[FilterType] = None, columns: Optional[List[str]] = None, filesystem: Optional[Union[str, pa_fs.FileSystem]] = None, filters: Optional[FilterType] = None, @@ -646,14 +583,15 @@ def to_pyarrow_table( :param filters: A disjunctive normal form (DNF) predicate for filtering rows. If you pass a filter you do not need to pass ``partitions`` """ if filters is not None: - filters = _filters_to_expression(filters) + filters = validate_filters(filters) + filters = filters_to_expression(filters) return self.to_pyarrow_dataset( partitions=partitions, filesystem=filesystem ).to_table(columns=columns, filter=filters) def to_pandas( self, - partitions: Optional[List[Tuple[str, str, Any]]] = None, + partitions: Optional[FilterType] = None, columns: Optional[List[str]] = None, filesystem: Optional[Union[str, pa_fs.FileSystem]] = None, filters: Optional[FilterType] = None, @@ -683,21 +621,6 @@ def update_incremental(self) -> None: def create_checkpoint(self) -> None: self._table.create_checkpoint() - def __stringify_partition_values( - self, partition_filters: Optional[List[Tuple[str, str, Any]]] - ) -> Optional[List[Tuple[str, str, Union[str, List[str]]]]]: - if partition_filters is None: - return partition_filters - out = [] - for field, op, value in partition_filters: - str_value: Union[str, List[str]] - if isinstance(value, (list, tuple)): - str_value = [encode_partition_value(val) for val in value] - else: - str_value = encode_partition_value(value) - out.append((field, op, str_value)) - return out - def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch: """ Return a dataframe with all current add actions. diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index db399e857e..0bcc9cf1f5 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -16,12 +16,13 @@ Optional, Tuple, Union, + cast, ) from urllib.parse import unquote from deltalake.fs import DeltaStorageHandler -from ._util import encode_partition_value +from ._util import encode_partition_value, validate_filters if TYPE_CHECKING: import pandas as pd @@ -44,7 +45,13 @@ from ._internal import batch_distinct from ._internal import write_new_deltalake as _write_new_deltalake from .exceptions import DeltaProtocolError, TableNotFoundError -from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable +from .table import ( + MAX_SUPPORTED_WRITER_VERSION, + DeltaTable, + FilterConjunctionType, + FilterLiteralType, + FilterType, +) try: import pandas as pd # noqa: F811 @@ -66,6 +73,43 @@ class AddAction: stats: str +def _match_filter( + filter_: FilterLiteralType, partition_values: Mapping[str, Optional[str]] +) -> bool: + """Matches a filter against a partition value from AddAction instance. + + This ensures that create_write_transaction is called with a valid partition filter. + + :param filter_: A list of tuple(s) in DNF format (column, operator, value) + :param partition_values: A mapping of partition values + + Examples: + >>> _match_filter([("a", "=", 1)], {"a": 1}) + >>> _match_filter([("a", "=", 1), ("b", "=", 2)], {"a": 1, "b": 2}) + """ + column, op, value = filter_ + actual_value = partition_values.get(column) + + if op == "=" or op == "==": + return actual_value == value + elif op == "!=": + return actual_value != value + elif op == "<": + return actual_value < value + elif op == ">": + return actual_value > value + elif op == "<=": + return actual_value <= value + elif op == ">=": + return actual_value >= value + elif op == "in": + return actual_value in value + elif op == "not in": + return actual_value not in value + else: + raise ValueError(f'"{filter_}" is not a valid operator in predicates.') + + def write_deltalake( table_or_uri: Union[str, Path, DeltaTable], data: Union[ @@ -91,7 +135,7 @@ def write_deltalake( configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, storage_options: Optional[Dict[str, str]] = None, - partition_filters: Optional[List[Tuple[str, str, Any]]] = None, + partition_filters: Optional[FilterType] = None, large_dtypes: bool = False, ) -> None: """Write to a Delta Lake table @@ -173,6 +217,8 @@ def write_deltalake( filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + validated_filters = validate_filters(partition_filters) + __enforce_append_only(table=table, configuration=configuration, mode=mode) if isinstance(partition_by, str): @@ -268,12 +314,22 @@ def check_data_is_aligned_with_partition_filtering( ) -> None: if table is None: return - existed_partitions: FrozenSet[ + + allowed_partitions = set() + if validated_filters is not None: + # get_active_partitions() on the Rust side does not handle a list of + # of list of tuples, so we need to call it multiple times + for single_filter in validated_filters: + single_filter = cast(FilterConjunctionType, single_filter) + allowed_partitions.update( + table._table.get_active_partitions(single_filter) + ) + else: + allowed_partitions = table._table.get_active_partitions() + + existing_partitions: FrozenSet[ FrozenSet[Tuple[str, Optional[str]]] ] = table._table.get_active_partitions() - allowed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions(partition_filters) partition_values = pa.RecordBatch.from_arrays( [ batch.column(column_name) @@ -293,7 +349,7 @@ def check_data_is_aligned_with_partition_filtering( partition = frozenset(partition_map.items()) if ( partition not in allowed_partitions - and partition in existed_partitions + and partition in existing_partitions ): partition_repr = " ".join( f"{key}={value}" for key, value in partition_map.items() @@ -356,14 +412,54 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: storage_options, ) else: - table._table.create_write_transaction( - add_actions, - mode, - partition_by or [], - schema, - partition_filters, - ) - table.update_incremental() + if table is not None: + if validated_filters is None: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + schema, + validated_filters, + ) + table.update_incremental() + elif isinstance(validated_filters, list): + if len(validated_filters) == 1: + single_filter = validated_filters[0] + if isinstance(single_filter, list): + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + schema, + single_filter, + ) + table.update_incremental() + elif all(isinstance(x, list) for x in validated_filters): + original_add_actions = add_actions.copy() + for filter_conjunction in validated_filters: + filter_conjunction = cast( + FilterConjunctionType, filter_conjunction + ) + filtered_add_actions = [ + action + for action in original_add_actions + if all( + _match_filter(filter_, action.partition_values) + for filter_ in filter_conjunction + ) + ] + # create_write_transaction() only accepts a list of tuples + # and not a list of list of tuples (OR conjunction) + table._table.create_write_transaction( + filtered_add_actions, + mode, + partition_by or [], + schema, + filter_conjunction, + ) + table.update_incremental() + else: + raise ValueError("Invalid format for validated_filters") def __enforce_append_only( diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index b7de2ba1ae..ec22ec18f4 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,14 +1,21 @@ import os -from datetime import datetime +from datetime import date, datetime from pathlib import Path from threading import Barrier, Thread from types import SimpleNamespace +from typing import Any from unittest.mock import Mock from packaging import version +from pyarrow.parquet import filters_to_expression +from deltalake._util import ( + encode_partition_value, + stringify_partition_values, + validate_filters, +) from deltalake.exceptions import DeltaProtocolError -from deltalake.table import ProtocolVersions +from deltalake.table import FilterType, ProtocolVersions from deltalake.writer import write_deltalake try: @@ -481,6 +488,79 @@ def test_delta_table_with_filters(): ) +def test_pyarrow_dataset_partitions(): + table_path = "../rust/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + + single_partition = [("day", "=", "1")] + dataset_filtered = dt.to_pyarrow_dataset(partitions=single_partition) + data_filtered = dataset_filtered.to_table() + dataset = dt.to_pyarrow_dataset() + filter_expr = ds.field("day") == "1" + data = dataset.to_table(filter=filter_expr) + assert data_filtered.num_rows == data.num_rows + + single_partition_multiple_columns = [("month", "=", "2"), ("day", "=", "5")] + dataset_filtered = dt.to_pyarrow_dataset( + partitions=single_partition_multiple_columns + ) + data_filtered = dataset_filtered.to_table() + dataset = dt.to_pyarrow_dataset() + filter_expr = (ds.field("month") == "2") & (ds.field("day") == "5") + data = dataset.to_table(filter=filter_expr) + assert data_filtered.num_rows == data.num_rows + + multiple_partitions_single_column = [[("month", "=", "2")], [("month", "=", "4")]] + dataset_filtered = dt.to_pyarrow_dataset( + partitions=multiple_partitions_single_column + ) + data_filtered = dataset_filtered.to_table() + dataset = dt.to_pyarrow_dataset() + filter_expr = (ds.field("month") == "2") | (ds.field("month") == "4") + data = dataset.to_table(filter=filter_expr) + assert data_filtered.num_rows == data.num_rows + + multiple_partitions_multiple_columns = [ + [("year", "=", "2020"), ("month", "=", "2"), ("day", "=", "5")], + [("year", "=", "2021"), ("month", "=", "4"), ("day", "=", "5")], + [("year", "=", "2021"), ("month", "=", "3"), ("day", "=", "1")], + ] + dataset_filtered = dt.to_pyarrow_dataset( + partitions=multiple_partitions_multiple_columns + ) + data_filtered = dataset_filtered.to_table() + dataset = dt.to_pyarrow_dataset() + filter_expr = ( + ( + (ds.field("year") == "2020") + & (ds.field("month") == "2") + & (ds.field("day") == "5") + ) + | ( + (ds.field("year") == "2021") + & (ds.field("month") == "4") + & (ds.field("day") == "5") + ) + | ( + (ds.field("year") == "2021") + & (ds.field("month") == "3") + & (ds.field("day") == "1") + ) + ) + data = dataset.to_table(filter=filter_expr) + assert data_filtered.num_rows == data.num_rows + + single_partition_single_column_list = [[("year", "=", "2020")]] + dataset_filtered = dt.to_pyarrow_dataset( + partitions=single_partition_single_column_list + ) + data_filtered = dataset_filtered.to_table() + dataset = dt.to_pyarrow_dataset() + filter_expr = ds.field("year") == "2020" + data = dataset.to_table(filter=filter_expr) + assert data_filtered.num_rows == data.num_rows + + def test_writer_fails_on_protocol(): table_path = "../rust/tests/data/simple_table" dt = DeltaTable(table_path) @@ -695,3 +775,140 @@ def test_issue_1653_filter_bool_partition(tmp_path: Path): ) == 1 ) + + +@pytest.mark.parametrize( + "input_value, expected", + [ + (True, "true"), + (False, "false"), + (1, "1"), + (1.5, "1.5"), + ("string", "string"), + (date(2023, 10, 17), "2023-10-17"), + (datetime(2023, 10, 17, 12, 34, 56), "2023-10-17 12:34:56"), + (b"bytes", "bytes"), + ([True, False], ["true", "false"]), + ([1, 2], ["1", "2"]), + ([1.5, 2.5], ["1.5", "2.5"]), + (["a", "b"], ["a", "b"]), + ([date(2023, 10, 17), date(2023, 10, 18)], ["2023-10-17", "2023-10-18"]), + ( + [datetime(2023, 10, 17, 12, 34, 56), datetime(2023, 10, 18, 12, 34, 56)], + ["2023-10-17 12:34:56", "2023-10-18 12:34:56"], + ), + ([b"bytes", b"testbytes"], ["bytes", "testbytes"]), + ], +) +def test_encode_partition_value(input_value: Any, expected: str) -> None: + if isinstance(input_value, list): + assert [encode_partition_value(val) for val in input_value] == expected + else: + assert encode_partition_value(input_value) == expected + + +@pytest.mark.parametrize( + "filters, expected", + [ + ( + [("date", "=", "2023-08-25"), ("date", "=", "2023-09-05")], + (ds.field("date") == "2023-08-25") & (ds.field("date") == "2023-09-05"), + ), + ( + [[("date", "=", "2023-08-25")], [("date", "=", "2023-09-05")]], + (ds.field("date") == "2023-08-25") | (ds.field("date") == "2023-09-05"), + ), + ( + [ + ("date", ">=", "2023-08-25"), + ("date", "<", "2023-09-05"), + ("date", "not in", ["2023-09-01", "2023-09-02"]), + ], + (ds.field("date") >= "2023-08-25") + & (ds.field("date") < "2023-09-05") + & ~ds.field("date").isin(["2023-09-01", "2023-09-02"]), + ), + ([["date", "=", "2023-08-25"]], ds.field("date") == "2023-08-25"), + ([("date", "=", "2023-08-25")], ds.field("date") == "2023-08-25"), + ([("date", "=", None)], (ds.field("date") == ds.scalar(None))), + ], +) +def test_filters_to_expression(filters: FilterType, expected: ds.Expression) -> None: + result = filters_to_expression(filters) + assert result.equals(expected) + + +@pytest.mark.parametrize( + "filters, expected", + [ + ([("a", "=", 1)], [[("a", "=", 1)]]), + ([("a", "=", 1), ("b", "=", 2)], [[("a", "=", 1), ("b", "=", 2)]]), + ( + [[("a", "=", 1), ("b", "=", 2)], [("c", "=", 3)]], + [[("a", "=", 1), ("b", "=", 2)], [("c", "=", 3)]], + ), + ([[("a", "=", 1)]], [[("a", "=", 1)]]), + ( + [[("a", "=", 1)], [("b", "=", 2)], [("c", "=", 3)]], + [[("a", "=", 1)], [("b", "=", 2)], [("c", "=", 3)]], + ), + ([("a", "=", 1)], [[("a", "=", 1)]]), + ], +) +def test_validate_filters(filters: FilterType, expected: FilterType) -> None: + result = validate_filters(filters) + assert result == expected + + +# Test cases with invalid filters +@pytest.mark.parametrize( + "filters", + [ + [], + [[]], + [()], + [("a", "=", 1), []], + [[("a", "=", 1)], ()], + ], +) +def test_validate_filters_invalid(filters: FilterType) -> None: + with pytest.raises(ValueError): + validate_filters(filters) + + +@pytest.mark.parametrize( + "input_filters, expected", + [ + ([("a", "=", 1)], [("a", "=", "1")]), + ([[("a", "=", 1), ("b", "!=", 2)]], [[("a", "=", "1"), ("b", "!=", "2")]]), + ([("a", "in", [1, 2])], [("a", "in", ["1", "2"])]), + ( + [[("a", "in", [1, 2]), ("b", "not in", [3, 4])]], + [[("a", "in", ["1", "2"]), ("b", "not in", ["3", "4"])]], + ), + ([("date_col", "=", date(2022, 1, 1))], [("date_col", "=", "2022-01-01")]), + ( + [("datetime_col", "=", datetime(2022, 1, 1, 12, 34, 56))], + [("datetime_col", "=", "2022-01-01 12:34:56")], + ), + ( + [ + [ + ("date_col", "=", date(2022, 1, 1)), + ("datetime_col", "=", datetime(2022, 1, 1, 12, 34, 56)), + ] + ], + [ + [ + ("date_col", "=", "2022-01-01"), + ("datetime_col", "=", "2022-01-01 12:34:56"), + ] + ], + ), + ], +) +def test_stringify_partition_values( + input_filters: FilterType, expected: FilterType +) -> None: + result = stringify_partition_values(input_filters) + assert result == expected diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 3385e32175..8e2ff6ca68 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -746,6 +746,187 @@ def test_partition_overwrite_with_non_partitioned_data( ) +@pytest.fixture() +def sample_data_for_multiple_partitions() -> pa.Table: + data = { + "id": [1, 2, 3, 4, 5], + "account": [ + "account_a", + "account_b", + "account_a", + "account_c", + "account_b", + ], + "created_date": [ + date(2023, 8, 25), + date(2023, 9, 5), + date(2023, 9, 7), + date(2023, 9, 21), + date(2023, 10, 2), + ], + "updated_at": [ + datetime(2023, 8, 25, 0, 0, 0), + datetime(2023, 9, 5, 0, 0, 0), + datetime(2023, 9, 7, 0, 0, 0), + datetime(2023, 9, 21, 0, 0, 0), + datetime(2023, 10, 2, 0, 0, 0), + ], + "value": [10.5, 20.5, 30.5, 40.5, 50.5], + } + + return pa.Table.from_pydict(data) + + +@pytest.fixture() +def sample_data_for_overwrite_partitions( + id_values: list[int], + account_values: list[str], + created_date_values: list[date], + updated_at_values: list[datetime], + value_values: list[float], +) -> pa.Table: + data = { + "id": id_values, + "account": account_values, + "created_date": created_date_values, + "updated_at": updated_at_values, + "value": value_values, + } + + return pa.Table.from_pydict(data) + + +@pytest.mark.pandas +@pytest.mark.parametrize( + "id_values, account_values, created_date_values, updated_at_values, value_values, partition_by, partition_filters", + [ + # Test filtering for multiple dates (OR condition) + ( + [1, 3, 4], + ["account_a", "account_b", "account_a"], + [date(2023, 8, 25), date(2023, 9, 7), date(2023, 9, 21)], + [datetime.utcnow(), datetime.utcnow(), datetime.utcnow()], + [44.5, 68, 11.5], + ["created_date"], + [ + [("created_date", "=", "2023-08-25")], + [("created_date", "=", "2023-09-07")], + [("created_date", "=", "2023-09-21")], + ], + ), + # Test filtering for date or account (OR condition) + ( + [3, 4, 5], + ["account_a", "account_c", "account_b"], + [date(2023, 9, 7), date(2023, 9, 21), date(2023, 10, 2)], + [datetime.utcnow(), datetime.utcnow(), datetime.utcnow()], + [0.1, 5.2, 100], + ["created_date", "account"], + [ + [("created_date", "=", "2023-09-21")], + [("account", "in", ["account_a", "account_b"])], + ], + ), + # Test date range (AND condition) or account (OR condition) + ( + [1, 2, 3, 4, 5], + ["account_a", "account_b", "account_a", "account_c", "account_b"], + [ + date(2023, 8, 25), + date(2023, 9, 5), + date(2023, 9, 7), + date(2023, 9, 21), + date(2023, 10, 2), + ], + [ + datetime.utcnow(), + datetime.utcnow(), + datetime.utcnow(), + datetime.utcnow(), + datetime.utcnow(), + ], + [0.1, 0.2, 0.3, 0.4, 0.5], + ["created_date", "account"], + [ + [ + ("created_date", ">", "2023-08-01"), + ("created_date", "<", "2023-12-31"), + ], + [("account", "=", "account_b")], + ], + ), + # Test date and account (AND condition)) + ( + [4], + ["account_c"], + [date(2023, 9, 21)], + [datetime.utcnow()], + [352.5], + ["created_date", "account"], + [ + [ + ("created_date", "=", "2023-09-21"), + ("account", "=", "account_c"), + ], + ], + ), + ], +) +def test_overwriting_multiple_partitions( + tmp_path: pathlib.Path, + sample_data_for_multiple_partitions: pa.Table, + sample_data_for_overwrite_partitions: pa.Table, + partition_by: list[str], + partition_filters: list[list[tuple[str, str, Any]]], + id_values: list[int], + value_values: list[float], +): + # Write initial data + write_deltalake( + tmp_path, + sample_data_for_multiple_partitions, + mode="overwrite", + partition_by=partition_by, + ) + + # Append new data + write_deltalake( + tmp_path, + sample_data_for_overwrite_partitions, + mode="append", + partition_by=partition_by, + ) + + # Filter data using partition filters + delta_table = DeltaTable(tmp_path) + filtered_table = delta_table.to_pyarrow_table(partitions=partition_filters) + + # Sort data by id and updated_at, find latest record for each id + combined_df = filtered_table.to_pandas().sort_values(by=["id", "updated_at"]) + deduplicated_df = combined_df.drop_duplicates(subset=["id"], keep="last") + deduplicated_data = pa.table(deduplicated_df, schema=filtered_table.schema) + + # Overwrite data using partition filters + write_deltalake( + tmp_path, + deduplicated_data, + mode="overwrite", + partition_by=partition_by, + partition_filters=partition_filters, + ) + + delta_table = DeltaTable(tmp_path) + actual_data = delta_table.to_pyarrow_table().sort_by([("id", "ascending")]) + + for id_val, value_val in zip(id_values, value_values): + id_condition = pa.compute.equal(actual_data["id"], id_val) + value_condition = pa.compute.equal(actual_data["value"], value_val) + combined_condition = pa.compute.and_(id_condition, value_condition) + combined_list = combined_condition.to_pylist() + assert True in combined_list + assert set(actual_data["id"].to_pylist()) == set([1, 2, 3, 4, 5]) + + def test_partition_overwrite_with_wrong_partition( tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table ):