Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multiple partition columns and filters in to_pyarrow_dataset() and OR filters in write_datalake() #1722

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/
/.idea
.vscode
.env
/python/.venv
**/.DS_Store
**/.python-version
.coverage
Expand All @@ -27,4 +28,4 @@ Cargo.lock
!/delta-inspect/Cargo.lock
!/proofs/Cargo.lock

justfile
justfile
82 changes: 79 additions & 3 deletions python/deltalake/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from datetime import date, datetime
from typing import Any
from typing import TYPE_CHECKING, Any, Optional, cast

if TYPE_CHECKING:
from deltalake.table import FilterType


def encode_partition_value(val: Any) -> str:
Expand All @@ -10,11 +13,84 @@ def encode_partition_value(val: Any) -> str:
return val
elif isinstance(val, (int, float)):
return str(val)
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, datetime):
return val.isoformat(sep=" ")
elif isinstance(val, date):
return val.isoformat()
elif isinstance(val, bytes):
return val.decode("unicode_escape", "backslashreplace")
else:
raise ValueError(f"Could not encode partition value for type: {val}")


def validate_filters(filters: Optional["FilterType"]) -> Optional["FilterType"]:
"""Ensures that the filters are a list of list of tuples in DNF format.

:param filters: Filters to be validated

Examples:
>>> validate_filters([("a", "=", 1), ("b", "=", 2)])
>>> validate_filters([[("a", "=", 1), ("b", "=", 2)], [("c", "=", 3)]])
>>> validate_filters([("a", "=", 1)])
>>> validate_filters([[("a", "=", 1)], [("b", "=", 2)], [("c", "=", 3)]])

"""
from deltalake.table import FilterType

if filters is None:
return None

if not isinstance(filters, list) or len(filters) == 0:
raise ValueError("Filters must be a non-empty list.")

if all(isinstance(item, tuple) and len(item) == 3 for item in filters):
return cast(FilterType, [filters])

elif all(
isinstance(conjunction, list)
and len(conjunction) > 0
and all(
isinstance(literal, tuple) and len(literal) == 3 for literal in conjunction
)
for conjunction in filters
):
if len(filters) == 0 or any(len(c) == 0 for c in filters):
raise ValueError("Malformed DNF")
return filters

else:
raise ValueError(
"Filters must be a list of tuples, or a list of lists of tuples"
)


def stringify_partition_values(
partition_filters: Optional["FilterType"],
) -> Optional["FilterType"]:
if partition_filters is None:
return None

if all(isinstance(item, tuple) for item in partition_filters):
return [ # type: ignore
(
field,
op,
[encode_partition_value(val) for val in value]
if isinstance(value, (list, tuple))
else encode_partition_value(value),
)
for field, op, value in partition_filters
]
return [
[
(
field,
op,
[encode_partition_value(val) for val in value]
if isinstance(value, (list, tuple))
else encode_partition_value(value),
)
for field, op, value in sublist
]
for sublist in partition_filters
]
179 changes: 51 additions & 128 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
import operator
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import reduce
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -22,18 +20,19 @@
import pyarrow
import pyarrow.fs as pa_fs
from pyarrow.dataset import (
Expression,
FileSystemDataset,
Fragment,
ParquetFileFormat,
ParquetFragmentScanOptions,
ParquetReadOptions,
)
from pyarrow.parquet import filters_to_expression

if TYPE_CHECKING:
import pandas

from ._internal import RawDeltaTable
from ._util import encode_partition_value
from ._util import stringify_partition_values, validate_filters
from .data_catalog import DataCatalog
from .exceptions import DeltaProtocolError
from .fs import DeltaStorageHandler
Expand Down Expand Up @@ -107,96 +106,6 @@ class ProtocolVersions(NamedTuple):
FilterType = Union[FilterConjunctionType, FilterDNFType]


def _check_contains_null(value: Any) -> bool:
"""
Check if target contains nullish value.
"""
if isinstance(value, bytes):
for byte in value:
if isinstance(byte, bytes):
compare_to = chr(0)
else:
compare_to = 0
if byte == compare_to:
return True
elif isinstance(value, str):
return "\x00" in value
return False


def _check_dnf(
dnf: FilterDNFType,
check_null_strings: bool = True,
) -> FilterDNFType:
"""
Check if DNF are well-formed.
"""
if len(dnf) == 0 or any(len(c) == 0 for c in dnf):
raise ValueError("Malformed DNF")
if check_null_strings:
for conjunction in dnf:
for col, op, val in conjunction:
if (
isinstance(val, list)
and all(_check_contains_null(v) for v in val)
or _check_contains_null(val)
):
raise NotImplementedError(
"Null-terminated binary strings are not supported "
"as filter values."
)
return dnf


def _convert_single_predicate(column: str, op: str, value: Any) -> Expression:
"""
Convert given `tuple` to `pyarrow.dataset.Expression`.
"""
import pyarrow.dataset as ds

field = ds.field(column)
if op == "=" or op == "==":
return field == value
elif op == "!=":
return field != value
elif op == "<":
return field < value
elif op == ">":
return field > value
elif op == "<=":
return field <= value
elif op == ">=":
return field >= value
elif op == "in":
return field.isin(value)
elif op == "not in":
return ~field.isin(value)
else:
raise ValueError(
f'"{(column, op, value)}" is not a valid operator in predicates.'
)


def _filters_to_expression(filters: FilterType) -> Expression:
"""
Check if filters are well-formed and convert to an ``pyarrow.dataset.Expression``.
"""
if isinstance(filters[0][0], str):
# We have encountered the situation where we have one nesting level too few:
# We have [(,,), ..] instead of [[(,,), ..]]
dnf = cast(FilterDNFType, [filters])
else:
dnf = cast(FilterDNFType, filters)
dnf = _check_dnf(dnf, check_null_strings=False)
disjunction_members = []
for conjunction in dnf:
conjunction_members = [
_convert_single_predicate(col, op, val) for col, op, val in conjunction
]
disjunction_members.append(reduce(operator.and_, conjunction_members))
return reduce(operator.or_, disjunction_members)


_DNF_filter_doc = """
Predicates are expressed in disjunctive normal form (DNF), like [("x", "=", "a"), ...].
DNF allows arbitrary boolean logical combinations of single partition predicates.
Expand Down Expand Up @@ -300,7 +209,7 @@ def version(self) -> int:
def files(
self, partition_filters: Optional[List[Tuple[str, str, Any]]] = None
) -> List[str]:
return self._table.files(self.__stringify_partition_values(partition_filters))
return self._table.files(stringify_partition_values(partition_filters))

files.__doc__ = f"""
Get the .parquet files of the DeltaTable.
Expand Down Expand Up @@ -353,9 +262,7 @@ def files_by_partitions(
def file_uris(
self, partition_filters: Optional[List[Tuple[str, str, Any]]] = None
) -> List[str]:
return self._table.file_uris(
self.__stringify_partition_values(partition_filters)
)
return self._table.file_uris(stringify_partition_values(partition_filters))

file_uris.__doc__ = f"""
Get the list of files as absolute URIs, including the scheme (e.g. "s3://").
Expand Down Expand Up @@ -569,9 +476,48 @@ def restore(
)
return json.loads(metrics)

def _create_fragments(
self,
partitions: Optional[FilterType],
format: ParquetFileFormat,
filesystem: pa_fs.FileSystem,
) -> List[Fragment]:
"""Create Parquet fragments for the given partitions

:param partitions: A list of partition filters
:param format: The ParquetFileFormat to use
:param filesystem: The PyArrow FileSystem to use
"""

partition_filters: Optional[FilterType] = validate_filters(partitions)
partition_filters = stringify_partition_values(partition_filters)

fragments = []
if partition_filters is not None:
for partition in partition_filters:
partition = cast(FilterConjunctionType, partition)
for file, partition_expression in self._table.dataset_partitions(
schema=self.schema().to_pyarrow(), partition_filters=partition
):
fragments.append(
format.make_fragment(file, filesystem, partition_expression)
)
else:
for file, part_expression in self._table.dataset_partitions(
schema=self.schema().to_pyarrow(), partition_filters=partitions
):
fragments.append(
format.make_fragment(
file,
filesystem=filesystem,
partition_expression=part_expression,
)
)
return fragments

def to_pyarrow_dataset(
self,
partitions: Optional[List[Tuple[str, str, Any]]] = None,
partitions: Optional[FilterType] = None,
filesystem: Optional[Union[str, pa_fs.FileSystem]] = None,
parquet_read_options: Optional[ParquetReadOptions] = None,
) -> pyarrow.dataset.Dataset:
Expand Down Expand Up @@ -606,17 +552,6 @@ def to_pyarrow_dataset(
default_fragment_scan_options=ParquetFragmentScanOptions(pre_buffer=True),
)

fragments = [
format.make_fragment(
file,
filesystem=filesystem,
partition_expression=part_expression,
)
for file, part_expression in self._table.dataset_partitions(
self.schema().to_pyarrow(), partitions
)
]

schema = self.schema().to_pyarrow()

dictionary_columns = format.read_options.dictionary_columns or set()
Expand All @@ -628,11 +563,13 @@ def to_pyarrow_dataset(
)
schema = schema.set(index, dict_field)

fragments = self._create_fragments(partitions, format, filesystem)

return FileSystemDataset(fragments, schema, format, filesystem)

def to_pyarrow_table(
self,
partitions: Optional[List[Tuple[str, str, Any]]] = None,
partitions: Optional[FilterType] = None,
columns: Optional[List[str]] = None,
filesystem: Optional[Union[str, pa_fs.FileSystem]] = None,
filters: Optional[FilterType] = None,
Expand All @@ -646,14 +583,15 @@ def to_pyarrow_table(
:param filters: A disjunctive normal form (DNF) predicate for filtering rows. If you pass a filter you do not need to pass ``partitions``
"""
if filters is not None:
filters = _filters_to_expression(filters)
filters = validate_filters(filters)
filters = filters_to_expression(filters)
return self.to_pyarrow_dataset(
partitions=partitions, filesystem=filesystem
).to_table(columns=columns, filter=filters)

def to_pandas(
self,
partitions: Optional[List[Tuple[str, str, Any]]] = None,
partitions: Optional[FilterType] = None,
columns: Optional[List[str]] = None,
filesystem: Optional[Union[str, pa_fs.FileSystem]] = None,
filters: Optional[FilterType] = None,
Expand Down Expand Up @@ -683,21 +621,6 @@ def update_incremental(self) -> None:
def create_checkpoint(self) -> None:
self._table.create_checkpoint()

def __stringify_partition_values(
self, partition_filters: Optional[List[Tuple[str, str, Any]]]
) -> Optional[List[Tuple[str, str, Union[str, List[str]]]]]:
if partition_filters is None:
return partition_filters
out = []
for field, op, value in partition_filters:
str_value: Union[str, List[str]]
if isinstance(value, (list, tuple)):
str_value = [encode_partition_value(val) for val in value]
else:
str_value = encode_partition_value(value)
out.append((field, op, str_value))
return out

def get_add_actions(self, flatten: bool = False) -> pyarrow.RecordBatch:
"""
Return a dataframe with all current add actions.
Expand Down
Loading
Loading