Skip to content

Commit

Permalink
Allow Task source to be any DatasetProvider.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478037350
  • Loading branch information
broken authored and SeqIO committed Mar 15, 2023
1 parent 9986307 commit a020633
Show file tree
Hide file tree
Showing 14 changed files with 415 additions and 17 deletions.
132 changes: 116 additions & 16 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,73 @@ class ShardInfo:
num_shards: int


class DatasetProviderBase(metaclass=abc.ABCMeta):
class DatasetProvider(metaclass=abc.ABCMeta):
"""Interface for classes that provide a tf.data.Dataset."""

@property
@abc.abstractmethod
def output_features(self) -> Mapping[str, Feature]:
raise NotImplementedError

@property
@abc.abstractmethod
def splits(self) -> Sequence[str]:
raise NotImplementedError

@abc.abstractmethod
def get_dataset(
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
num_epochs: Optional[int] = 1,
) -> tf.data.Dataset:
"""Returns the requested tf.data.Dataset."""
raise NotImplementedError

@abc.abstractmethod
def num_input_examples(self, split: str) -> Optional[int]:
raise NotImplementedError

@property
@abc.abstractmethod
def caching_permitted(self) -> bool:
"""Indicates whether this dataset provider may be cached.
Caching may be prohibited for the sake of data versioning rigor or as a
matter of policy for certain datasets.
"""
return NotImplementedError

@property
@abc.abstractmethod
def supports_arbitrary_sharding(self) -> bool:
return NotImplementedError

@property
@abc.abstractmethod
def cache_dir(self) -> Optional[str]:
return NotImplementedError

@abc.abstractmethod
def list_shards(self, split: str) -> Sequence[str]:
raise NotImplementedError



class DatasetProviderBase(DatasetProvider, metaclass=abc.ABCMeta):
"""Abstract base for classes that provide a tf.data.Dataset."""

@abc.abstractproperty
@property
@abc.abstractmethod
def output_features(self) -> Mapping[str, Feature]:
raise NotImplementedError

@abc.abstractproperty
@property
@abc.abstractmethod
def splits(self) -> Sequence[str]:
raise NotImplementedError

Expand All @@ -101,6 +160,27 @@ def get_dataset(
def num_input_examples(self, split: str) -> Optional[int]:
raise NotImplementedError

@property
def caching_permitted(self) -> bool:
"""Indicates whether this dataset provider may be cached.
Caching may be prohibited for the sake of data versioning rigor or as a
matter of policy for certain datasets.
"""
return True

@property
def supports_arbitrary_sharding(self) -> bool:
return True

@property
def cache_dir(self) -> Optional[str]:
return None

def list_shards(self, split: str) -> Sequence[str]:
raise NotImplementedError



class DatasetProviderRegistry(object):
"""Base for registry of data providers.
Expand All @@ -110,8 +190,8 @@ class DatasetProviderRegistry(object):
"""

# Class variables must be defined in subclasses.
_REGISTRY: MutableMapping[str, DatasetProviderBase]
_PROVIDER_TYPE: Type[DatasetProviderBase]
_REGISTRY: MutableMapping[str, DatasetProvider]
_PROVIDER_TYPE: Type[DatasetProvider]

@classmethod
def add_provider(cls, name: str, provider):
Expand Down Expand Up @@ -228,7 +308,7 @@ def get_dataset(
class DataSource(DatasetProviderBase):
"""A `DatasetProvider` that provides raw data from an input source.
Inherits all abstract methods and properties of `DatasetProviderBase` except
Inherits all abstract methods and properties of `DatasetProvider` except
those overidden below.
"""

Expand Down Expand Up @@ -264,7 +344,7 @@ def supports_arbitrary_sharding(self) -> bool:

@property
def output_features(self) -> Mapping[str, Feature]:
"""Override unused property of `DatasetProviderBase`."""
"""Override unused property of `DatasetProvider`."""
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -887,7 +967,7 @@ class Task(DatasetProviderBase):
def __init__(
self,
name: str,
source: DataSource,
source: DatasetProvider,
output_features: Mapping[str, Feature],
preprocessors: Optional[Sequence[Callable[..., tf.data.Dataset]]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None,
Expand All @@ -899,7 +979,7 @@ def __init__(
Args:
name: a unique name for the Task.
source: a `DataSource` that provides a raw `tf.data.Dataset`.
source: a `DatasetProvider` that provides a raw `tf.data.Dataset`.
output_features: dict(str, Feature), output features of the Task to be
passed to the model. After preprocessing, examples will be validated to
ensure they include features that match this specification. Note that
Expand Down Expand Up @@ -1074,7 +1154,7 @@ def splits(self) -> Sequence[str]:
return s

@property
def source(self) -> DataSource:
def source(self) -> DatasetProvider:
return self._source

def _validate_preprocessors(self):
Expand Down Expand Up @@ -1344,7 +1424,7 @@ def get_dataset(
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
num_epochs: Optional[int] = 1,
trim_output_features: bool = True, # Unique to Task
trim_output_features: bool = True, # unique to Task & Mixture
) -> tf.data.Dataset:
"""Returns a tf.data.Dataset from cache or generated on the fly.
Expand Down Expand Up @@ -1413,12 +1493,32 @@ def get_dataset(
shard_data_source = True
shard_info = None

# If source is a Task or Mixture, we will have addional arguments to pass.
# This can be removed and added to the call when we are certain all
# DataSources inherit the full get_dataset method signature.
kwargs = {}
if isinstance(source, Task):
kwargs["sequence_length"] = sequence_length
kwargs["use_cached"] = use_cached
kwargs["shuffle_buffer_size"] = shuffle_buffer_size
kwargs["num_epochs"] = num_epochs
kwargs["trim_output_features"] = trim_output_features
elif isinstance(source, Mixture):
kwargs["sequence_length"] = sequence_length
kwargs["use_cached"] = use_cached
kwargs["num_epochs"] = num_epochs
kwargs["trim_output_features"] = trim_output_features

if shard_data_source:
ds = source.get_dataset(
split=split, shuffle=shuffle, seed=seed, shard_info=shard_info
split=split,
shuffle=shuffle,
seed=seed,
shard_info=shard_info,
**kwargs,
)
else:
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed, **kwargs)
ds = ds.shard(shard_info.num_shards, shard_info.index)

if (
Expand Down Expand Up @@ -1466,7 +1566,7 @@ def get_dataset(

def _get_cached_source(
self, split: str, file_shuffle_buffer_size: Optional[int] = None
) -> _CachedDataSource:
) -> DatasetProvider:
"""Returns a DataSource to read cached files for split."""
self.assert_cached()
file_shuffle_buffer_size = (
Expand Down Expand Up @@ -1502,7 +1602,7 @@ class TaskRegistry(DatasetProviderRegistry):
def add(
cls,
name: str,
source: DataSourceInterface,
source: DatasetProvider,
output_features: Mapping[str, Feature],
preprocessors: Optional[Sequence[Callable[..., tf.data.Dataset]]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -2146,7 +2246,7 @@ def get_dataset(

mixture_or_task = (
get_mixture_or_task(mixture_or_task_name)
if not isinstance(mixture_or_task_name, DatasetProviderBase)
if not isinstance(mixture_or_task_name, DatasetProvider)
else mixture_or_task_name
)
is_grain_task = False
Expand Down
61 changes: 61 additions & 0 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,67 @@ def test_get_dataset_no_truncation(self):
"uncached_task", use_cached=False, sequence_length=None
)

def test_task_with_task_source(self):
self.verify_task_matches_fake_datasets(
"task_with_task_source", use_cached=False
)

# Test with cache
self.verify_task_matches_fake_datasets(
"task_with_task_source", use_cached=True
)

# Test with token preprocessor.
preproc_task_with_task_source = self.task_with_task_source.replace(
preprocessors=(
test_utils.test_token_preprocessor,
)
)
self.verify_task_matches_fake_datasets(
task=preproc_task_with_task_source, token_preprocessed=True
)

# Test with token preprocessor on source.
task_with_preproc_task_source = self.task_with_task_source.source.replace(
preprocessors=(
self.DEFAULT_PREPROCESSORS + (test_utils.test_token_preprocessor,)
)
)
self.verify_task_matches_fake_datasets(
task=task_with_preproc_task_source, token_preprocessed=True
)

def test_task_with_mixture_source(self):
self.verify_task_matches_fake_datasets(
"task_with_mixture_source", use_cached=False
)

# Test with cache
self.verify_task_matches_fake_datasets(
"task_with_mixture_source", use_cached=True
)

# Test with token preprocessor.
preproc_task_with_mixture_source = self.task_with_mixture_source.replace(
preprocessors=(test_utils.test_token_preprocessor,)
)
self.verify_task_matches_fake_datasets(
task=preproc_task_with_mixture_source, token_preprocessed=True
)

# Test with token preprocessor on mixture's task source.
task_with_preproc_mixture_source = (
self.task_with_mixture_source.source.tasks[0].replace(
preprocessors=(
self.DEFAULT_PREPROCESSORS
+ (test_utils.test_token_preprocessor,)
)
)
)
self.verify_task_matches_fake_datasets(
task=task_with_preproc_mixture_source, token_preprocessed=True
)

def test_sharding(self):
for i in range(3):
self.verify_task_matches_fake_datasets(
Expand Down
2 changes: 1 addition & 1 deletion seqio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class TruncatedDatasetProvider(dp.DataSource):

def __init__(
self,
child: dp.DataSource,
child: dp.DatasetProvider,
split_sizes: Mapping[str, int],
shuffle_buffer_size: Optional[int] = None,
):
Expand Down
Empty file.
26 changes: 26 additions & 0 deletions seqio/test_data/task_with_mixture_source/info.train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"features": {
"inputs": {
"dtype": "int32",
"shape": [
null
]
},
"inputs_pretokenized": {
"dtype": "string",
"shape": []
},
"targets": {
"dtype": "int32",
"shape": [
null
]
},
"targets_pretokenized": {
"dtype": "string",
"shape": []
}
},
"num_shards": 2,
"seqio_version": "0.0.0"
}
46 changes: 46 additions & 0 deletions seqio/test_data/task_with_mixture_source/info.validation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"features": {
"id": {
"dtype": "string",
"shape": []
},
"ids": {
"dtype": "string",
"shape": [
null
]
},
"idx": {
"dtype": "int64",
"shape": []
},
"idxs": {
"dtype": "int32",
"shape": [
null
]
},
"inputs": {
"dtype": "int32",
"shape": [
null
]
},
"inputs_pretokenized": {
"dtype": "string",
"shape": []
},
"targets": {
"dtype": "int32",
"shape": [
null
]
},
"targets_pretokenized": {
"dtype": "string",
"shape": []
}
},
"num_shards": 1,
"seqio_version": "0.0.0"
}
9 changes: 9 additions & 0 deletions seqio/test_data/task_with_mixture_source/stats.train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"examples": 3,
"inputs_chars": 43,
"inputs_max_tokens": 13,
"inputs_tokens": 36,
"targets_chars": 29,
"targets_max_tokens": 6,
"targets_tokens": 18
}
Loading

0 comments on commit a020633

Please sign in to comment.