Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property
from itertools import batched
from types import MappingProxyType

from pyiceberg.catalog import Catalog
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
context: DataLoaderContext | None = None,
max_attempts: int = 3,
batch_size: int | None = None,
files_per_split: int = 1,
):
"""
Args:
Expand All @@ -131,11 +133,15 @@ def __init__(
Passed to PyArrow's Scanner which produces batches of at most this many
rows. Smaller values reduce peak memory but increase per-batch overhead.
None uses the PyArrow default (~131K rows).
files_per_split: Number of files each split reads concurrently.
Default is 1 (one file per split).
"""
if branch is not None and branch.strip() == "":
raise ValueError("branch must not be empty or whitespace")
if branch is not None and snapshot_id is not None:
raise ValueError("Cannot specify both branch and snapshot_id")
if files_per_split < 1:
raise ValueError("files_per_split must be at least 1")
self._catalog = catalog
self._table_id = TableIdentifier(database, table, branch)
self._snapshot_id = snapshot_id
Expand All @@ -144,6 +150,7 @@ def __init__(
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts
self._batch_size = batch_size
self._files_per_split = files_per_split

if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None:
apply_libhdfs_opts(self._context.jvm_config.planner_args)
Expand Down Expand Up @@ -260,9 +267,9 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
lambda: scan.plan_files(), label=f"plan_files {self._table_id}", max_attempts=self._max_attempts
)

for scan_task in scan_tasks:
for chunk in batched(scan_tasks, self._files_per_split):
yield DataLoaderSplit(
file_scan_task=scan_task,
file_scan_tasks=chunk,
scan_context=scan_context,
transform_sql=optimized_sql,
udf_registry=self._context.udf_registry,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import hashlib
from collections.abc import Iterator, Mapping
from collections.abc import Iterator, Mapping, Sequence
from types import MappingProxyType

from datafusion.context import SessionContext
Expand Down Expand Up @@ -45,17 +45,19 @@ def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch:


class DataLoaderSplit:
"""A single data split"""
"""A data split that reads one or more files."""

def __init__(
self,
file_scan_task: FileScanTask,
file_scan_tasks: Sequence[FileScanTask],
scan_context: TableScanContext,
transform_sql: str | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._file_scan_tasks = list(file_scan_tasks)
if not self._file_scan_tasks:
raise ValueError("file_scan_tasks must not be empty")
self._scan_context = scan_context
self._transform_sql = transform_sql
self._udf_registry = udf_registry or NoOpRegistry()
Expand All @@ -66,20 +68,19 @@ def id(self) -> str:
"""Unique ID for the split. This is stable across executions for a given
snapshot and split size.
"""
file_path = self._file_scan_task.file.file_path
return hashlib.sha256(file_path.encode("utf-8")).hexdigest()
paths = sorted(t.file.file_path for t in self._file_scan_tasks)
combined = "\0".join(paths)
return hashlib.sha256(combined.encode("utf-8")).hexdigest()

@property
def table_properties(self) -> Mapping[str, str]:
"""Properties of the table being loaded"""
return MappingProxyType(self._scan_context.table_metadata.properties)

def __iter__(self) -> Iterator[RecordBatch]:
"""Reads the file scan task and yields Arrow RecordBatches.
"""Reads the file scan tasks and yields Arrow RecordBatches.

Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution,
delete files, and partition spec lookups. The number of batches loaded
into memory at once is bounded to prevent using too much memory at once.
When the split contains multiple files, they are read concurrently.
"""
ctx = self._scan_context
if ctx.worker_jvm_args is not None:
Expand All @@ -92,8 +93,8 @@ def __iter__(self) -> Iterator[RecordBatch]:
)

batches = arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
self._file_scan_tasks,
order=ArrivalOrder(concurrent_streams=len(self._file_scan_tasks), batch_size=self._batch_size),
Comment thread
robreeves marked this conversation as resolved.
)

if self._transform_sql is None:
Expand Down
53 changes: 51 additions & 2 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,12 @@ def fake_scan(**kwargs):
# Without branch: splits come from main snapshot
main_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl"))
assert len(main_splits) == 1
assert main_splits[0]._file_scan_task.file.file_path == "main.parquet"
assert main_splits[0]._file_scan_tasks[0].file.file_path == "main.parquet"

# With branch: splits come from branch snapshot
branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch"))
assert len(branch_splits) == 1
assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet"
assert branch_splits[0]._file_scan_tasks[0].file.file_path == "branch.parquet"


# --- batch_size tests ---
Expand Down Expand Up @@ -594,6 +594,55 @@ def test_batch_size_default_is_none(tmp_path):
assert split._batch_size is None


# --- files_per_split tests ---


def _add_file_tasks(catalog, num_tasks: int) -> None:
"""Override plan_files on a catalog from _make_real_catalog to return multiple mock tasks."""
mock_table = catalog.load_table.return_value
original_scan = mock_table.scan.side_effect

def multi_file_scan(**kwargs):
scan = original_scan(**kwargs)
scan.plan_files.return_value = [
MagicMock(file=MagicMock(file_path=f"file_{i}.parquet")) for i in range(num_tasks)
]
return scan

mock_table.scan.side_effect = multi_file_scan


def test_files_per_split_groups_tasks(tmp_path):
"""files_per_split=2 groups 4 files into 2 splits of 2 files each."""
catalog = _make_real_catalog(tmp_path)
_add_file_tasks(catalog, 4)
loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=2)
splits = list(loader)

assert len(splits) == 2
for split in splits:
assert len(split._file_scan_tasks) == 2


def test_files_per_split_remainder_split(tmp_path):
"""When files don't divide evenly, the last split gets the remainder."""
catalog = _make_real_catalog(tmp_path)
_add_file_tasks(catalog, 5)
loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=3)
splits = list(loader)

assert len(splits) == 2
assert len(splits[0]._file_scan_tasks) == 3
assert len(splits[1]._file_scan_tasks) == 2


def test_files_per_split_invalid_raises():
"""files_per_split < 1 raises ValueError."""
catalog = MagicMock()
with pytest.raises(ValueError, match="files_per_split must be at least 1"):
OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", files_per_split=0)


# --- Predicate pushdown with transformer tests ---


Expand Down
29 changes: 28 additions & 1 deletion integrations/python/dataloader/tests/test_data_loader_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _create_test_split(
task = FileScanTask(data_file=data_file)

return DataLoaderSplit(
file_scan_task=task,
file_scan_tasks=[task],
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=udf_registry,
Expand Down Expand Up @@ -468,3 +468,30 @@ def test_split_batch_size_preserves_data(tmp_path):
result = pa.Table.from_batches(list(split))
assert result.num_rows == 25
assert sorted(result.column("id").to_pylist()) == list(range(25))


# --- multi-file split tests ---


def test_multi_file_split_returns_all_rows(tmp_path):
"""A split with multiple files yields rows from all files."""
schema = _BATCH_SCHEMA
table_a = pa.table({"id": pa.array([1, 2, 3], type=pa.int64())})
table_b = pa.table({"id": pa.array([4, 5, 6], type=pa.int64())})
split_a = _create_test_split(tmp_path, table_a, FileFormat.PARQUET, schema, filename="a.parquet")
split_b = _create_test_split(tmp_path, table_b, FileFormat.PARQUET, schema, filename="b.parquet")

combined = DataLoaderSplit(
file_scan_tasks=split_a._file_scan_tasks + split_b._file_scan_tasks,
scan_context=split_a._scan_context,
)
result = pa.Table.from_batches(list(combined))

assert result.num_rows == 6
assert sorted(result.column("id").to_pylist()) == [1, 2, 3, 4, 5, 6]

reversed_split = DataLoaderSplit(
file_scan_tasks=split_b._file_scan_tasks + split_a._file_scan_tasks,
scan_context=split_a._scan_context,
)
assert reversed_split.id == combined.id