Skip to content

Commit

Permalink
Supports builder_kwargs in TfdsDataSource
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636306489
  • Loading branch information
jimlinntu authored and SeqIO committed May 24, 2024
1 parent 5966127 commit fc65a0d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
4 changes: 4 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def __init__(
] = None,
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
tfds_builder_kwargs: Optional[dict[str, Any]] = None,
):
"""TfdsTask constructor.
Expand All @@ -514,6 +515,8 @@ def __init__(
Default True.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to
the `tfds.core.DatasetBuilder` constructor through `tfds.load()`.
"""
if splits and not isinstance(splits, dict):
splits = {k: k for k in splits}
Expand All @@ -523,6 +526,7 @@ def __init__(
data_dir=tfds_data_dir,
split_map=splits if isinstance(splits, dict) else None,
decoders=decoders,
builder_kwargs=tfds_builder_kwargs,
)

# If splits are not provided, we pass an empty tuple and use the lazy
Expand Down
15 changes: 14 additions & 1 deletion seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
data_dir: Optional[str] = None,
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
):
"""LazyTfdsLoader constructor.
Expand All @@ -140,12 +141,16 @@ def __init__(
split='train')`). If `TfdsSplit` are used then `name` must be empty.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
`tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
"""
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs

self._is_custom_split_map = False
if split_map:
Expand Down Expand Up @@ -302,8 +307,15 @@ def _get_builder(self, split: Optional[str] = None):
builder_key = self._get_builder_key(dataset, data_dir)
if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS:
if dataset:
builder = tfds.builder(dataset, data_dir=data_dir)
builder = tfds.builder(
dataset, data_dir=data_dir, **self._builder_kwargs
)
else:
if self._builder_kwargs:
raise ValueError(
"`builder_kwargs` should be empty when `dataset` value is not"
" present."
)
builder = tfds.builder_from_directory(data_dir)
LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder
return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key]
Expand Down Expand Up @@ -374,6 +386,7 @@ def load(
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)

def load_shard(
Expand Down
2 changes: 2 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_split_map(self, mock_tfds_load):
try_gcs=True,
read_config=AnyArg(),
decoders=None,
builder_kwargs=None,
)

# test .size()
Expand Down Expand Up @@ -238,6 +239,7 @@ def test_tfds_split(self, mock_tfds_load):
try_gcs=True,
read_config=AnyArg(),
decoders=None,
builder_kwargs=None,
)

# test .size()
Expand Down

0 comments on commit fc65a0d

Please sign in to comment.