diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index c71ad75f..4d81b98b 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -493,6 +493,7 @@ def __init__( ] = None, caching_permitted: bool = True, decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None, + tfds_builder_kwargs: Optional[dict[str, Any]] = None, ): """TfdsTask constructor. @@ -514,6 +515,9 @@ def __init__( Default True. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to + the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and + `tfds.builder()`. """ if splits and not isinstance(splits, dict): splits = {k: k for k in splits} @@ -523,6 +527,7 @@ def __init__( data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None, decoders=decoders, + builder_kwargs=tfds_builder_kwargs, ) # If splits are not provided, we pass an empty tuple and use the lazy diff --git a/seqio/utils.py b/seqio/utils.py index 68d2cb15..603eae78 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -127,6 +127,7 @@ def __init__( data_dir: Optional[str] = None, split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None, decoders=None, + builder_kwargs: Optional[dict[str, Any]] = None, ): """LazyTfdsLoader constructor. @@ -140,12 +141,16 @@ def __init__( split='train')`). If `TfdsSplit` are used then `name` must be empty. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + builder_kwargs: `dict` (optional), keyword arguments to be passed to the + `tfds.core.DatasetBuilder` constructor through `tfds.load()` and + `tfds.builder()`. """ _validate_tfds_name(name) self._name = name self._data_dir = data_dir self._split_map = split_map self._decoders = decoders + self._builder_kwargs = builder_kwargs self._is_custom_split_map = False if split_map: @@ -302,8 +307,16 @@ def _get_builder(self, split: Optional[str] = None): builder_key = self._get_builder_key(dataset, data_dir) if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS: if dataset: - builder = tfds.builder(dataset, data_dir=data_dir) + builder_kwargs = self._builder_kwargs if self._builder_kwargs else {} + builder = tfds.builder( + dataset, data_dir=data_dir, **builder_kwargs + ) else: + if self._builder_kwargs: + raise ValueError( + "`builder_kwargs` should be empty when `dataset` value is not" + " present." + ) builder = tfds.builder_from_directory(data_dir) LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] @@ -374,6 +387,7 @@ def load( try_gcs=True, read_config=read_config, decoders=self._decoders, + builder_kwargs=self._builder_kwargs, ) def load_shard( diff --git a/seqio/utils_test.py b/seqio/utils_test.py index 5154e06f..a7ca73f9 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -178,6 +178,7 @@ def test_split_map(self, mock_tfds_load): try_gcs=True, read_config=AnyArg(), decoders=None, + builder_kwargs=None, ) # test .size() @@ -238,6 +239,7 @@ def test_tfds_split(self, mock_tfds_load): try_gcs=True, read_config=AnyArg(), decoders=None, + builder_kwargs=None, ) # test .size()