Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pytorch async dataloader race condition #3120

Merged
merged 1 commit into from Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 7 additions & 12 deletions horovod/data/data_loader_base.py
Expand Up @@ -71,24 +71,19 @@ def __init__(self, async_loader_queue_size=64, *args, **kwargs):
self.thread.daemon = True
self.started = False

def __del__(self):
self._close_async_loader()
s = super()
if hasattr(s, "__del__"):
s.__del__(self)

def _close_async_loader(self):
def close_async_loader(self):
"""
Close the async data loader.
"""
print("Closing the AsyncDataLoaderMixin.")
if self.async_loader_queue_size > 0 and self.started:
self.finished_event.set()
try:
# Free buffer to allow worker to retry
self.queue.get_nowait()
except Empty:
pass
while True:
try:
# Drain buffer
self.queue.get_nowait()
except Empty:
break
self.thread.join()

def _async_worker(self):
Expand Down
13 changes: 10 additions & 3 deletions horovod/spark/lightning/datamodule.py
Expand Up @@ -65,9 +65,14 @@ def setup(self, stage=None):
def teardown(self, stage=None):
if stage == "fit" or stage is None:
if self.verbose:
print("Tear down petastorm readers")
print("Tear down: closing async dataloaders")
self.train_dl.close_async_loader()
if self.has_val:
self.val_dl.close_async_loader()
if not self.inmemory_cache_all:
# Reader was loaded once and stopped for inmemory datalaoder.
if self.verbose:
print("Tear down: closing petastorm readers")
self.train_reader.stop()
self.train_reader.join()
if self.has_val:
Expand All @@ -90,7 +95,8 @@ def train_dataloader(self):
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size

return dataloader_class(**kwargs)
self.train_dl = dataloader_class(**kwargs)
return self.train_dl

def val_dataloader(self):
if not self.has_val:
Expand All @@ -110,4 +116,5 @@ def val_dataloader(self):
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = 0

return dataloader_class(**kwargs)
self.val_dl = dataloader_class(**kwargs)
return self.val_dl