diff --git a/horovod/spark/keras/estimator.py b/horovod/spark/keras/estimator.py index b1ec936738..c9982363b6 100644 --- a/horovod/spark/keras/estimator.py +++ b/horovod/spark/keras/estimator.py @@ -115,9 +115,10 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable, val_batch_size: Number of rows from the DataFrame per batch for validation, if not set, will use batch_size. epochs: Number of epochs to train. verbose: Verbosity level [0, 2] (default: 1). - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows. Allocating a larger buffer size - increases randomness of shuffling at the cost of more host memory. Defaults to estimating - with an assumption of 4GB of memory per host. + shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + Allocating a larger buffer size increases randomness of shuffling at + the cost of more host memory. Defaults to estimating with an assumption + of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not provided. diff --git a/horovod/spark/keras/remote.py b/horovod/spark/keras/remote.py index 879e80e12b..c0b5a30e83 100644 --- a/horovod/spark/keras/remote.py +++ b/horovod/spark/keras/remote.py @@ -112,10 +112,12 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): pin_gpu(hvd, tf, k) - if not user_shuffle_buffer_size: + # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. + if user_shuffle_buffer_size is None: shuffle_buffer_size = calculate_shuffle_buffer_size( hvd, avg_row_size, train_rows / hvd.size()) else: + assert user_shuffle_buffer_size >= 0, "user_shuffle_buffer_size cannot be negative!" shuffle_buffer_size = user_shuffle_buffer_size # needs to be deserialized in the with scope @@ -214,7 +216,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): schema_fields.append(sample_weight_col) if verbose: - print(f"Training parameters: Epochs: {epochs}, Scaled lr: {scaled_lr}\n" + print(f"Training parameters: Epochs: {epochs}, Scaled lr: {scaled_lr}, Shuffle size: {shuffle_buffer_size}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n" f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {validation_steps}\n" f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") @@ -241,6 +243,8 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + # Don't shuffle row groups if shuffle_buffer_size is 0 (non-shuffle case). + shuffle_row_groups=True if shuffle_buffer_size > 0 else False, **reader_factory_kwargs) as train_reader: with reader_factory(remote_store.val_data_path, num_epochs=1, @@ -252,11 +256,13 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + shuffle_row_groups=False, **reader_factory_kwargs) \ if should_validate else empty_batch_reader() as val_reader: train_data = make_dataset(train_reader, batch_size, shuffle_buffer_size, - is_batch_reader, shuffle=True, cache=inmemory_cache_all) + is_batch_reader, shuffle=True if shuffle_buffer_size > 0 else False, + cache=inmemory_cache_all) val_data = make_dataset(val_reader, val_batch_size, shuffle_buffer_size, is_batch_reader, shuffle=False, cache=inmemory_cache_all) \ if val_reader else None diff --git a/horovod/spark/lightning/datamodule.py b/horovod/spark/lightning/datamodule.py index 213e779267..367340db2e 100644 --- a/horovod/spark/lightning/datamodule.py +++ b/horovod/spark/lightning/datamodule.py @@ -63,13 +63,16 @@ def setup(self, stage=None): cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=self.schema_fields, - storage_options=self.storage_options) + storage_options=self.storage_options, + # Don't shuffle row groups without shuffling. + shuffle_row_groups=True if self.shuffle_size > 0 else False) if self.has_val: self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs, cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=self.schema_fields, - storage_options=self.storage_options) + storage_options=self.storage_options, + shuffle_row_groups=False) def teardown(self, stage=None): if stage == "fit" or stage is None: diff --git a/horovod/spark/lightning/estimator.py b/horovod/spark/lightning/estimator.py index 5111ca4cd1..b5d78fd5cc 100644 --- a/horovod/spark/lightning/estimator.py +++ b/horovod/spark/lightning/estimator.py @@ -141,9 +141,10 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, run_id: (Optional) unique ID for this run for organization in the Store. Will be automatically assigned if not provided. sample_weight_col: (Optional) column indicating the weight of each sample. - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows. Allocating a larger - buffer size increases randomness of shuffling at the cost of more host memory. Defaults to estimating with an assumption of 4GB of memory per - host. + shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + Allocating a larger buffer size increases randomness of shuffling at + the cost of more host memory. Defaults to estimating with an assumption + of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. store: Store object that abstracts reading and writing of intermediate data and run results. terminate_on_nan : (Optinoal) terminate the training process on seeing NaN output. diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index dce397907e..9d5540e973 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -171,8 +171,9 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \ int(math.floor(float(val_rows) / val_batch_size / hvd.size())) + shuffle_size = calculate_shuffle_buffer_size() if verbose: - print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}\n" + print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}, shuffle_size: {shuffle_size}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n" f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n" f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") @@ -235,7 +236,7 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'has_val': should_validate is not None, 'train_batch_size': batch_size, 'val_batch_size': val_batch_size, - 'shuffle_size': calculate_shuffle_buffer_size(), + 'shuffle_size': shuffle_size, 'num_reader_epochs': loader_num_epochs, 'reader_pool_type': reader_pool_type, 'reader_worker_count': train_reader_worker_count, @@ -312,7 +313,9 @@ def calculate_shuffle_buffer_size(): """ import horovod.torch as hvd - if user_shuffle_buffer_size: + # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. + if user_shuffle_buffer_size is not None: + assert user_shuffle_buffer_size >= 0, "user_shuffle_buffer_size cannot be negative!" return user_shuffle_buffer_size local_size = hvd.local_size() diff --git a/horovod/spark/torch/estimator.py b/horovod/spark/torch/estimator.py index ae350b6310..dd7f1670fa 100644 --- a/horovod/spark/torch/estimator.py +++ b/horovod/spark/torch/estimator.py @@ -115,9 +115,10 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable, val_batch_size: Number of rows from the DataFrame per batch for validation, if not set, will use batch_size. epochs: Number of epochs to train. verbose: Verbosity level [0, 2] (default: 1). - shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows. Allocating a larger buffer size - increases randomness of shuffling at the cost of more host memory. Defaults to estimating - with an assumption of 4GB of memory per host. + shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data). + Allocating a larger buffer size increases randomness of shuffling at + the cost of more host memory. Defaults to estimating with an assumption + of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle. partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10). run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not provided. diff --git a/horovod/spark/torch/remote.py b/horovod/spark/torch/remote.py index 6b2f325caa..1691149738 100644 --- a/horovod/spark/torch/remote.py +++ b/horovod/spark/torch/remote.py @@ -121,10 +121,12 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized, import horovod as _horovod print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}") - if not user_shuffle_buffer_size: + # If user specifies any user_shuffle_buffer_size (even 0), we should honor it. + if user_shuffle_buffer_size is None: shuffle_buffer_size = \ calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size()) else: + assert user_shuffle_buffer_size >= 0, "user_shuffle_buffer_size cannot be negative!" shuffle_buffer_size = user_shuffle_buffer_size cuda_available = torch.cuda.is_available() @@ -218,6 +220,7 @@ def save_checkpoint(): if hvd.rank() == 0 and user_verbose: print(f"Training parameters: Epochs: {epochs}\n" f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n" + f"Shuffle buffer size: {shuffle_buffer_size}\n" f"Checkpoint file: {ckpt_file}, Logs dir: {logs_dir}\n") # In general, make_batch_reader is faster than make_reader for reading the dataset. # However, we found out that make_reader performs data transformations much faster than @@ -245,6 +248,8 @@ def save_checkpoint(): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + # Don't shuffle row groups without shuffling. + shuffle_row_groups=True if shuffle_buffer_size > 0 else False **reader_factory_kwargs) as train_reader: with reader_factory(remote_store.val_data_path, num_epochs=None, @@ -256,6 +261,7 @@ def save_checkpoint(): schema_fields=schema_fields, transform_spec=transform_spec, storage_options=storage_options, + shuffle_row_groups=False, **reader_factory_kwargs) \ if should_validate else empty_batch_reader() as val_reader: diff --git a/test/integration/test_spark_lightning.py b/test/integration/test_spark_lightning.py index 4e1e81fb29..ea7350146b 100644 --- a/test/integration/test_spark_lightning.py +++ b/test/integration/test_spark_lightning.py @@ -364,6 +364,12 @@ def test_calculate_shuffle_buffer_size(self, mock_local_size, mock_allgather): expected = int(constants.TOTAL_BUFFER_MEMORY_CAP_GIB * constants.BYTES_PER_GIB / avg_row_size / 5) assert actual == expected + calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn( + train_row_count_per_worker, avg_row_size, 0) + shuffle_size = calculate_shuffle_buffer_size() + # Set 0 for non-shuffle + assert int(shuffle_size) == 0 + def test_prepare_data(self): with spark_session('test_prepare_data') as spark: df = create_xor_data(spark)