Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make the data loader interface more general.
Signed-off-by: Peng Zhang <pengz@uber.com>
- Loading branch information
Showing
11 changed files
with
203 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from queue import Queue, Empty | ||
from threading import Thread, Event | ||
|
||
|
||
class BaseDataLoader(object): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def __len__(self): | ||
""" | ||
Length of the batches to be loaded. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def _iterate(self): | ||
""" | ||
Interface for the implimentation of iterate batches | ||
""" | ||
raise NotImplementedError() | ||
|
||
def __iter__(self): | ||
""" | ||
Starting iteration and get batchs | ||
""" | ||
for batch in self._iterate(): | ||
yield self._process_batch(batch) | ||
|
||
def _process_batch(self, batch): | ||
""" | ||
Hook to modify batch before output. Will be override by trainer to reshape the data | ||
as needed. Please do not override it. | ||
""" | ||
return batch | ||
|
||
|
||
class AsyncDataLoaderMixin(object): | ||
""" | ||
Async Mixin on top of the implementation of BaseDataLoader. It contains a seperate thread | ||
which reads data from self._iterate() and push them in the queue. The self.__iter__() function | ||
will pop the data from the queue. | ||
For example: | ||
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader): | ||
""" | ||
|
||
def __init__(self, async_loader_queue_size=64, *args, **kwargs): | ||
""" | ||
initialize the async data loader. Need to add this in the __init__() of the implementation | ||
""" | ||
super().__init__(*args, **kwargs) | ||
|
||
print(f"Apply the AsyncDataLoaderMixin on top of the data loader, async_loader_queue_size={async_loader_queue_size}. ") | ||
self.async_loader_queue_size = async_loader_queue_size | ||
|
||
if self.async_loader_queue_size <= 0: | ||
self.finished_event = Event() | ||
self.queue = Queue(self.async_loader_queue_size) | ||
self.thread = Thread(target=self._start_async_worker) | ||
self.thread.daemon = True | ||
self.started = False | ||
|
||
def __del__(self): | ||
self._close_async_loader() | ||
s = super() | ||
if hasattr(s, "__del__"): | ||
s.__del__(self) | ||
|
||
def _close_async_loader(self): | ||
""" | ||
Close the async data loader. | ||
""" | ||
print("Closing the AsyncDataLoaderMixin.") | ||
if self.async_loader_queue_size > 0 and self.started: | ||
self.finished_event.set() | ||
try: | ||
# Free buffer to allow worker to retry | ||
self.queue.get_nowait() | ||
except Empty: | ||
pass | ||
self.thread.join() | ||
|
||
def _start_async_worker(self): | ||
""" | ||
Start worker thread to load data asynchronously. | ||
User need to implement self._iterate() to read the data. | ||
""" | ||
try: | ||
while not self.finished_event.is_set(): | ||
for batch in self._iterate(): | ||
if self.finished_event.is_set(): | ||
break | ||
self.queue.put(batch) | ||
self.queue.put(None) | ||
except Exception as ex: | ||
self.queue.put(ex) | ||
self.queue.put(None) | ||
finally: | ||
self.queue.put(None) | ||
|
||
def __iter__(self): | ||
""" | ||
Override the __iter__() to iterate data asynchronously to produce batchs. | ||
Will procude batchs from the queue which were generated by self._iterate(). | ||
""" | ||
|
||
print("Start generating batches from axync data loader.") | ||
if self.async_loader_queue_size > 0: | ||
if not self.started: | ||
self.started = True | ||
self.thread.start() | ||
while True: | ||
batch = self.queue.get() | ||
if batch is None: | ||
break | ||
if isinstance(batch, Exception): | ||
raise batch | ||
yield self._process_batch(batch) | ||
else: | ||
for batch in self._iterate(): | ||
yield self._process_batch(batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from petastorm.pytorch import BatchedDataLoader | ||
from .data_loader_base import BaseDataLoader, AsyncDataLoaderMixin | ||
|
||
|
||
class PytorchDataLoader(BaseDataLoader): | ||
def __init__(self, reader, batch_size, shuffling_queue_capacity, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.reader = reader | ||
self.batch_size = batch_size | ||
self.shuffling_queue_capacity = shuffling_queue_capacity | ||
print(f"Initializing petastorm dataloader with batch_size {batch_size}" | ||
f" and shuffling_queue_capacity {shuffling_queue_capacity}") | ||
|
||
def __len__(self): | ||
return len(self.reader) | ||
|
||
def _iterate(self): | ||
if self.reader.last_row_consumed: | ||
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}") | ||
self.reader.reset() | ||
|
||
# Re-create the data loader for each iterate. There maybe some left over data | ||
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset. | ||
data_loader = BatchedDataLoader( | ||
self.reader, | ||
batch_size=self.batch_size, | ||
shuffling_queue_capacity=self.shuffling_queue_capacity, | ||
) | ||
|
||
for batch in data_loader: | ||
yield batch | ||
|
||
|
||
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
print("Created PytorchAsyncDataLoader. ") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.