Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This CL unifies split argument behavior between TfdsDataSource and FunctionDataSource. #283

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class FunctionDataSource(DataSource):

def __init__(self,
dataset_fn: DatasetFnCallable,
splits: Iterable[str],
splits: Union[Iterable[str], Mapping[str, str]],
num_input_examples: Optional[Mapping[str, int]] = None,
caching_permitted: bool = True):
"""FunctionDataSource constructor.
Expand All @@ -326,7 +326,8 @@ def __init__(self,
dataset_fn: a function with the signature `dataset_fn(split,
shuffle_files)' (and optionally the variable `seed`) that returns a
`tf.data.Dataset`.
splits: an iterable of applicable string split names.
splits: an iterable of applicable string split names, or a dict mapping
between splits (e.g., {'train':'dev', 'validation':'test'}).
num_input_examples: dict or None, an optional dictionary mapping split to
its size in number of input examples (before preprocessing). The
`num_input_examples` method will return None if not provided.
Expand All @@ -335,6 +336,10 @@ def __init__(self,
"""
_validate_args(dataset_fn, ["split", "shuffle_files"])
self._dataset_fn = dataset_fn
if isinstance(splits, dict):
self.split_map = splits
else:
self.split_map = None
super().__init__(
splits=splits,
num_input_examples=num_input_examples,
Expand All @@ -354,6 +359,9 @@ def get_dataset(self,
"`FunctionDataSource` does not support low-level sharding. Use "
"tf.data.Dataset.shard instead.")

if self.split_map:
split = self.split_map[split]

if seed is None:
ds = self._dataset_fn(split=split, shuffle_files=shuffle)
else:
Expand Down
16 changes: 16 additions & 0 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ def predict_metric_fn_with_types(

# pylint:enable=unused-argument

def test_function_data_source_splits(self):
def good_fn(split, shuffle_files):
del split
del shuffle_files

self.assertSameElements(["train", "validation"],
dataset_providers.FunctionDataSource(
dataset_fn=good_fn,
splits=["train", "validation"]).splits)
self.assertSameElements(["validation"],
dataset_providers.FunctionDataSource(
dataset_fn=good_fn,
splits={
"validation": "train"
}).splits)

def test_no_tfds_version(self):
with self.assertRaisesWithLiteralMatch(
ValueError, "TFDS name must contain a version number, got: fake"):
Expand Down