Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import math
import pprint
import sys
from typing import Any, Callable, TypeVar, cast
from typing import Any, Callable, Generic, Protocol, TypeVar, cast, runtime_checkable

from grain._src.core import tree_lib
from grain._src.python.dataset import base
Expand Down Expand Up @@ -306,6 +306,24 @@ def __str__(self) -> str:
)


@runtime_checkable
class BatchFn(Protocol, Generic[S, T]):
"""Custom batch function that support element spec inference.

If you need a custom batch function with `ds.batch(batch_fn=...)`, you can
implement this protocol to allow `batch` to infer the element spec of the
batched dataset. If not implemented, the output element spec will be unknown.
"""

def __call__(self, elements: Sequence[S]) -> T:
"""Batches elements."""

def output_spec(
self, input_spec: Any, batch_size: int, drop_remainder: bool
) -> Any:
"""Returns the element spec for batches produced by this function."""


def _get_batch_element_spec(
input_spec: Any,
batch_size: int,
Expand All @@ -317,6 +335,10 @@ def _get_batch_element_spec(
wrapped_batch_fn = batch_fn
if isinstance(batch_fn, functools.partial):
wrapped_batch_fn = batch_fn.func

if isinstance(wrapped_batch_fn, BatchFn):
return wrapped_batch_fn.output_spec(input_spec, batch_size, drop_remainder)

if wrapped_batch_fn is not make_batch and not isinstance(
wrapped_batch_fn, _MakeBatchParallel
):
Expand Down
33 changes: 32 additions & 1 deletion grain/_src/python/dataset/transformations/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import importlib
import sys
from typing import Any
from typing import Any, Sequence
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -57,6 +57,25 @@ def output_spec(self, input_spec: Any) -> Any:
}


class CustomBatchFn(batch.BatchFn):

def __call__(self, elements: Sequence[Any]) -> dict[str, Any]:
return {"batch": batch.make_batch(elements)}

def output_spec(
self, input_spec: Any, batch_size: int, drop_remainder: bool
) -> Any:
batch_dim = batch_size if drop_remainder else None
return {
"batch": tree_lib.map_structure(
lambda s: base.ShapeDtypeStruct(
shape=(batch_dim,) + s.shape, dtype=s.dtype
),
input_spec,
)
}


class MakeBatchTest(absltest.TestCase):

def test_batch_zero_values_error(self):
Expand Down Expand Up @@ -554,6 +573,18 @@ def test_element_spec(
self.assertEqual(spec.shape, expected_shape)
self.assertEqual(spec.dtype, np.int64)

def test_element_spec_custom_batch_fn(self):
ds = dataset.MapDataset.range(0, 10)
batch_size = 3
ds = batch.BatchMapDataset(
ds, batch_size, drop_remainder=True, batch_fn=CustomBatchFn()
)
spec = dataset.get_element_spec(ds)
self.assertEqual(
spec,
{"batch": base.ShapeDtypeStruct(shape=(batch_size,), dtype=np.int64)},
)


class BatchIterDatasetTest(parameterized.TestCase):

Expand Down
Loading