Skip to content

Commit

Permalink
lightning: turn off shuffling for validation dataset (#2974)
Browse files Browse the repository at this point in the history
Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Jun 11, 2021
1 parent 52fffed commit d4e1341
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions horovod/spark/lightning/remote.py
Expand Up @@ -82,8 +82,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
store = estimator.getStore()
remote_store = store.to_remote(run_id, dataset_idx)

set_data_loader = _set_data_loader_fn(transformation, schema_fields,
batch_size, calculate_shuffle_buffer_size,
set_data_loader = _set_data_loader_fn(transformation, schema_fields, batch_size,
data_loader_cls, loader_num_epochs, store, verbose)

def train(serialized_model):
Expand Down Expand Up @@ -170,11 +169,11 @@ def train(serialized_model):
# print(row_group)

with set_data_loader(model, remote_store.train_data_path, 'train_dataloader',
train_reader_worker_count, reader_pool_type,
train_reader_worker_count, reader_pool_type, calculate_shuffle_buffer_size(),
name="train_dataloader",
limit_step_per_epoch=_train_steps_per_epoch), \
set_data_loader(model, remote_store.val_data_path, 'val_dataloader',
val_reader_worker_count, reader_pool_type,
val_reader_worker_count, reader_pool_type, 0,
should_validate, name="val_dataloader",
limit_step_per_epoch=_val_steps_per_epoch):

Expand Down Expand Up @@ -219,10 +218,11 @@ def on_sanity_check_end(self, trainer, model):
return [ResetCallback()]


def _set_data_loader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, data_loader_cls, num_epochs, store, verbose=False):
def _set_data_loader_fn(transformation, schema_fields, batch_size, data_loader_cls, num_epochs, store, verbose=False):

@contextlib.contextmanager
def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True, name="", limit_step_per_epoch=-1):
def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, shuffling_queue_capacity,
should_read=True, name="", limit_step_per_epoch=-1):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod.torch as hvd

Expand Down Expand Up @@ -268,7 +268,7 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read
**reader_factory_kwargs) as reader:
def dataloader_fn():
return data_loader_cls(reader=reader, batch_size=batch_size,
shuffling_queue_capacity=calculate_shuffle_buffer_size(),
shuffling_queue_capacity=shuffling_queue_capacity,
name=name,
limit_step_per_epoch=limit_step_per_epoch,
verbose=verbose)
Expand Down

0 comments on commit d4e1341

Please sign in to comment.