diff --git a/horovod/data/data_loader_base.py b/horovod/data/data_loader_base.py index 6fd0f5f255..70e5d8955c 100644 --- a/horovod/data/data_loader_base.py +++ b/horovod/data/data_loader_base.py @@ -55,7 +55,7 @@ class AsyncDataLoaderMixin(object): class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader): """ - def __init__(self, async_loader_queue_size=64, *args, **kwargs): + def __init__(self, async_loader_queue_size=5, *args, **kwargs): """ initialize the async data loader. Need to add this in the __init__() of the implementation """ @@ -115,6 +115,7 @@ def __iter__(self): if not self.started: self.started = True self.thread.start() + while True: batch = self.queue.get() if batch is None: diff --git a/horovod/spark/lightning/datamodule.py b/horovod/spark/lightning/datamodule.py index 534361469a..ebee791aad 100644 --- a/horovod/spark/lightning/datamodule.py +++ b/horovod/spark/lightning/datamodule.py @@ -78,6 +78,8 @@ def teardown(self, stage=None): if self.has_val: self.val_reader.stop() self.val_reader.join() + if self.verbose: + print("Tear down: async dataloaders closed.") def train_dataloader(self): if self.verbose: @@ -94,6 +96,10 @@ def train_dataloader(self): else: dataloader_class = PytorchInfiniteAsyncDataLoader kwargs['shuffling_queue_capacity'] = self.shuffle_size + # To avoid loading too much data in memory, need to calculate the queue size + # dynamicaly, and limit the data loaded in queue. + # Add 1 in size for storing the None in the end of each epoch. + kwargs['async_loader_queue_size'] = max(1, min(100000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1 self.train_dl = dataloader_class(**kwargs) return self.train_dl @@ -115,6 +121,10 @@ def val_dataloader(self): else: dataloader_class = PytorchInfiniteAsyncDataLoader kwargs['shuffling_queue_capacity'] = 0 + # To avoid loading too much data in memory, need to calculate the queue size + # dynamicaly, and limit the data loaded in queue. + # Add 1 in size for storing the None in the end of each epoch. + kwargs['async_loader_queue_size'] = max(1, min(10000 // kwargs['batch_size'], kwargs['limit_step_per_epoch'] // 4)) + 1 self.val_dl = dataloader_class(**kwargs) return self.val_dl diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 2c083810c8..f4f180955c 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -193,8 +193,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - 'gpus': _num_gpus, 'callbacks': callbacks, 'max_epochs': epochs, - 'limit_train_batches': _train_steps_per_epoch, - 'limit_val_batches': _val_steps_per_epoch, 'logger': train_logger, 'log_every_n_steps': log_every_n_steps, 'num_sanity_val_steps': 0,