Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports builder_kwargs in TfdsDataSource #745

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def __init__(
] = None,
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
tfds_builder_kwargs: Optional[dict[str, Any]] = None,
):
"""TfdsTask constructor.

Expand All @@ -514,6 +515,9 @@ def __init__(
Default True.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to
the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
"""
if splits and not isinstance(splits, dict):
splits = {k: k for k in splits}
Expand All @@ -523,6 +527,7 @@ def __init__(
data_dir=tfds_data_dir,
split_map=splits if isinstance(splits, dict) else None,
decoders=decoders,
builder_kwargs=tfds_builder_kwargs,
)

# If splits are not provided, we pass an empty tuple and use the lazy
Expand Down
16 changes: 15 additions & 1 deletion seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
data_dir: Optional[str] = None,
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
):
"""LazyTfdsLoader constructor.

Expand All @@ -140,12 +141,16 @@ def __init__(
split='train')`). If `TfdsSplit` are used then `name` must be empty.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
`tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
"""
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs

self._is_custom_split_map = False
if split_map:
Expand Down Expand Up @@ -302,8 +307,16 @@ def _get_builder(self, split: Optional[str] = None):
builder_key = self._get_builder_key(dataset, data_dir)
if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS:
if dataset:
builder = tfds.builder(dataset, data_dir=data_dir)
builder_kwargs = self._builder_kwargs if self._builder_kwargs else {}
builder = tfds.builder(
dataset, data_dir=data_dir, **builder_kwargs
)
else:
if self._builder_kwargs:
raise ValueError(
"`builder_kwargs` should be empty when `dataset` value is not"
" present."
)
builder = tfds.builder_from_directory(data_dir)
LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder
return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key]
Expand Down Expand Up @@ -374,6 +387,7 @@ def load(
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)

def load_shard(
Expand Down
2 changes: 2 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_split_map(self, mock_tfds_load):
try_gcs=True,
read_config=AnyArg(),
decoders=None,
builder_kwargs=None,
)

# test .size()
Expand Down Expand Up @@ -238,6 +239,7 @@ def test_tfds_split(self, mock_tfds_load):
try_gcs=True,
read_config=AnyArg(),
decoders=None,
builder_kwargs=None,
)

# test .size()
Expand Down
Loading