Skip to content

Commit

Permalink
make the data loader interface more general.
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed May 28, 2021
1 parent ac29047 commit e3986c3
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 163 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-MM-DD

### Added
Custom spark data loader interface. ([#2938](https://github.com/horovod/horovod/issues/2938))

- Added support to supply a logger and associated parameter to control the frequency of logging. ([#2926](https://github.com/horovod/horovod/pull/2926))

Expand Down
2 changes: 2 additions & 0 deletions docs/pytorch.rst
Expand Up @@ -139,3 +139,5 @@ Start the training job and specify the number of workers on the command line as
You can find an example of use pytorch lightning trainer with horovod backend in `pytorch_lightning_mnist.py script <../examples/pytorch/pytorch_lightning_mnist.py>`__

See the PyTorch Lightning `docs <https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#horovod>`_ for more details.

A pytorch-lightning based spark estimator trainer is also added example is in `pytorch_lightning_spark_mnist.py <../examples/spark/pytorch/pytorch_lightning_spark_mnist.py>`__
2 changes: 2 additions & 0 deletions docs/spark.rst
Expand Up @@ -96,6 +96,8 @@ logging (for Tensorboard) using the Estimator ``Store`` abstraction. Stores are
artifacts including intermediate representations of the training data. Horovod natively supports stores for HDFS
and local filesystems.

Petastorm based data loader is used by default, but user can define a custom data loader by override the `base_data_loader` interface.

End-to-end example
------------------
`keras_spark_rossmann_estimator.py script <../examples/spark/keras/keras_spark_rossmann_estimator.py>`__ provides
Expand Down
1 change: 0 additions & 1 deletion examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Expand Up @@ -95,7 +95,6 @@ def __init__(self):
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
# raise RuntimeError("x shape is {}".format(x.shape))
x = x.float().reshape((-1, 1, 28, 28))
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
Expand Down
151 changes: 0 additions & 151 deletions horovod/spark/common/data_loader.py

This file was deleted.

Empty file.
119 changes: 119 additions & 0 deletions horovod/spark/data_loaders/data_loader_base.py
@@ -0,0 +1,119 @@
from queue import Queue, Empty
from threading import Thread, Event


class BaseDataLoader(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __len__(self):
"""
Length of the batches to be loaded.
"""
raise NotImplementedError()

def _iterate(self):
"""
Interface for the implimentation of iterate batches
"""
raise NotImplementedError()

def __iter__(self):
"""
Starting iteration and get batchs
"""
for batch in self._iterate():
yield self._process_batch(batch)

def _process_batch(self, batch):
"""
Hook to modify batch before output. Will be override by trainer to reshape the data
as needed. Please do not override it.
"""
return batch


class AsyncDataLoaderMixin(object):
"""
Async Mixin on top of the implementation of BaseDataLoader. It contains a seperate thread
which reads data from self._iterate() and push them in the queue. The self.__iter__() function
will pop the data from the queue.
For example:
class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
"""

def __init__(self, async_loader_queue_size=64, *args, **kwargs):
"""
initialize the async data loader. Need to add this in the __init__() of the implementation
"""
super().__init__(*args, **kwargs)

print(f"Apply the AsyncDataLoaderMixin on top of the data loader, async_loader_queue_size={async_loader_queue_size}. ")
self.async_loader_queue_size = async_loader_queue_size

if self.async_loader_queue_size <= 0:
self.finished_event = Event()
self.queue = Queue(self.async_loader_queue_size)
self.thread = Thread(target=self._start_async_worker)
self.thread.daemon = True
self.started = False

def __del__(self):
self._close_async_loader()
s = super()
if hasattr(s, "__del__"):
s.__del__(self)

def _close_async_loader(self):
"""
Close the async data loader.
"""
print("Closing the AsyncDataLoaderMixin.")
if self.async_loader_queue_size > 0 and self.started:
self.finished_event.set()
try:
# Free buffer to allow worker to retry
self.queue.get_nowait()
except Empty:
pass
self.thread.join()

def _start_async_worker(self):
"""
Start worker thread to load data asynchronously.
User need to implement self._iterate() to read the data.
"""
try:
while not self.finished_event.is_set():
for batch in self._iterate():
if self.finished_event.is_set():
break
self.queue.put(batch)
self.queue.put(None)
except Exception as ex:
self.queue.put(ex)
self.queue.put(None)
finally:
self.queue.put(None)

def __iter__(self):
"""
Override the __iter__() to iterate data asynchronously to produce batchs.
Will procude batchs from the queue which were generated by self._iterate().
"""

print("Start generating batches from axync data loader.")
if self.async_loader_queue_size > 0:
if not self.started:
self.started = True
self.thread.start()
while True:
batch = self.queue.get()
if batch is None:
break
if isinstance(batch, Exception):
raise batch
yield self._process_batch(batch)
else:
for batch in self._iterate():
yield self._process_batch(batch)
38 changes: 38 additions & 0 deletions horovod/spark/data_loaders/pytorch_data_loaders.py
@@ -0,0 +1,38 @@
from petastorm.pytorch import BatchedDataLoader
from .data_loader_base import BaseDataLoader, AsyncDataLoaderMixin


class PytorchDataLoader(BaseDataLoader):
def __init__(self, reader, batch_size, shuffling_queue_capacity, *args, **kwargs):
super().__init__(*args, **kwargs)

self.reader = reader
self.batch_size = batch_size
self.shuffling_queue_capacity = shuffling_queue_capacity
print(f"Initializing petastorm dataloader with batch_size {batch_size}"
f" and shuffling_queue_capacity {shuffling_queue_capacity}")

def __len__(self):
return len(self.reader)

def _iterate(self):
if self.reader.last_row_consumed:
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}")
self.reader.reset()

# Re-create the data loader for each iterate. There maybe some left over data
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset.
data_loader = BatchedDataLoader(
self.reader,
batch_size=self.batch_size,
shuffling_queue_capacity=self.shuffling_queue_capacity,
)

for batch in data_loader:
yield batch


class PytorchAsyncDataLoader(AsyncDataLoaderMixin, PytorchDataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print("Created PytorchAsyncDataLoader. ")
2 changes: 1 addition & 1 deletion horovod/spark/lightning/estimator.py
Expand Up @@ -176,7 +176,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
'Name of the dataloader class.')

loader_num_epochs = Param(Params._dummy(), 'loader_num_epochs',
'Number of epochs whcih data loader reads in each iteration. If set to None, reader will be in infinite loop mode.')
'An epoch is a single pass over all rows in the dataset. If set to None, reader will be in infinite loop mode, and generate unlimite data as needed. ')

@keyword_only
def __init__(self,
Expand Down
10 changes: 4 additions & 6 deletions horovod/spark/lightning/remote.py
Expand Up @@ -119,8 +119,6 @@ def train(serialized_model):

model = deserialize(serialized_model)

# _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else 1.0
# _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0
_train_steps_per_epoch = train_steps_per_epoch
if _train_steps_per_epoch is None:
_train_steps_per_epoch = int(math.floor(float(train_rows) / batch_size / hvd.size()))
Expand Down Expand Up @@ -254,7 +252,7 @@ def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count
**reader_factory_kwargs) as reader:
def dataloader_fn():
return data_loader_cls(reader, batch_size=batch_size,
shuffling_queue_capacity=calculate_shuffle_buffer_size())
shuffling_queue_capacity=calculate_shuffle_buffer_size())
try:
setattr(model, dataloader_attr, dataloader_fn)
yield
Expand Down Expand Up @@ -310,9 +308,9 @@ def calculate_shuffle_buffer_size():

def _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls=None):
if data_loader_cls is None:
# set PetastormAsyncDataLoader as default
from horovod.spark.common.data_loader import PetastormAsyncDataLoader
data_loader_cls = PetastormAsyncDataLoader
# set PytorchAsyncDataLoader as default
from horovod.spark.data_loaders.pytorch_data_loaders import PytorchAsyncDataLoader
data_loader_cls = PytorchAsyncDataLoader

print(f"Using dataloader: {data_loader_cls}")

Expand Down

0 comments on commit e3986c3

Please sign in to comment.