diff --git a/CHANGELOG.md b/CHANGELOG.md index c1ad611271..57ae65db9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Added support for Petastorm reader level parallel shuffling. ([#3665](https://github.com/horovod/horovod/pull/3665)) +- Added random seed support for Lightning datamodule to generate reproducible data loading outputs. ([#3665](https://github.com/horovod/horovod/pull/3665)) - Added support for `int8, uint8` allreduce and grouped_allreduce in tensorflow. ([#3649](https://github.com/horovod/horovod/pull/3649)) - Added support for batched memory copies in GPUAllgather. ([#3590](https://github.com/horovod/horovod/pull/3590)) - Added support for batched memory copies in GPUReducescatter. ([#3621](https://github.com/horovod/horovod/pull/3621)) @@ -20,8 +22,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed +- Default Petastorm reader pool is changed from `process` to `thread` for lower memory usage. ([#3665](https://github.com/horovod/horovod/pull/3665)) + ### Deprecated +- Deprecated field `shuffle_buffer_size` from `EstimatorParams`. Use `shuffle` to enable shuffle or not. ([#3665](https://github.com/horovod/horovod/pull/3665)) + ### Removed ### Fixed diff --git a/horovod/spark/common/params.py b/horovod/spark/common/params.py index 186ca6960c..6073d5d8cf 100644 --- a/horovod/spark/common/params.py +++ b/horovod/spark/common/params.py @@ -74,9 +74,14 @@ class EstimatorParams(Params): shuffle_buffer_size = Param(Params._dummy(), 'shuffle_buffer_size', - 'shuffling buffer size of data before training in number of samples', + '(Deprecated) shuffling buffer size used for training samples', typeConverter=TypeConverters.toInt) + shuffle = Param(Params._dummy(), + 'shuffle', + 'Whether to shuffle training samples or not. Defaults to True', + typeConverter=TypeConverters.toBoolean) + verbose = Param(Params._dummy(), 'verbose', 'verbose flag (0=silent, 1=enabled, other values used by frameworks)', typeConverter=TypeConverters.toInt) @@ -149,6 +154,7 @@ def __init__(self): callbacks=[], random_seed=None, shuffle_buffer_size=None, + shuffle=True, partitions_per_process=10, run_id=None, train_steps_per_epoch=None, @@ -158,7 +164,7 @@ def __init__(self): transformation_removed_fields=None, train_reader_num_workers=2, val_reader_num_workers=2, - reader_pool_type='process', + reader_pool_type='thread', label_shapes=None, inmemory_cache_all=False, use_gpu=True, @@ -316,9 +322,15 @@ def getRandomSeed(self): def setShufflingBufferSize(self, value): return self._set(shuffle_buffer_size=value) + def setShuffle(self, value): + return self._set(shuffle=value) + def getShufflingBufferSize(self): return self.getOrDefault(self.shuffle_buffer_size) + def getShuffle(self): + return self.getOrDefault(self.shuffle) + def setOptimizer(self, value): return self._set(optimizer=value) diff --git a/horovod/spark/keras/estimator.py b/horovod/spark/keras/estimator.py index 25a6a48b6e..8fc37c5e17 100644 --- a/horovod/spark/keras/estimator.py +++ b/horovod/spark/keras/estimator.py @@ -116,10 +116,11 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable, epochs: Number of epochs to train. verbose: Verbosity level [0, 2] (default: 1). random_seed: Optional random seed to use for Tensorflow. Default: None. - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data). Allocating a larger buffer size increases randomness of shuffling at the cost of more host memory. Defaults to estimating with an assumption of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. + shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True. partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not provided. @@ -142,8 +143,8 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable, high enough, or users need to apply transformation such as decompression or data augmentation on raw data. val_reader_num_workers: Similar to the train_reader_num_workers. - reader_pool_type: Type of worker pool used to parallelize reading data from the dataset. - Should be one of ['thread', 'process']. Defaults to 'process'. + reader_pool_type: Type of Petastorm worker pool used to parallelize reading data from the dataset. + Should be one of ['thread', 'process', 'dummy']. Defaults to 'thread'. inmemory_cache_all: boolean value. Cache the data in memory for training and validation. Default: False. backend_env: dict to add to the environment of the backend. Defaults to setting the java heap size to 2G min and max for libhdfs through petastorm @@ -180,6 +181,7 @@ def __init__(self, verbose=None, random_seed=None, shuffle_buffer_size=None, + shuffle=True, partitions_per_process=None, run_id=None, train_steps_per_epoch=None, @@ -187,7 +189,7 @@ def __init__(self, transformation_fn=None, train_reader_num_workers=None, val_reader_num_workers=None, - reader_pool_type=None, + reader_pool_type='thread', label_shapes=None, checkpoint_callback=None, inmemory_cache_all=False, diff --git a/horovod/spark/keras/remote.py b/horovod/spark/keras/remote.py index 3125d425bf..d613fd80c9 100644 --- a/horovod/spark/keras/remote.py +++ b/horovod/spark/keras/remote.py @@ -17,6 +17,7 @@ import io import math import os +import warnings import h5py import tensorflow as tf @@ -49,6 +50,9 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx): should_validate = estimator.getValidation() random_seed = estimator.getRandomSeed() user_shuffle_buffer_size = estimator.getShufflingBufferSize() + if user_shuffle_buffer_size is not None: + warnings.warn('shuffle_buffer_size is deprecated, use shuffle instead', DeprecationWarning) + shuffle = estimator.getShuffle() user_verbose = estimator.getVerbose() checkpoint_callback = estimator.getCheckpointCallback() inmemory_cache_all = estimator.getInMemoryCacheAll() @@ -84,7 +88,6 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx): # Utility functions deserialize_keras_model = _deserialize_keras_model_fn() - calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn() pin_gpu = _pin_gpu_fn() # Storage @@ -134,15 +137,6 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): else: tf.random.set_seed(random_seed) - # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. - if user_shuffle_buffer_size is None: - shuffle_buffer_size = calculate_shuffle_buffer_size( - hvd, avg_row_size, train_rows / hvd.size()) - else: - if user_shuffle_buffer_size < 0: - raise ValueError("user_shuffle_buffer_size cannot be negative!") - shuffle_buffer_size = user_shuffle_buffer_size - # needs to be deserialized in the with scope with k.utils.custom_object_scope(custom_objects): model = deserialize_keras_model( @@ -238,7 +232,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): if verbose: print(f"Training parameters: Epochs: {epochs}, Scaled lr: {scaled_lr}, " - f"Shuffle size: {shuffle_buffer_size}, random_seed: {random_seed}\n" + f"Shuffle: {shuffle}, random_seed: {random_seed}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n" f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {validation_steps}\n" f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") @@ -265,8 +259,9 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, - # Don't shuffle row groups if shuffle_buffer_size is 0 (non-shuffle case). - shuffle_row_groups=True if shuffle_buffer_size > 0 else False, + shuffle_rows=shuffle, + shuffle_row_groups=shuffle, + seed=random_seed, **reader_factory_kwargs) as train_reader: with reader_factory(remote_store.val_data_path, num_epochs=1, @@ -278,14 +273,14 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + shuffle_rows=False, shuffle_row_groups=False, **reader_factory_kwargs) \ if should_validate else empty_batch_reader() as val_reader: - train_data = make_dataset(train_reader, batch_size, shuffle_buffer_size, - is_batch_reader, shuffle=True if shuffle_buffer_size > 0 else False, - cache=inmemory_cache_all, seed=random_seed) - val_data = make_dataset(val_reader, val_batch_size, shuffle_buffer_size, + train_data = make_dataset(train_reader, batch_size, + is_batch_reader, shuffle=shuffle, cache=inmemory_cache_all) + val_data = make_dataset(val_reader, val_batch_size, is_batch_reader, shuffle=False, cache=inmemory_cache_all) \ if val_reader else None @@ -329,47 +324,6 @@ def deserialize_keras_model(model_bytes, load_model_fn): return deserialize_keras_model -def _calculate_shuffle_buffer_size_fn(): - def calculate_shuffle_buffer_size(hvd, avg_row_size, train_row_count_per_worker): - """ - Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling - buffer such that on a single machine, among all the workers on that machine, at most - memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size - is identical among all the workers. - - example 1: - memory_cap_gb = 4 - machine1: 8 workers - machine2: 3 workers - shuffle_buffer_size = 0.5 GB - - example 2: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 3 workers - shuffle_buffer_size = 1 GB - - example 3: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 8 workers - machine3: 5 workers - shuffle_buffer_size = 0.5 GB - """ - local_size = hvd.local_size() - local_sizes = hvd.allgather([local_size]) - max_local_size = int(max(local_sizes)) - - if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB: - shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size - else: - shuffle_buffer_size = BYTES_PER_GIB / avg_row_size - - return int(min(shuffle_buffer_size, train_row_count_per_worker)) - - return calculate_shuffle_buffer_size - - def _pin_gpu_fn(): # Horovod: pin GPU to be used to process local rank (one GPU per process) return _pin_gpu_tensorflow2_fn() if LooseVersion(tf.__version__) >= LooseVersion('2.0.0') \ diff --git a/horovod/spark/keras/util.py b/horovod/spark/keras/util.py index 9d3057afcf..85ad3d132b 100644 --- a/horovod/spark/keras/util.py +++ b/horovod/spark/keras/util.py @@ -60,20 +60,19 @@ def make_dataset_fn(feature_columns, label_columns, sample_weight_col, metadata, has_sparse_col, sample_weight_col, feature_columns, label_columns, input_shapes, label_shapes, output_names) - def fn(reader, batch_size, shuffle_buffer_size, is_batch_reader, shuffle=False, cache=False, seed=None): + def fn(reader, batch_size, is_batch_reader, shuffle=True, cache=False): from petastorm.tf_utils import make_petastorm_dataset + # Samples come from Petastorm reader are already shuffled if needed. + # We don't need to shuffle again in Tensorflow dataset. dataset = make_petastorm_dataset(reader) if is_batch_reader: dataset = dataset.apply(tf.data.experimental.unbatch()) - # Apply cache() before shuffle, so we can reshuffle in each iteration. - if cache: + # cache() can only be applied without shuffle to generate same samples per epoch. + if cache and not shuffle: dataset = dataset.cache() - if shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, seed=seed) - # Use tf.data.Dataset.repeat() to set up an infinite iterator # and to enable ranks to perform training and validation with # unequal number of samples. diff --git a/horovod/spark/lightning/datamodule.py b/horovod/spark/lightning/datamodule.py index 2f63848d80..92ae48c2bf 100644 --- a/horovod/spark/lightning/datamodule.py +++ b/horovod/spark/lightning/datamodule.py @@ -20,9 +20,9 @@ def __init__( has_val: bool = True, train_batch_size: int = 32, val_batch_size: int = 32, - shuffle_size: int = 1000, + shuffle: bool = True, num_reader_epochs=None, - reader_pool_type: str = "process", + reader_pool_type: str = "thread", reader_worker_count: int = 2, transformation=None, transformation_edit_fields=None, @@ -38,6 +38,7 @@ def __init__( debug_data_loader: bool = False, train_async_data_loader_queue_size: int = None, val_async_data_loader_queue_size: int = None, + seed: int = None, **kwargs): super().__init__() self.train_dir = train_dir @@ -46,7 +47,7 @@ def __init__( self.has_val = has_val self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size - self.shuffle_size = shuffle_size + self.shuffle = shuffle self.num_reader_epochs = num_reader_epochs self.reader_pool_type = reader_pool_type self.reader_worker_count = reader_worker_count @@ -64,6 +65,7 @@ def __init__( self.debug_data_loader = debug_data_loader self.train_async_data_loader_queue_size = train_async_data_loader_queue_size self.val_async_data_loader_queue_size = val_async_data_loader_queue_size + self.seed = seed if debug_data_loader: print("Creating data_module") @@ -101,9 +103,9 @@ def setup(self, stage=None): schema_fields=self.schema_fields, storage_options=self.storage_options, transform_spec=transform_spec, - # Don't shuffle row groups - # without shuffling. - shuffle_row_groups=True if self.shuffle_size > 0 else False, + shuffle_rows=self.shuffle, + shuffle_row_groups=self.shuffle, + seed=self.seed, **reader_factory_kwargs) if self.has_val: self.val_reader = reader_factory( @@ -117,6 +119,7 @@ def setup(self, stage=None): schema_fields=self.schema_fields, storage_options=self.storage_options, transform_spec=transform_spec, + shuffle_rows=False, shuffle_row_groups=False, **reader_factory_kwargs) @@ -151,11 +154,13 @@ def train_dataloader(self): if self.inmemory_cache_all: # Use inmem dataloader dataloader_class = PytorchInmemAsyncDataLoader - kwargs['shuffle'] = self.shuffle_size > 0 + kwargs['shuffle'] = self.shuffle kwargs['num_epochs'] = self.num_train_epochs else: dataloader_class = PytorchInfiniteAsyncDataLoader - kwargs['shuffling_queue_capacity'] = self.shuffle_size + # Don't need to shuffle again in dataloder level. + # Reader shuffles rows in every row group since Petastorm 0.12.0. + kwargs['shuffling_queue_capacity'] = 0 if self.debug_data_loader: kwargs['debug_data_loader'] = self.debug_data_loader diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index e9f5e568a9..fcfd912c61 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -135,17 +135,18 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, for training. Not needed for lightning model. partitions_per_process: (Optional) Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). - reader_pool_type: (Optional) Type of worker pool used to parallelize reading data from - the dataset. Should be one of ['thread', 'process']. Defaults to - 'process'. + reader_pool_type: (Optional) Type of Petastorm worker pool used to parallelize reading data from + the dataset. Should be one of ['thread', 'process', 'dummy']. Defaults to + 'thread'. run_id: (Optional) unique ID for this run for organization in the Store. Will be automatically assigned if not provided. sample_weight_col: (Optional) column indicating the weight of each sample. random_seed: Optional random seed to use for PyTorch Lightning. Default: None. - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data). Allocating a larger buffer size increases randomness of shuffling at the cost of more host memory. Defaults to estimating with an assumption of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. + shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True. store: Store object that abstracts reading and writing of intermediate data and run results. terminate_on_nan : (Optinoal) terminate the training process on seeing NaN output. @@ -249,6 +250,7 @@ def __init__(self, verbose=1, random_seed=None, shuffle_buffer_size=None, + shuffle=True, partitions_per_process=None, run_id=None, train_minibatch_fn=None, diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 112fb55c7b..f242e6a7af 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -18,6 +18,7 @@ import os import tempfile import math +import warnings from distutils.version import LooseVersion import torch @@ -49,6 +50,10 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro epochs = estimator.getEpochs() random_seed = estimator.getRandomSeed() user_shuffle_buffer_size = estimator.getShufflingBufferSize() + if user_shuffle_buffer_size is not None: + warnings.warn('shuffle_buffer_size is deprecated and will be removed in future releases, '\ + 'use shuffle instead', DeprecationWarning) + shuffle = estimator.getShuffle() terminate_on_nan = estimator.getTerminateOnNan() transformation_fn = estimator.getTransformationFn() transformation = transformation_fn if transformation_fn else None @@ -89,8 +94,6 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro # Utility functions deserialize = deserialize_fn() - calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn( - train_rows, avg_row_size, user_shuffle_buffer_size) schema_fields = feature_columns + label_columns if sample_weight_col: @@ -194,10 +197,9 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \ int(math.floor(float(val_rows) / val_batch_size / hvd.size())) - shuffle_size = calculate_shuffle_buffer_size() if verbose: print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}, " - f"Shuffle_size: {shuffle_size}, Random seed: {random_seed}\n" + f"Shuffle: {shuffle}, Random seed: {random_seed}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n" f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n" f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") @@ -269,7 +271,7 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'has_val': should_validate is not None, 'train_batch_size': batch_size, 'val_batch_size': val_batch_size, - 'shuffle_size': shuffle_size, + 'shuffle': shuffle, 'num_reader_epochs': loader_num_epochs, 'reader_pool_type': reader_pool_type, 'reader_worker_count': train_reader_worker_count, @@ -287,6 +289,7 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'debug_data_loader': debug_data_loader, 'train_async_data_loader_queue_size': train_async_data_loader_queue_size, 'val_async_data_loader_queue_size': val_async_data_loader_queue_size, + 'seed': random_seed, } if debug_data_loader and hvd.rank() == 0: print(f"Creating data module with args:\n {data_module_kwargs}") @@ -319,54 +322,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - return train -def _calculate_shuffle_buffer_size_fn(train_rows, avg_row_size, user_shuffle_buffer_size): - def calculate_shuffle_buffer_size(): - """ - Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling - buffer such that on a single machine, among all the workers on that machine, at most - memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size - is identical among all the workers. - - example 1: - memory_cap_gb = 4 - machine1: 8 workers - machine2: 3 workers - shuffle_buffer_size = 0.5 GB - - example 2: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 3 workers - shuffle_buffer_size = 1 GB - - example 3: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 8 workers - machine3: 5 workers - shuffle_buffer_size = 0.5 GB - """ - import horovod.torch as hvd - - # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. - if user_shuffle_buffer_size is not None: - if user_shuffle_buffer_size < 0: - raise ValueError("user_shuffle_buffer_size cannot be negative!") - return user_shuffle_buffer_size - - local_size = hvd.local_size() - local_sizes = hvd.allgather(torch.tensor([local_size])) - max_local_size = torch.max(local_sizes).item() - - if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB: - shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size - else: - shuffle_buffer_size = BYTES_PER_GIB / avg_row_size - return int(min(shuffle_buffer_size, train_rows / hvd.size())) - - return calculate_shuffle_buffer_size - - def _prepare_data_fn(metadata): def prepare_data(col_name, rows): if col_name not in metadata: diff --git a/horovod/spark/torch/estimator.py b/horovod/spark/torch/estimator.py index b79021f8b9..c7b4ecff0a 100644 --- a/horovod/spark/torch/estimator.py +++ b/horovod/spark/torch/estimator.py @@ -117,10 +117,11 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, epochs: Number of epochs to train. verbose: Verbosity level [0, 2] (default: 1). random_seed: Optional random seed to use for Torch. Default: None. - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data). Allocating a larger buffer size increases randomness of shuffling at the cost of more host memory. Defaults to estimating with an assumption of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. + shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True. partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not provided. @@ -145,8 +146,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, high enough, or users need to apply transformation such as decompression or data augmentation on raw data. val_reader_num_workers: Similar to the train_reader_num_workers. - reader_pool_type: Type of worker pool used to parallelize reading data from the dataset. - Should be one of ['thread', 'process']. Defaults to 'process'. + reader_pool_type: Type of Petastorm worker pool used to parallelize reading data from the dataset. + Should be one of ['thread', 'process', 'dummy']. Defaults to 'thread'. inmemory_cache_all: (Optional) Cache the data in memory for training and validation. use_gpu: Whether to use the GPU for training. Defaults to True. mp_start_method: The method to use to start multiprocessing. Defaults to None. @@ -182,6 +183,7 @@ def __init__(self, verbose=1, random_seed=None, shuffle_buffer_size=None, + shuffle=True, partitions_per_process=None, run_id=None, train_minibatch_fn=None, @@ -190,7 +192,7 @@ def __init__(self, transformation_fn=None, train_reader_num_workers=None, val_reader_num_workers=None, - reader_pool_type=None, + reader_pool_type='thread', label_shapes=None, inmemory_cache_all=False, use_gpu=True, diff --git a/horovod/spark/torch/remote.py b/horovod/spark/torch/remote.py index 9c24843473..778061b88e 100644 --- a/horovod/spark/torch/remote.py +++ b/horovod/spark/torch/remote.py @@ -17,6 +17,7 @@ import io import math import os +import warnings from datetime import datetime, timezone import torch @@ -52,6 +53,10 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id metric_fn_groups = estimator.getMetrics() random_seed = estimator.getRandomSeed() user_shuffle_buffer_size = estimator.getShufflingBufferSize() + if user_shuffle_buffer_size is not None: + warnings.warn('shuffle_buffer_size is deprecated and will be removed in future releases, '\ + 'use shuffle instead', DeprecationWarning) + shuffle = estimator.getShuffle() user_verbose = estimator.getVerbose() train_minibatch_fn = estimator.getTrainMinibatchFn() train_minibatch = train_minibatch_fn if train_minibatch_fn else _train_minibatch_fn() @@ -81,7 +86,6 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id # Utility functions deserialize = deserialize_fn() get_optimizer_with_unscaled_lr = _get_optimizer_with_unscaled_lr_fn() - calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn() construct_metric_value_holders = _construct_metric_value_holders_fn() metric_cls = _metric_cls() prepare_np_data = _prepare_np_data_fn() @@ -131,15 +135,6 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized, import horovod as _horovod print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}") - # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. - if user_shuffle_buffer_size is None: - shuffle_buffer_size = \ - calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size()) - else: - if user_shuffle_buffer_size < 0: - raise ValueError("user_shuffle_buffer_size cannot be negative!") - shuffle_buffer_size = user_shuffle_buffer_size - if not should_use_gpu and user_verbose: print("Skip pinning current process to the GPU.") @@ -240,7 +235,7 @@ def save_checkpoint(): if hvd.rank() == 0 and user_verbose: print(f"Training parameters: Epochs: {epochs}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n" - f"Shuffle buffer size: {shuffle_buffer_size}, Random seed: {random_seed}\n" + f"Shuffle: {shuffle}, Random seed: {random_seed}\n" f"Checkpoint file: {ckpt_file}, Logs dir: {logs_dir}\n") # In general, make_batch_reader is faster than make_reader for reading the dataset. # However, we found out that make_reader performs data transformations much faster than @@ -268,8 +263,9 @@ def save_checkpoint(): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, - # Don't shuffle row groups without shuffling. - shuffle_row_groups=True if shuffle_buffer_size > 0 else False, + shuffle_rows=shuffle, + shuffle_row_groups=shuffle, + seed=random_seed, **reader_factory_kwargs) as train_reader: with reader_factory(remote_store.val_data_path, num_epochs=None, @@ -281,7 +277,9 @@ def save_checkpoint(): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + shuffle_rows=False, shuffle_row_groups=False, + seed=random_seed, **reader_factory_kwargs) \ if should_validate else empty_batch_reader() as val_reader: @@ -291,11 +289,12 @@ def save_checkpoint(): batch_size=batch_size, num_epochs=epochs, rows_capacity=steps_per_epoch*batch_size, - shuffle=True) + shuffle=shuffle) else: train_loader = BatchedDataLoader(train_reader, batch_size=batch_size, - shuffling_queue_capacity=shuffle_buffer_size) + # No need to shuffle again in dataloader level + shuffling_queue_capacity=0) train_loader_iter = iter(train_loader) def prepare_batch(row): @@ -480,46 +479,6 @@ def get_optimizer_with_unscaled_lr(hvd, current_optimizer, optimizer_cls, model) return get_optimizer_with_unscaled_lr -def _calculate_shuffle_buffer_size_fn(): - def calculate_shuffle_buffer_size(hvd, avg_row_size, train_row_count_per_worker): - """ - Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling - buffer such that on a single machine, among all the workers on that machine, at most - memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size - is identical among all the workers. - - example 1: - memory_cap_gb = 4 - machine1: 8 workers - machine2: 3 workers - shuffle_buffer_size = 0.5 GB - - example 2: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 3 workers - shuffle_buffer_size = 1 GB - - example 3: - memory_cap_gb = 4 - machine1: 2 workers - machine2: 8 workers - machine3: 5 workers - shuffle_buffer_size = 0.5 GB - """ - local_size = hvd.local_size() - local_sizes = hvd.allgather(torch.tensor([local_size])) - max_local_size = torch.max(local_sizes).item() - - if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB: - shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size - else: - shuffle_buffer_size = BYTES_PER_GIB / avg_row_size - return int(min(shuffle_buffer_size, train_row_count_per_worker)) - - return calculate_shuffle_buffer_size - - def _construct_metric_value_holders_fn(): def construct_metric_value_holders(metric_class, metric_fn_groups, label_columns, hvd): metric_values = [] diff --git a/setup.py b/setup.py index 8df5b71900..b0f133ed06 100644 --- a/setup.py +++ b/setup.py @@ -166,7 +166,7 @@ def build_extensions(self): mxnet_require_list = ['mxnet>=1.4.1'] pyspark_require_list = ['pyspark>=2.3.2;python_version<"3.8"', 'pyspark>=3.0.0;python_version>="3.8"'] -spark_require_list = ['numpy', 'petastorm>=0.11.0', 'pyarrow>=0.15.0', 'fsspec>=2021.07.0'] +spark_require_list = ['numpy', 'petastorm>=0.12.0', 'pyarrow>=0.15.0', 'fsspec>=2021.07.0'] # https://github.com/ray-project/ray/pull/17465 ray_require_list = ['ray', 'aioredis<2'] pytorch_spark_require_list = pytorch_require_list + \ diff --git a/test/integration/test_spark_keras.py b/test/integration/test_spark_keras.py index d3abb9b921..1eea8114f6 100644 --- a/test/integration/test_spark_keras.py +++ b/test/integration/test_spark_keras.py @@ -391,34 +391,6 @@ def test_serialize_param_value(self): serialized_dummy_param = _serialize_param_value('dummy_param_name', None, None, None) assert serialized_dummy_param is None - def test_calculate_shuffle_buffer_size_small_row_size(self): - hvd_size = 4 - local_size = 2 - hvd_mock = mock.MagicMock() - hvd_mock.local_size.return_value = local_size - hvd_mock.allgather.return_value = [local_size for _ in range(hvd_size)] - - avg_row_size = 100 - train_row_count_per_worker = 100 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() - shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) - assert shuffle_size == train_row_count_per_worker - - def test_calculate_shuffle_buffer_size(self): - # case with 2 workers, one with 5 ranks and second with 3 ranks - hvd_mock = mock.MagicMock() - hvd_mock.allgather.return_value = [5, 5, 5, 5, 5, 3, 3, 3] - hvd_mock.local_size.return_value = 2 - - avg_row_size = 100000 - train_row_count_per_worker = 1000000 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() - shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) - - assert int(shuffle_size) == int(constants.TOTAL_BUFFER_MEMORY_CAP_GIB * constants.BYTES_PER_GIB / avg_row_size / 5) - def test_custom_sparse_to_dense_fn(self): dense_shape = 10 custom_sparse_to_dense = _custom_sparse_to_dense_fn() diff --git a/test/integration/test_spark_lightning.py b/test/integration/test_spark_lightning.py index 585800544d..9a422d2367 100644 --- a/test/integration/test_spark_lightning.py +++ b/test/integration/test_spark_lightning.py @@ -331,52 +331,6 @@ def test_transform_multi_class(self): for field in out_df.schema.fields: assert type(field.dataType) == expected_types[field.name] - @mock.patch('horovod.torch.allgather') - @mock.patch('horovod.torch.local_size') - def test_calculate_shuffle_buffer_size_small_row_size(self, mock_local_size, mock_allgather): - import horovod.torch as hvd - hvd.init() - - hvd_size = 4 - local_size = 2 - mock_local_size.return_value = local_size - mock_allgather.return_value = torch.tensor([local_size for _ in range(hvd_size)]) - - avg_row_size = 100 - train_row_count_per_worker = 100 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn( - train_row_count_per_worker, avg_row_size, None) - shuffle_size = calculate_shuffle_buffer_size() - assert shuffle_size == train_row_count_per_worker - - @mock.patch('horovod.torch.allgather') - @mock.patch('horovod.torch.local_size') - def test_calculate_shuffle_buffer_size(self, mock_local_size, mock_allgather): - import horovod.torch as hvd - hvd.init() - - # case with 2 workers, one with 5 ranks and second with 3 ranks - mock_allgather.return_value = torch.tensor([5, 5, 5, 5, 5, 3, 3, 3]) - mock_local_size.return_value = 2 - - avg_row_size = 100000 - train_row_count_per_worker = 1000000 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn( - train_row_count_per_worker, avg_row_size, None) - shuffle_size = calculate_shuffle_buffer_size() - - actual = int(shuffle_size) - expected = int(constants.TOTAL_BUFFER_MEMORY_CAP_GIB * constants.BYTES_PER_GIB / avg_row_size / 5) - assert actual == expected - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn( - train_row_count_per_worker, avg_row_size, 0) - shuffle_size = calculate_shuffle_buffer_size() - # Set 0 for non-shuffle - assert int(shuffle_size) == 0 - def test_prepare_data(self): with spark_session('test_prepare_data') as spark: df = create_xor_data(spark) @@ -898,7 +852,7 @@ def test_train_with_custom_data_module(self): class CustomDataModule(pl.LightningDataModule): """Custom DataModule for Lightning Estimator, using PytorchAsyncDataLoader""" def __init__(self, train_dir: str, val_dir: str, has_val: bool=True, - train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=100, + train_batch_size: int=32, val_batch_size: int=32, shuffle: bool=True, num_reader_epochs=None, cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None, steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, **kwargs): super().__init__() @@ -907,7 +861,7 @@ def __init__(self, train_dir: str, val_dir: str, has_val: bool=True, self.has_val = has_val self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size - self.shuffle_size = shuffle_size + self.shuffle = shuffle self.num_reader_epochs = num_reader_epochs self.cur_shard = cur_shard self.shard_count = shard_count @@ -923,6 +877,8 @@ def setup(self, stage=None): if stage == 'fit' or stage is None: self.train_reader = make_batch_reader(self.train_dir, num_epochs=self.num_reader_epochs, cur_shard=self.cur_shard, shard_count=self.shard_count, + shuffle_rows=self.shuffle, + shuffle_row_groups=self.shuffle, hdfs_driver='libhdfs', schema_fields=self.schema_fields, storage_options=self.storage_options) @@ -948,7 +904,7 @@ def train_dataloader(self): print("Setup train dataloader") kwargs = dict(reader=self.train_reader, batch_size=self.train_batch_size, name="train dataloader", - shuffling_queue_capacity = self.shuffle_size, + shuffling_queue_capacity=0, limit_step_per_epoch=self.steps_per_epoch_train, verbose=self.verbose) return PytorchAsyncDataLoader(**kwargs) diff --git a/test/integration/test_spark_torch.py b/test/integration/test_spark_torch.py index 5d12bc0446..08e3c4c708 100644 --- a/test/integration/test_spark_torch.py +++ b/test/integration/test_spark_torch.py @@ -184,35 +184,6 @@ def test_pytorch_get_optimizer_with_unscaled_lr(self): for i in range(len(optimizer_state['param_groups'])): assert optimizer_state['param_groups'][i]['lr'] == init_learning_rate / hvd_size - def test_calculate_shuffle_buffer_size_small_row_size(self): - hvd_size = 4 - local_size = 2 - hvd_mock = mock.MagicMock() - hvd_mock.local_size = lambda: local_size - hvd_mock.allgather = lambda x: torch.tensor([local_size for _ in range(hvd_size)]) - - avg_row_size = 100 - train_row_count_per_worker = 100 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() - shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) - assert shuffle_size == train_row_count_per_worker - - def test_calculate_shuffle_buffer_size(self): - # case with 2 workers, one with 5 ranks and second with 3 ranks - hvd_mock = mock.MagicMock() - hvd_mock.allgather = lambda x: torch.tensor([5, 5, 5, 5, 5, 3, 3, 3]) - hvd_mock.local_size = lambda: 2 - - avg_row_size = 100000 - train_row_count_per_worker = 1000000 - - calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() - shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) - - assert int(shuffle_size) == \ - int(constants.TOTAL_BUFFER_MEMORY_CAP_GIB * constants.BYTES_PER_GIB / avg_row_size / 5) - def test_metric_class(self): hvd_mock = mock.MagicMock() hvd_mock.allreduce = lambda tensor, name: 2 * tensor