Skip to content

Commit

Permalink
[Train] Split all Ray Datasets by default (ray-project#38694)
Browse files Browse the repository at this point in the history
Previously we only shard the "train" Ray Dataset by default. If users want to shard other datasets, they need to explicitly specify it with a `DataConfig`. e.g. `DataConfig(datasets_to_split=["train", "eval"])`.

We now change the default behavior to shard all datasets by default for the following considerations:

- Efficiency: We want people to leverage Ray Data as best as possible. The best way to optimize training time is to leverage the fact that Ray Data can effectively shard all the datasets across workers. Training frameworks (e.g. Lightning) provide ways to aggregate results across workers, and we should be recommending users to shard their validation datasets.
- Consistency: It is conceptually easier for users to understand a single default behavior applied to all Datasets and to be provided options to configure them.
- Explicitness: The behavior for the magic “train” key is not very explicit, and users will not understand this until they really read through the documentation. Relying on untyped keywords is non-ideal.

### API
- Shard all datasets(default):
```python
TorchTrainer(
    datasets={"a": ds_1, "b": ds_2, "c": ds_3},
    # data_config=DataConfig(datasets_to_split="all")
)
```

- Shard a subset of datasets
```python
TorchTrainer(
    datasets={"a": ds_1, "b": ds_2, "c": ds_3},
    data_config=DataConfig(datasets_to_split=["a", "b"])
)
```

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Co-authored-by: Eric Liang <ekhliang@gmail.com>
Co-authored-by: Cheng Su <scnju13@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
4 people authored and arvind-chandra committed Aug 31, 2023
1 parent e816d43 commit 4e35cd0
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 29 deletions.
29 changes: 19 additions & 10 deletions doc/source/train/user-guides/data-loading-preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,13 @@ Your preprocessed datasets can be passed into a Ray Train Trainer (e.g. :class:`

The datasets passed into the Trainer's ``datasets`` can be accessed inside of the ``train_loop_per_worker`` run on each distributed training worker by calling :meth:`ray.train.get_dataset_shard`.

The default splitting behavior is as follows:
All datasets are split (i.e. sharded) across the training workers by default. :meth:`~ray.train.get_dataset_shard` will return ``1/n`` of the dataset, where ``n`` is the number of training workers.

- The ``"train"`` dataset is split (i.e. sharded) across the training workers. :meth:`~ray.train.get_dataset_shard` will return ``1/n`` of the dataset, where ``n`` is the number of training workers.
- All other dataset are not split. :meth:`~ray.train.get_dataset_shard` will return the full dataset.
.. note::

Please be aware that as the evaluation dataset is split, users have to aggregate the evaluation results across workers.
You might consider using `TorchMetrics <https://torchmetrics.readthedocs.io/en/latest/>`_ (:ref:`example <deepspeed_example>`) or
utilities available in other frameworks that you can explore.

This behavior can be overwritten by passing in the ``dataset_config`` argument. For more information on configuring splitting logic, see :ref:`Splitting datasets <train-datasets-split>`.

Expand Down Expand Up @@ -298,11 +301,11 @@ For more details, see the following sections for each framework.

Splitting datasets
------------------
By default, Ray Train splits the ``"train"`` dataset across workers using :meth:`Dataset.streaming_split <ray.data.Dataset.streaming_split>`. Each worker sees a disjoint subset of the data, instead of iterating over the entire dataset. Unless randomly shuffled, the same splits are used for each iteration of the dataset.
By default, Ray Train splits all datasets across workers using :meth:`Dataset.streaming_split <ray.data.Dataset.streaming_split>`. Each worker sees a disjoint subset of the data, instead of iterating over the entire dataset. Unless randomly shuffled, the same splits are used for each iteration of the dataset.

For all other datasets, Ray Train passes the entire dataset to each worker.
If want to customize which datasets are split, pass in a :class:`DataConfig <ray.train.DataConfig>` to the Trainer constructor.

To customize this, pass in a :class:`DataConfig <ray.train.DataConfig>` to the Trainer constructor. For example, to split both the training and validation datasets, do the following:
For example, to split only the training dataset, do the following:

.. testcode::

Expand All @@ -317,18 +320,24 @@ To customize this, pass in a :class:`DataConfig <ray.train.DataConfig>` to the T
train_ds, val_ds = ds.train_test_split(0.3)

def train_loop_per_worker():
# Get an iterator to the dataset we passed in below.
it = train.get_dataset_shard("train")
# Get the sharded training dataset
train_ds = train.get_dataset_shard("train")
for _ in range(2):
for batch in it.iter_batches(batch_size=128):
for batch in train_ds.iter_batches(batch_size=128):
print("Do some training on batch", batch)

# Get the unsharded full validation dataset
val_ds = train.get_dataset_shard("val")
for _ in range(2):
for batch in val_ds.iter_batches(batch_size=128):
print("Do some evaluation on batch", batch)

my_trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2),
datasets={"train": train_ds, "val": val_ds},
dataset_config=ray.train.DataConfig(
datasets_to_split=["train", "val"],
datasets_to_split=["train"],
),
)
my_trainer.fit()
Expand Down
57 changes: 55 additions & 2 deletions python/ray/air/tests/test_new_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def test_basic(ray_start_4_cpus):
test = TestBasic(1, True, {"train": 10, "test": -1}, datasets={"train": ds})
test.fit()

# Two workers, train split.
# Two workers, train and test split.
test = TestBasic(
2, True, {"train": 5, "test": 10}, datasets={"train": ds, "test": ds}
2, True, {"train": 5, "test": 5}, datasets={"train": ds, "test": ds}
)
test.fit()

Expand All @@ -78,6 +78,59 @@ def test_basic(ray_start_4_cpus):
test.fit()


def test_split(ray_start_4_cpus):
ds = ray.data.range(10)

# Split all by default
test = TestBasic(
2,
True,
{"train": 5, "test": 5, "val": 5},
datasets={"train": ds, "test": ds, "val": ds},
)
test.fit()

# Test flag "all"
test = TestBasic(
2,
True,
{"train": 5, "test": 5},
datasets={"train": ds, "test": ds},
dataset_config=DataConfig(datasets_to_split="all"),
)

# Test split train only.
test = TestBasic(
2,
True,
{"train": 5, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=DataConfig(datasets_to_split=["train"]),
)
test.fit()

# Test invalid arguments
for datasets_to_split in ["train", ("train"), {}]:
with pytest.raises(TypeError, match="`datasets_to_split` should be.*"):
test = TestBasic(
2,
True,
{"train": 5, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=DataConfig(datasets_to_split=datasets_to_split),
)

# Test empty `datasets_to_split` list
test = TestBasic(
2,
True,
{"train": 10, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=DataConfig(datasets_to_split=[]),
)
test.fit()


@pytest.mark.skip(
reason="Incomplete implementation of _validate_dag causes other errors, so we "
"remove DAG validation for now; see https://github.com/ray-project/ray/pull/37829"
Expand Down
41 changes: 31 additions & 10 deletions python/ray/train/_internal/data_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional, Dict, List
from typing import Optional, Union, Dict, List

import ray
from ray.actor import ActorHandle
from ray.train.constants import TRAIN_DATASET_KEY

# TODO(justinvyu): Fix the circular import error
from ray.train.constants import TRAIN_DATASET_KEY # noqa
from ray.train._internal.dataset_spec import DataParallelIngestSpec
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.air.config import DatasetConfig
Expand All @@ -15,6 +17,13 @@
)
from ray.data.preprocessor import Preprocessor

import sys

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


@PublicAPI
class DataConfig:
Expand All @@ -26,21 +35,28 @@ class DataConfig:

def __init__(
self,
datasets_to_split: Optional[List[str]] = None,
datasets_to_split: Union[Literal["all"], List[str]] = "all",
execution_options: Optional[ExecutionOptions] = None,
):
"""Construct a DataConfig.
Args:
datasets_to_split: The list of dataset names to split between workers.
By default, only the "train" dataset will be split.
datasets_to_split: Specifies which datasets should be split among workers.
Can be set to "all" or a list of dataset names. Defaults to "all",
i.e. split all datasets.
execution_options: The execution options to pass to Ray Data. By default,
the options will be optimized for data ingest. When overriding this,
base your options off of `DataConfig.default_ingest_options()`.
"""
self._datasets_to_split: List[str] = (
datasets_to_split if datasets_to_split is not None else [TRAIN_DATASET_KEY]
)
if isinstance(datasets_to_split, list) or datasets_to_split == "all":
self._datasets_to_split = datasets_to_split
else:
raise TypeError(
"`datasets_to_split` should be a 'all' or a list of strings of "
"dataset names. Received "
f"{type(datasets_to_split).__name__} with value {datasets_to_split}."
)

self._execution_options: ExecutionOptions = (
execution_options or DataConfig.default_ingest_options()
)
Expand Down Expand Up @@ -68,12 +84,17 @@ def configure(
equal to `world_size`. Each element of the list contains the assigned
`DataIterator` instances by name for the worker.
"""
output = [{} for i in range(world_size)]
output = [{} for _ in range(world_size)]

if self._datasets_to_split == "all":
datasets_to_split = set(datasets.keys())
else:
datasets_to_split = set(self._datasets_to_split)

for name, ds in datasets.items():
ds = ds.copy(ds)
ds.context.execution_options = self._execution_options
if name in self._datasets_to_split:
if name in datasets_to_split:
for i, split in enumerate(
ds.streaming_split(
world_size, equal=True, locality_hints=worker_node_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import ray
import ray.train as train
from ray.train import ScalingConfig
from ray.train import ScalingConfig, DataConfig
from ray.data import Dataset
from ray.train.torch import TorchTrainer

Expand Down Expand Up @@ -120,6 +120,7 @@ def train_regression(num_workers=2, use_gpu=False):
train_loop_config=config,
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
datasets={"train": train_dataset, "validation": val_dataset},
dataset_config=DataConfig(datasets_to_split=["train"]),
)

result = trainer.fit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ray
from ray import tune
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from ray.train import ScalingConfig, DataConfig
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner

Expand All @@ -20,6 +20,7 @@ def tune_linear(num_workers, num_samples, use_gpu):
train_loop_config=config,
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
datasets={"train": train_dataset, "validation": val_dataset},
dataset_config=DataConfig(datasets_to_split=["train"]),
)

tuner = Tuner(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/tests/test_data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_dataset():
# All other datasets should not be sharded.
val_dataset = train.get_dataset_shard("val")
val_ds_count = len(list(val_dataset.iter_rows()))
assert val_ds_count == num_val_data
assert val_ds_count == num_val_data / scale_config.num_workers

trainer = DataParallelTrainer(
train_loop_per_worker=get_dataset,
Expand Down
6 changes: 2 additions & 4 deletions python/ray/train/torch/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,8 @@ def train_loop_per_worker(config):
Sharding and additional configuration can be done by
passing in a ``dataset_config``.
dataset_config: The configuration for ingesting the input ``datasets``.
By default:
- The ``"train"`` Dataset is split equally across workers.
- All other Datasets are **not** split.
By default, all the Ray Dataset are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
resume_from_checkpoint: A checkpoint to resume training from.
This checkpoint can be accessed from within ``train_loop_per_worker``
by calling ``ray.train.get_checkpoint()``.
Expand Down

0 comments on commit 4e35cd0

Please sign in to comment.