From d4e1341749def9773c99f30cfa9bfcb9f10c2321 Mon Sep 17 00:00:00 2001 From: chongxiaoc <74630762+chongxiaoc@users.noreply.github.com> Date: Fri, 11 Jun 2021 16:38:34 -0500 Subject: [PATCH] lightning: turn off shuffling for validation dataset (#2974) Signed-off-by: Chongxiao Cao --- horovod/spark/lightning/remote.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index e1d23f608a..7d0368473c 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -82,8 +82,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro store = estimator.getStore() remote_store = store.to_remote(run_id, dataset_idx) - set_data_loader = _set_data_loader_fn(transformation, schema_fields, - batch_size, calculate_shuffle_buffer_size, + set_data_loader = _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, loader_num_epochs, store, verbose) def train(serialized_model): @@ -170,11 +169,11 @@ def train(serialized_model): # print(row_group) with set_data_loader(model, remote_store.train_data_path, 'train_dataloader', - train_reader_worker_count, reader_pool_type, + train_reader_worker_count, reader_pool_type, calculate_shuffle_buffer_size(), name="train_dataloader", limit_step_per_epoch=_train_steps_per_epoch), \ set_data_loader(model, remote_store.val_data_path, 'val_dataloader', - val_reader_worker_count, reader_pool_type, + val_reader_worker_count, reader_pool_type, 0, should_validate, name="val_dataloader", limit_step_per_epoch=_val_steps_per_epoch): @@ -219,10 +218,11 @@ def on_sanity_check_end(self, trainer, model): return [ResetCallback()] -def _set_data_loader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, data_loader_cls, num_epochs, store, verbose=False): +def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, num_epochs, store, verbose=False): @contextlib.contextmanager - def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True, name="", limit_step_per_epoch=-1): + def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, shuffling_queue_capacity, + should_read=True, name="", limit_step_per_epoch=-1): from petastorm import TransformSpec, make_reader, make_batch_reader import horovod.torch as hvd @@ -268,7 +268,7 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read **reader_factory_kwargs) as reader: def dataloader_fn(): return data_loader_cls(reader=reader, batch_size=batch_size, - shuffling_queue_capacity=calculate_shuffle_buffer_size(), + shuffling_queue_capacity=shuffling_queue_capacity, name=name, limit_step_per_epoch=limit_step_per_epoch, verbose=verbose)